40 Commits

Author SHA1 Message Date
e8845ce28d incremental reward 2026-03-30 14:15:56 -07:00
73dcb93d5e more steps 2026-03-30 14:08:31 -07:00
3b00541644 add reward 2026-03-30 14:08:31 -07:00
b29d5f3845 normal torch ok 2026-03-30 13:15:25 -07:00
48cd49bd43 document setup using uv + fish 2026-03-30 13:13:19 -07:00
8ac74d36f0 CompleteMap is only needed for construction 2026-03-30 12:46:28 -07:00
9995e852d4 update ndarray 2026-03-30 12:40:16 -07:00
Cameron Grant
0a1c8068fe Add rustfmt.toml and improve code formatting in lib.rs. 2026-03-30 12:39:55 -07:00
792078121b CompleteMap is only needed for construction 2026-03-30 12:26:34 -07:00
df3e813dd9 fix float to bool 2026-03-30 12:25:37 -07:00
cea0bcbaf3 Use slice deconstruction 2026-03-30 12:18:42 -07:00
e79d0378ac Use helper function 2026-03-30 12:16:57 -07:00
Cameron Grant
d907672daa Add poetry.lock file for Python dependency management. 2026-03-30 12:10:13 -07:00
Cameron Grant
899278ff64 Add Cargo.lock file for Rust dependency management. 2026-03-30 11:53:59 -07:00
Cameron Grant
5da88a0f69 Converted full project to PyTorch. 2026-03-30 11:39:04 -07:00
96c21fffa9 separate depth from inputs 2026-03-28 08:44:22 -07:00
357e0f4a20 print best loss 2026-03-28 08:36:16 -07:00
31bfa208f8 include angles in history 2026-03-28 08:26:03 -07:00
1d09378bfd silence lint 2026-03-28 08:08:01 -07:00
bf2bf6d693 dropout first 2026-03-28 07:37:04 -07:00
a144ff1178 fix file name shenanigans 2026-03-27 19:33:19 -07:00
48f9657d0f don't hardcode map and bot 2026-03-27 16:46:02 -07:00
e38c0a92b4 remove chrono dep 2026-03-27 16:28:30 -07:00
148471dce1 simulate: add output_file argument 2026-03-27 16:28:30 -07:00
7bf439395b write model name based on num params 2026-03-27 16:16:26 -07:00
03f5eb5c13 tweak model 2026-03-27 16:04:03 -07:00
1e7bb6c4ce format 2026-03-27 15:57:32 -07:00
d8b0f9abbb rename GraphicsState to InputGenerator 2026-03-27 15:57:17 -07:00
fb8c6e2492 training options 2026-03-27 15:56:15 -07:00
18cad85b62 add cli args 2026-03-27 15:56:15 -07:00
b195a7eb95 split code into modules 2026-03-27 15:46:51 -07:00
4208090da0 add clap dep 2026-03-27 15:32:17 -07:00
e31b148f41 simulator 2026-03-27 15:29:32 -07:00
a05113baa5 feed position history into model inputs 2026-03-27 15:28:53 -07:00
9ad8a70ad0 hardcode depth "normalization" 2026-03-27 15:13:29 -07:00
e890623f2e add dropout to input 2026-03-27 15:00:30 -07:00
7d55e872e7 save model with current date 2026-03-27 11:56:45 -07:00
1e1cbeb180 graphics state 2026-03-27 11:56:45 -07:00
e19c46d851 save best model 2026-03-27 11:25:34 -07:00
59bb8eee12 implement training 2026-03-27 11:13:23 -07:00
13 changed files with 3261 additions and 6714 deletions

27
.gitignore vendored
View File

@@ -1,2 +1,27 @@
/files
# Rust
/target
# Python
__pycache__
*.pyc
*.egg-info
.eggs
dist
build
.venv
# Data files
/files
*.snfm
*.qbot
*.snfb
*.model
*.bin
# TensorBoard
runs/
# IDE / tools
.claude
.idea
_rust.*

6508
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,10 +1,17 @@
[package]
name = "strafe-ai"
name = "strafe-ai-py"
version = "0.1.0"
edition = "2024"
[lib]
name = "_rust"
crate-type = ["cdylib"]
[dependencies]
burn = { version = "0.20.1", features = ["cuda", "autodiff"] }
pyo3 = { version = "0.28", features = ["extension-module"] }
numpy = "0.28"
glam = "0.32.1"
pollster = "0.4.0"
wgpu = "29.0.0"
strafesnet_common = { version = "0.9.0", registry = "strafesnet" }
@@ -13,4 +20,3 @@ strafesnet_physics = { version = "=0.0.2-surf", registry = "strafesnet" }
strafesnet_roblox_bot_file = { version = "0.9.4", registry = "strafesnet" }
strafesnet_roblox_bot_player = { version = "=0.6.2-depth2", registry = "strafesnet" }
strafesnet_snf = { version = "0.4.0", registry = "strafesnet" }
pollster = "0.4.0"

77
README.md Normal file
View File

@@ -0,0 +1,77 @@
# strafe-ai-py
PyTorch + Rust environment for training an AI to play bhop maps.
## Architecture
- **Rust** (via PyO3): Physics simulation + depth rendering — fast, deterministic
- **Python** (PyTorch): Model, training, RL algorithms, TensorBoard logging
The Rust side exposes a [Gymnasium](https://gymnasium.farama.org/)-compatible environment,
so any standard RL library (Stable Baselines3, CleanRL, etc.) works out of the box.
## Setup
Requires:
- Python 3.10+
- Rust toolchain
- CUDA toolkit
- [maturin](https://github.com/PyO3/maturin) for building the Rust extension
```bash
# create a virtual environment
uv venv
source .venv/bin/activate.fish # or .venv\Scripts\activate on Windows
# install maturin
uv pip install torch
uv pip install maturin gymnasium numpy tensorboard
# build and install the Rust extension + Python deps
maturin develop --release
# or install with RL extras
pip install -e ".[rl]"
```
## Usage
```bash
# test the environment
python -m strafe_ai.train --map-file path/to/map.snfm
# tensorboard
tensorboard --logdir runs
```
## Project Structure
```
strafe_ai/ Python package
environment.py Gymnasium env wrapping Rust sim
model.py PyTorch model (StrafeNet)
train.py Training script
src/
lib.rs PyO3 bindings (physics + rendering)
```
## Environment API
```python
from strafe_ai import StrafeEnvironment
env = StrafeEnvironment("map.snfm")
obs, info = env.reset()
# obs["position_history"] — (50,) float32 — 10 recent positions + angles
# obs["depth"] — (2304,) float32 — 64x36 depth buffer
action = env.action_space.sample() # 7 floats
obs, reward, terminated, truncated, info = env.step(action)
```
## Next Steps
- [ ] Implement reward function (curve_dt along WR path)
- [ ] Train with PPO (Stable Baselines3)
- [ ] Add .qbot replay loading for imitation learning pretraining

2319
poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

22
pyproject.toml Normal file
View File

@@ -0,0 +1,22 @@
[build-system]
requires = ["maturin>=1.5,<2.0"]
build-backend = "maturin"
[project]
name = "strafe-ai"
version = "0.1.0"
requires-python = ">=3.10"
dependencies = [
"torch>=2.0",
"gymnasium",
"numpy",
"tensorboard",
]
[project.optional-dependencies]
rl = ["stable-baselines3"]
[tool.maturin]
python-source = "."
features = ["pyo3/extension-module"]
module-name = "strafe_ai._rust"

View File

@@ -1 +1 @@
hard_tabs = true
hard_tabs = true

414
src/lib.rs Normal file
View File

@@ -0,0 +1,414 @@
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use strafesnet_common::instruction::TimedInstruction;
use strafesnet_common::mouse::MouseState;
use strafesnet_common::physics::{
Instruction as PhysicsInputInstruction, ModeInstruction, MouseInstruction,
SetControlInstruction, Time as PhysicsTime,
};
use strafesnet_physics::physics::{PhysicsContext, PhysicsData, PhysicsState};
const LIMITS: wgpu::Limits = wgpu::Limits::defaults();
const FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Rgba8UnormSrgb;
const SIZE_X: u32 = 64;
const SIZE_Y: u32 = 36;
const SIZE: glam::UVec2 = glam::uvec2(SIZE_X, SIZE_Y);
const DEPTH_PIXELS: usize = (SIZE_X * SIZE_Y) as usize;
const STRIDE_SIZE: u32 = (SIZE_X * size_of::<f32>() as u32).next_multiple_of(256);
const POSITION_HISTORY: usize = 10;
const STEP: PhysicsTime = PhysicsTime::from_millis(10);
fn float_to_bool(float: f32) -> bool {
0.5 < float
}
/// The Rust-side environment that wraps physics simulation and depth rendering.
/// Exposed to Python via PyO3.
#[pyclass]
struct StrafeEnv {
// rendering
wgpu_device: wgpu::Device,
wgpu_queue: wgpu::Queue,
graphics: strafesnet_roblox_bot_player::graphics::Graphics,
graphics_texture_view: wgpu::TextureView,
output_staging_buffer: wgpu::Buffer,
texture_data: Vec<u8>,
// simulation
geometry: PhysicsData,
simulation: PhysicsState,
time: PhysicsTime,
mouse_pos: glam::IVec2,
start_offset: glam::Vec3,
// scoring
bot: strafesnet_roblox_bot_player::bot::CompleteBot,
bvh: strafesnet_roblox_bot_player::bvh::Bvh,
best_progress: f32,
// position history (relative positions + angles)
position_history: Vec<(glam::Vec3, glam::Vec2)>,
}
#[pymethods]
impl StrafeEnv {
/// Create a new environment from a map file path.
#[new]
fn new(map_path: &str, bot_path: &str) -> PyResult<Self> {
let map_file = std::fs::read(map_path)
.map_err(|e| PyRuntimeError::new_err(format!("Failed to read map: {e}")))?;
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
.map_err(|e| PyRuntimeError::new_err(format!("Failed to decode map: {e:?}")))?
.into_complete_map()
.map_err(|e| PyRuntimeError::new_err(format!("Failed to load map: {e:?}")))?;
let bot_file = std::fs::read(bot_path)
.map_err(|e| PyRuntimeError::new_err(format!("Failed to read bot: {e}")))?;
let timelines =
strafesnet_roblox_bot_file::v0::read_all_to_block(std::io::Cursor::new(bot_file))
.map_err(|e| PyRuntimeError::new_err(format!("Failed to decode bot: {e:?}")))?;
let bot = strafesnet_roblox_bot_player::bot::CompleteBot::new(timelines)
.map_err(|e| PyRuntimeError::new_err(format!("Failed to init bot: {e:?}")))?;
let bvh = strafesnet_roblox_bot_player::bvh::Bvh::new(&bot);
// compute start offset
let modes = map.modes.clone().denormalize();
let mode = modes
.get_mode(strafesnet_common::gameplay_modes::ModeId::MAIN)
.ok_or_else(|| PyRuntimeError::new_err("No MAIN mode in map"))?;
let start_zone = map
.models
.get(mode.get_start().get() as usize)
.ok_or_else(|| PyRuntimeError::new_err("Start zone not found"))?;
let start_offset = glam::Vec3::from_array(
start_zone
.transform
.translation
.map(|f| f.into())
.to_array(),
);
// init wgpu
let desc = wgpu::InstanceDescriptor::new_without_display_handle_from_env();
let instance = wgpu::Instance::new(desc);
let (wgpu_device, wgpu_queue) = pollster::block_on(async {
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
force_fallback_adapter: false,
compatible_surface: None,
})
.await
.unwrap();
strafesnet_graphics::setup::step4::request_device(&adapter, LIMITS)
.await
.unwrap()
});
let mut graphics = strafesnet_roblox_bot_player::graphics::Graphics::new(
&wgpu_device,
&wgpu_queue,
SIZE,
FORMAT,
LIMITS,
);
graphics
.change_map(&wgpu_device, &wgpu_queue, &map)
.unwrap();
let graphics_texture = wgpu_device.create_texture(&wgpu::TextureDescriptor {
label: Some("RGB texture"),
format: FORMAT,
size: wgpu::Extent3d {
width: SIZE_X,
height: SIZE_Y,
depth_or_array_layers: 1,
},
mip_level_count: 1,
sample_count: 1,
dimension: wgpu::TextureDimension::D2,
usage: wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
view_formats: &[],
});
let graphics_texture_view = graphics_texture.create_view(&wgpu::TextureViewDescriptor {
label: Some("RGB texture view"),
aspect: wgpu::TextureAspect::All,
usage: Some(
wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
),
..Default::default()
});
let texture_data = Vec::<u8>::with_capacity((STRIDE_SIZE * SIZE_Y) as usize);
let output_staging_buffer = wgpu_device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Output staging buffer"),
size: texture_data.capacity() as u64,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let geometry = PhysicsData::new(&map);
let simulation = PhysicsState::default();
let mut env = Self {
wgpu_device,
wgpu_queue,
graphics,
graphics_texture_view,
output_staging_buffer,
texture_data,
geometry,
simulation,
time: PhysicsTime::ZERO,
mouse_pos: glam::ivec2(-5300, 0),
start_offset,
position_history: Vec::with_capacity(POSITION_HISTORY),
bot,
bvh,
best_progress: 0.0,
};
// initial reset
env.do_reset();
Ok(env)
}
/// Reset the environment. Returns (position_history, depth) as flat f32 arrays.
fn reset(&mut self) -> (Vec<f32>, Vec<f32>) {
self.do_reset();
self.get_observation()
}
/// Take one step with the given action.
/// action: [move_forward, move_left, move_back, move_right, jump, mouse_dx, mouse_dy]
/// Returns (position_history, depth, done)
fn step(&mut self, action: Vec<f32>) -> PyResult<(Vec<f32>, Vec<f32>, bool)> {
let &[
move_forward,
move_left,
move_back,
move_right,
jump,
mouse_dx,
mouse_dy,
] = action.as_slice()
else {
return Err(PyRuntimeError::new_err("Action must have 7 elements"));
};
// apply controls
self.run_instruction(PhysicsInputInstruction::SetControl(
SetControlInstruction::SetMoveForward(float_to_bool(move_forward)),
));
self.run_instruction(PhysicsInputInstruction::SetControl(
SetControlInstruction::SetMoveLeft(float_to_bool(move_left)),
));
self.run_instruction(PhysicsInputInstruction::SetControl(
SetControlInstruction::SetMoveBack(float_to_bool(move_back)),
));
self.run_instruction(PhysicsInputInstruction::SetControl(
SetControlInstruction::SetMoveRight(float_to_bool(move_right)),
));
self.run_instruction(PhysicsInputInstruction::SetControl(
SetControlInstruction::SetJump(float_to_bool(jump)),
));
// apply mouse
self.mouse_pos += glam::vec2(mouse_dx, mouse_dy).round().as_ivec2();
let next_time = self.time + STEP;
self.run_instruction(PhysicsInputInstruction::Mouse(
MouseInstruction::SetNextMouse(MouseState {
pos: self.mouse_pos,
time: next_time,
}),
));
self.time = next_time;
let (pos_hist, depth) = self.get_observation();
// done after 20 seconds of simulation
let done = self.time >= PhysicsTime::from_millis(20_000);
Ok((pos_hist, depth, done))
}
fn reward(&mut self) -> f32 {
// this is going int -> float -> int but whatever.
let point = self
.simulation
.body()
.position
.map(Into::<f32>::into)
.to_array()
.into();
let time = self.bvh.closest_time_to_point(&self.bot, point).unwrap();
let progress = self.bot.playback_time(time).into();
if self.best_progress < progress {
let diff = progress - self.best_progress;
self.best_progress = progress;
return diff;
} else {
return 0.0;
}
}
/// Get the current position as [x, y, z].
fn get_position(&self) -> Vec<f32> {
let trajectory = self.simulation.camera_trajectory(&self.geometry);
let pos = trajectory
.extrapolated_position(self.time)
.map(Into::<f32>::into)
.to_array();
pos.to_vec()
}
/// Get observation dimensions.
#[staticmethod]
fn observation_sizes() -> (usize, usize) {
(POSITION_HISTORY * 5, DEPTH_PIXELS)
}
/// Get action size.
#[staticmethod]
fn action_size() -> usize {
7
}
}
impl StrafeEnv {
fn do_reset(&mut self) {
self.simulation = PhysicsState::default();
self.time = PhysicsTime::ZERO;
self.mouse_pos = glam::ivec2(-5300, 0);
self.position_history.clear();
self.run_instruction(PhysicsInputInstruction::Mode(ModeInstruction::Reset));
self.run_instruction(PhysicsInputInstruction::Mode(ModeInstruction::Restart(
strafesnet_common::gameplay_modes::ModeId::MAIN,
)));
}
fn run_instruction(&mut self, instruction: PhysicsInputInstruction) {
PhysicsContext::run_input_instruction(
&mut self.simulation,
&self.geometry,
TimedInstruction {
time: self.time,
instruction,
},
);
}
fn get_observation(&mut self) -> (Vec<f32>, Vec<f32>) {
let trajectory = self.simulation.camera_trajectory(&self.geometry);
let pos: glam::Vec3 = trajectory
.extrapolated_position(self.time)
.map(Into::<f32>::into)
.to_array()
.into();
let camera = self.simulation.camera();
let angles = camera.simulate_move_angles(glam::IVec2::ZERO);
// build position history input
let mut pos_hist = Vec::with_capacity(POSITION_HISTORY * 5);
if !self.position_history.is_empty() {
let cam_matrix =
strafesnet_graphics::graphics::view_inv(pos - self.start_offset, angles).inverse();
for &(p, a) in self.position_history.iter().rev() {
let relative_pos = cam_matrix.transform_vector3(p);
let relative_ang = glam::vec2(angles.x - a.x, a.y);
pos_hist.extend_from_slice(&relative_pos.to_array());
pos_hist.extend_from_slice(&relative_ang.to_array());
}
}
// pad remaining history with zeros
for _ in self.position_history.len()..POSITION_HISTORY {
pos_hist.extend_from_slice(&[0.0, 0.0, 0.0, 0.0, 0.0]);
}
// update position history
if self.position_history.len() < POSITION_HISTORY {
self.position_history.push((pos, angles));
} else {
self.position_history.rotate_left(1);
*self.position_history.last_mut().unwrap() = (pos, angles);
}
// render depth
let render_pos = pos - self.start_offset;
let depth = self.render_depth(render_pos, angles);
(pos_hist, depth)
}
fn render_depth(&mut self, pos: glam::Vec3, angles: glam::Vec2) -> Vec<f32> {
let mut encoder =
self.wgpu_device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("depth encoder"),
});
self.graphics
.encode_commands(&mut encoder, &self.graphics_texture_view, pos, angles);
encoder.copy_texture_to_buffer(
wgpu::TexelCopyTextureInfo {
texture: self.graphics.depth_texture(),
mip_level: 0,
origin: wgpu::Origin3d::ZERO,
aspect: wgpu::TextureAspect::All,
},
wgpu::TexelCopyBufferInfo {
buffer: &self.output_staging_buffer,
layout: wgpu::TexelCopyBufferLayout {
offset: 0,
bytes_per_row: Some(STRIDE_SIZE),
rows_per_image: Some(SIZE_Y),
},
},
wgpu::Extent3d {
width: SIZE_X,
height: SIZE_Y,
depth_or_array_layers: 1,
},
);
self.wgpu_queue.submit([encoder.finish()]);
let buffer_slice = self.output_staging_buffer.slice(..);
let (sender, receiver) = std::sync::mpsc::channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |r| sender.send(r).unwrap());
self.wgpu_device
.poll(wgpu::PollType::wait_indefinitely())
.unwrap();
receiver.recv().unwrap().unwrap();
let mut depth = Vec::with_capacity(DEPTH_PIXELS);
{
let view = buffer_slice.get_mapped_range();
self.texture_data.extend_from_slice(&view[..]);
}
self.output_staging_buffer.unmap();
for y in 0..SIZE_Y {
depth.extend(
self.texture_data[(STRIDE_SIZE * y) as usize
..(STRIDE_SIZE * y + SIZE_X * size_of::<f32>() as u32) as usize]
.chunks_exact(4)
.map(|b| 1.0 - 2.0 * f32::from_le_bytes(b.try_into().unwrap())),
);
}
self.texture_data.clear();
depth
}
}
/// Python module definition
#[pymodule]
fn _rust(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<StrafeEnv>()?;
Ok(())
}

View File

@@ -1,363 +0,0 @@
use burn::backend::Autodiff;
use burn::nn::loss::{MseLoss, Reduction};
use burn::nn::{Linear, LinearConfig, Relu};
use burn::optim::{GradientsParams, Optimizer, AdamConfig};
use burn::prelude::*;
type InferenceBackend = burn::backend::Cuda<f32>;
type TrainingBackend = Autodiff<InferenceBackend>;
const LIMITS: wgpu::Limits = wgpu::Limits::defaults();
use strafesnet_graphics::setup;
use strafesnet_roblox_bot_file::v0;
const SIZE_X: usize = 64;
const SIZE_Y: usize = 36;
const INPUT: usize = SIZE_X * SIZE_Y;
const HIDDEN: [usize; 2] = [
INPUT >> 3,
INPUT >> 7,
];
// MoveForward
// MoveLeft
// MoveBack
// MoveRight
// Jump
// mouse_dx
// mouse_dy
const OUTPUT: usize = 7;
#[derive(Module, Debug)]
struct Net<B: Backend> {
input: Linear<B>,
hidden: [Linear<B>; HIDDEN.len() - 1],
output: Linear<B>,
activation: Relu,
}
impl<B: Backend> Net<B> {
fn init(device: &B::Device) -> Self {
let mut it = HIDDEN.into_iter();
let mut last_size = it.next().unwrap();
let input = LinearConfig::new(INPUT, last_size).init(device);
let hidden = core::array::from_fn(|_| {
let size = it.next().unwrap();
let layer = LinearConfig::new(last_size, size).init(device);
last_size = size;
layer
});
let output = LinearConfig::new(last_size, OUTPUT).init(device);
Self {
input,
hidden,
output,
activation: Relu::new(),
}
}
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let x = self.input.forward(input);
let mut x = self.activation.forward(x);
for layer in &self.hidden {
x = layer.forward(x);
x = self.activation.forward(x);
}
self.output.forward(x)
}
}
fn training() {
let gpu_id: usize = std::env::args()
.skip(1)
.next()
.map(|id| id.parse().unwrap())
.unwrap_or_default();
// load map
// load replay
// setup player
let map_file = include_bytes!("../files/bhop_marble_5692093612.snfm");
let bot_file = include_bytes!("../files/bhop_marble_7cf33a64-7120-4514-b9fa-4fe29d9523d.qbot");
// read files
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
.unwrap()
.into_complete_map()
.unwrap();
let timelines =
strafesnet_roblox_bot_file::v0::read_all_to_block(std::io::Cursor::new(bot_file)).unwrap();
let bot = strafesnet_roblox_bot_player::bot::CompleteBot::new(timelines).unwrap();
let world_offset = bot.world_offset();
let timelines = bot.timelines();
// setup graphics
let desc = wgpu::InstanceDescriptor::new_without_display_handle_from_env();
let instance = wgpu::Instance::new(desc);
let (device, queue) = pollster::block_on(async {
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
force_fallback_adapter: false,
compatible_surface: None,
})
.await
.unwrap();
setup::step4::request_device(&adapter, LIMITS)
.await
.unwrap()
});
const FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Rgba8UnormSrgb;
let size = [SIZE_X as u32, SIZE_Y as u32].into();
let mut graphics = strafesnet_roblox_bot_player::graphics::Graphics::new(
&device, &queue, size, FORMAT, LIMITS,
);
graphics.change_map(&device, &queue, &map).unwrap();
// setup simulation
// run progressively longer segments of the map, starting very close to the end of the run and working the starting time backwards until the ai can run the whole map
// set up textures
let graphics_texture = device.create_texture(&wgpu::TextureDescriptor {
label: Some("RGB texture"),
format: FORMAT,
size: wgpu::Extent3d {
width: size.x,
height: size.y,
depth_or_array_layers: 1,
},
mip_level_count: 1,
sample_count: 1,
dimension: wgpu::TextureDimension::D2,
usage: wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
view_formats: &[],
});
let graphics_texture_view = graphics_texture.create_view(&wgpu::TextureViewDescriptor {
label: Some("RGB texture view"),
aspect: wgpu::TextureAspect::All,
usage: Some(wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING),
..Default::default()
});
// bytes_per_row needs to be a multiple of 256.
let stride_size = (size.x * size_of::<f32>() as u32).next_multiple_of(256);
let mut texture_data = Vec::<u8>::with_capacity((stride_size * size.y) as usize);
let output_staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Output staging buffer"),
size: texture_data.capacity() as u64,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
// training data
let training_samples = timelines.input_events.len() - 1;
let input_size = (size.x * size.y) as usize * size_of::<f32>();
let mut inputs = Vec::with_capacity(input_size * training_samples);
let mut targets = Vec::with_capacity(OUTPUT * training_samples);
// generate all frames
println!("Generating {training_samples} frames of depth textures...");
let mut it = timelines.input_events.iter();
// grab mouse position from first frame, omitting one frame from the training data
let first = it.next().unwrap();
let mut last_mx = first.event.mouse_pos.x;
let mut last_my = first.event.mouse_pos.y;
for input_event in it {
let mouse_dx = input_event.event.mouse_pos.x - last_mx;
let mouse_dy = input_event.event.mouse_pos.y - last_my;
last_mx = input_event.event.mouse_pos.x;
last_my = input_event.event.mouse_pos.y;
// set targets
targets.extend([
// MoveForward
input_event
.event
.game_controls
.contains(v0::GameControls::MoveForward) as i32 as f32,
// MoveLeft
input_event
.event
.game_controls
.contains(v0::GameControls::MoveLeft) as i32 as f32,
// MoveBack
input_event
.event
.game_controls
.contains(v0::GameControls::MoveBack) as i32 as f32,
// MoveRight
input_event
.event
.game_controls
.contains(v0::GameControls::MoveRight) as i32 as f32,
// Jump
input_event
.event
.game_controls
.contains(v0::GameControls::Jump) as i32 as f32,
mouse_dx,
mouse_dy,
]);
// find the closest output event to the input event time
let output_event_index = timelines
.output_events
.binary_search_by(|event| event.time.partial_cmp(&input_event.time).unwrap());
let output_event = match output_event_index {
// found the exact same timestamp
Ok(output_event_index) => &timelines.output_events[output_event_index],
// found first index greater than the time.
// check this index and the one before and return the closest one
Err(insert_index) => timelines
.output_events
.get(insert_index)
.into_iter()
.chain(
insert_index
.checked_sub(1)
.and_then(|index| timelines.output_events.get(index)),
)
.min_by(|&e0, &e1| {
(e0.time - input_event.time)
.abs()
.partial_cmp(&(e1.time - input_event.time).abs())
.unwrap()
})
.unwrap(),
};
fn p(v: v0::Vector3) -> [f32; 3] {
[v.x, v.y, v.z]
}
fn a(a: v0::Vector3) -> [f32; 2] {
[a.y, a.x]
}
fn sub<T: core::ops::Sub>(lhs: T, rhs: T) -> T::Output {
lhs - rhs
}
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("wgpu encoder"),
});
// render!
graphics.encode_commands(
&mut encoder,
&graphics_texture_view,
sub(p(output_event.event.position).into(), world_offset),
a(output_event.event.angles).into(),
);
// copy the depth texture into ram
encoder.copy_texture_to_buffer(
wgpu::TexelCopyTextureInfo {
texture: graphics.depth_texture(),
mip_level: 0,
origin: wgpu::Origin3d::ZERO,
aspect: wgpu::TextureAspect::All,
},
wgpu::TexelCopyBufferInfo {
buffer: &output_staging_buffer,
layout: wgpu::TexelCopyBufferLayout {
offset: 0,
// This needs to be a multiple of 256.
bytes_per_row: Some(stride_size as u32),
rows_per_image: Some(size.y),
},
},
wgpu::Extent3d {
width: size.x,
height: size.y,
depth_or_array_layers: 1,
},
);
queue.submit([encoder.finish()]);
// map buffer
let buffer_slice = output_staging_buffer.slice(..);
let (sender, receiver) = std::sync::mpsc::channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |r| sender.send(r).unwrap());
device.poll(wgpu::PollType::wait_indefinitely()).unwrap();
receiver.recv().unwrap().unwrap();
// copy texture inside a scope so the mapped view gets dropped
{
let view = buffer_slice.get_mapped_range();
texture_data.extend_from_slice(&view[..]);
}
output_staging_buffer.unmap();
let inputs_start = inputs.len();
// discombolulate stride
for y in 0..size.y {
inputs.extend(
texture_data[(stride_size * y) as usize
..(stride_size * y + size.x * size_of::<f32>() as u32) as usize]
.chunks_exact(4)
.map(|b| f32::from_le_bytes(b.try_into().unwrap())),
)
}
let inputs_end = inputs.len();
println!("inputs = {:?}", &inputs[inputs_start..inputs_end]);
texture_data.clear();
}
let device = burn::backend::cuda::CudaDevice::new(gpu_id);
let mut model: Net<TrainingBackend> = Net::init(&device);
println!("Training model ({} parameters)", model.num_params());
let mut optim = AdamConfig::new().init();
let inputs = Tensor::from_data(
TensorData::new(
inputs,
Shape::new([training_samples, (size.x * size.y) as usize]),
),
&device,
);
let targets = Tensor::from_data(
TensorData::new(targets, Shape::new([training_samples, OUTPUT])),
&device,
);
const LEARNING_RATE: f64 = 0.5;
const EPOCHS: usize = 10000;
for epoch in 0..EPOCHS {
let predictions = model.forward(inputs.clone());
let loss = MseLoss::new().forward(predictions, targets.clone(), Reduction::Mean);
let loss_scalar = loss.clone().into_scalar();
if epoch == 0 {
// kinda a fake print, but that's what is happening after this point
println!("Compiling optimized GPU kernels...");
}
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
model = optim.step(LEARNING_RATE, model, grads);
if epoch % (EPOCHS >> 4) == 0 || epoch == EPOCHS - 1 {
// .clone().into_scalar() extracts the f32 value from a 1-element tensor.
println!(" epoch {:>5} | loss = {:.8}", epoch, loss_scalar);
}
}
}
fn inference() {
// load map
// setup simulation
// setup agent-simulation feedback loop
// go!
}
fn main() {
training();
}

4
strafe_ai/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
from strafe_ai.environment import StrafeEnvironment
from strafe_ai.model import StrafeNet
__all__ = ["StrafeEnvironment", "StrafeNet"]

70
strafe_ai/environment.py Normal file
View File

@@ -0,0 +1,70 @@
"""
Gymnasium-compatible environment wrapping the Rust physics sim + depth renderer.
Usage:
env = StrafeEnvironment("path/to/map.snfm")
obs, info = env.reset()
obs, reward, terminated, truncated, info = env.step(action)
"""
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from strafe_ai._rust import StrafeEnv
class StrafeEnvironment(gym.Env):
"""
A bhop environment.
Observation: dict with "position_history" and "depth" arrays.
Action: 7 floats — [forward, left, back, right, jump, mouse_dx, mouse_dy]
First 5 are binary (thresholded at 0.5), last 2 are continuous.
"""
metadata = {"render_modes": ["none"]}
def __init__(self, map_path: str, bot_path: str):
super().__init__()
self._env = StrafeEnv(map_path, bot_path)
pos_size, depth_size = StrafeEnv.observation_sizes()
# observation: position history + depth buffer
self.observation_space = spaces.Dict({
"position_history": spaces.Box(-np.inf, np.inf, shape=(pos_size,), dtype=np.float32),
"depth": spaces.Box(-1.0, 1.0, shape=(depth_size,), dtype=np.float32),
})
# action: 5 binary controls + 2 continuous mouse deltas
self.action_space = spaces.Box(
low=np.array([0, 0, 0, 0, 0, -100, -100], dtype=np.float32),
high=np.array([1, 1, 1, 1, 1, 100, 100], dtype=np.float32),
)
def reset(self, seed=None, options=None):
super().reset(seed=seed)
pos_hist, depth = self._env.reset()
obs = {
"position_history": np.array(pos_hist, dtype=np.float32),
"depth": np.array(depth, dtype=np.float32),
}
return obs, {}
def step(self, action):
action_list = action.tolist() if hasattr(action, "tolist") else list(action)
pos_hist, depth, done = self._env.step(action_list)
obs = {
"position_history": np.array(pos_hist, dtype=np.float32),
"depth": np.array(depth, dtype=np.float32),
}
reward = self._env.reward()
return obs, reward, done, False, {}
def get_position(self):
"""Get the agent's current [x, y, z] position."""
return np.array(self._env.get_position(), dtype=np.float32)

57
strafe_ai/model.py Normal file
View File

@@ -0,0 +1,57 @@
"""
Neural network model — PyTorch equivalent of the Rust Net struct.
Takes position history + depth buffer, outputs 7 control values.
"""
import torch
import torch.nn as nn
from strafe_ai._rust import StrafeEnv
POS_SIZE, DEPTH_SIZE = StrafeEnv.observation_sizes()
ACTION_SIZE = StrafeEnv.action_size()
# hidden layer sizes (same ratios as the Rust version)
INPUT_SIZE = POS_SIZE + DEPTH_SIZE
HIDDEN = [INPUT_SIZE >> 3, INPUT_SIZE >> 5, INPUT_SIZE >> 7]
class StrafeNet(nn.Module):
"""
Simple feedforward network for bhop control.
Architecture matches the Rust version:
- Dropout on depth input
- Concatenate position history + depth
- 3 hidden layers with ReLU
- Linear output (7 values)
"""
def __init__(self, dropout_rate: float = 0.1):
super().__init__()
self.depth_dropout = nn.Dropout(dropout_rate)
layers = []
prev_size = INPUT_SIZE
for hidden_size in HIDDEN:
layers.append(nn.Linear(prev_size, hidden_size))
layers.append(nn.ReLU())
prev_size = hidden_size
layers.append(nn.Linear(prev_size, ACTION_SIZE))
self.network = nn.Sequential(*layers)
def forward(self, position_history: torch.Tensor, depth: torch.Tensor) -> torch.Tensor:
"""
Args:
position_history: (batch, POS_SIZE) — relative position + angle history
depth: (batch, DEPTH_SIZE) — depth buffer pixels
Returns:
(batch, 7) — [forward, left, back, right, jump, mouse_dx, mouse_dy]
"""
x = self.depth_dropout(depth)
x = torch.cat([position_history, x], dim=1)
return self.network(x)

100
strafe_ai/train.py Normal file
View File

@@ -0,0 +1,100 @@
"""
Training script — imitation learning from .qbot replay files.
This is the Python equivalent of the Rust training code.
It trains the model to predict player controls from depth frames.
Usage:
python -m strafe_ai.train --map-file map.snfm --bot-file replay.qbot --epochs 10000
"""
import argparse
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from strafe_ai.model import StrafeNet, POS_SIZE, DEPTH_SIZE, ACTION_SIZE
from strafe_ai.environment import StrafeEnvironment
def load_training_data(env: StrafeEnvironment, bot_file: str):
"""
Generate training data by replaying a .qbot file through the environment.
For now this is a placeholder — the actual .qbot parsing happens in Rust.
You would call the Rust training data generator and load the results.
"""
# TODO: add a Rust function to generate training pairs from .qbot files
# For now, this returns dummy data to verify the pipeline works
raise NotImplementedError(
"Training data generation from .qbot files requires Rust bindings. "
"Use the Rust training code for imitation learning, or implement RL below."
)
def train_rl(map_file: str, bot_file: str, epochs: int, lr: float, device: str):
"""
Reinforcement learning training loop.
This is the main training path going forward.
"""
print(f"Setting up environment with map: {map_file}")
print(f"Setting up environment with bot: {bot_file}")
env = StrafeEnvironment(map_file, bot_file)
print(f"Using device: {device}")
torch_device = torch.device(device)
model = StrafeNet().to(torch_device)
num_params = sum(p.numel() for p in model.parameters())
print(f"Model has {num_params:,} parameters")
writer = SummaryWriter("runs/strafe-ai")
# TODO: implement RL algorithm (PPO via Stable Baselines3, or custom)
# For now, run random actions to verify the environment works
print("Running environment test (random actions)...")
obs, info = env.reset()
total_reward = 0.0
for step in range(10000):
# random action
action = env.action_space.sample()
# forward pass through model (just to verify it works)
pos_hist = torch.tensor(obs["position_history"], device=torch_device).unsqueeze(0)
depth = torch.tensor(obs["depth"], device=torch_device).unsqueeze(0)
with torch.no_grad():
predicted = model(pos_hist, depth)
obs, reward, terminated, truncated, info = env.step(action)
total_reward += reward
if terminated or truncated:
obs, info = env.reset()
if step % 1000 == 0:
pos = env.get_position()
print(f" step {step:4d} | pos = ({pos[0]:.1f}, {pos[1]:.1f}, {pos[2]:.1f})")
print(f"Test complete. Total reward: {total_reward}")
writer.close()
def main():
parser = argparse.ArgumentParser(description="Train strafe-ai")
parser.add_argument("--map-file", required=True, help="Path to .snfm map file")
parser.add_argument("--bot-file", help="Path to .qbot file (for imitation learning)")
parser.add_argument("--epochs", type=int, default=10000)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
args = parser.parse_args()
train_rl(args.map_file, args.bot_file, args.epochs, args.lr, args.device)
if __name__ == "__main__":
main()