28 Commits

Author SHA1 Message Date
Cameron Grant
d907672daa Add poetry.lock file for Python dependency management. 2026-03-30 12:10:13 -07:00
Cameron Grant
899278ff64 Add Cargo.lock file for Rust dependency management. 2026-03-30 11:53:59 -07:00
Cameron Grant
5da88a0f69 Converted full project to PyTorch. 2026-03-30 11:39:04 -07:00
96c21fffa9 separate depth from inputs 2026-03-28 08:44:22 -07:00
357e0f4a20 print best loss 2026-03-28 08:36:16 -07:00
31bfa208f8 include angles in history 2026-03-28 08:26:03 -07:00
1d09378bfd silence lint 2026-03-28 08:08:01 -07:00
bf2bf6d693 dropout first 2026-03-28 07:37:04 -07:00
a144ff1178 fix file name shenanigans 2026-03-27 19:33:19 -07:00
48f9657d0f don't hardcode map and bot 2026-03-27 16:46:02 -07:00
e38c0a92b4 remove chrono dep 2026-03-27 16:28:30 -07:00
148471dce1 simulate: add output_file argument 2026-03-27 16:28:30 -07:00
7bf439395b write model name based on num params 2026-03-27 16:16:26 -07:00
03f5eb5c13 tweak model 2026-03-27 16:04:03 -07:00
1e7bb6c4ce format 2026-03-27 15:57:32 -07:00
d8b0f9abbb rename GraphicsState to InputGenerator 2026-03-27 15:57:17 -07:00
fb8c6e2492 training options 2026-03-27 15:56:15 -07:00
18cad85b62 add cli args 2026-03-27 15:56:15 -07:00
b195a7eb95 split code into modules 2026-03-27 15:46:51 -07:00
4208090da0 add clap dep 2026-03-27 15:32:17 -07:00
e31b148f41 simulator 2026-03-27 15:29:32 -07:00
a05113baa5 feed position history into model inputs 2026-03-27 15:28:53 -07:00
9ad8a70ad0 hardcode depth "normalization" 2026-03-27 15:13:29 -07:00
e890623f2e add dropout to input 2026-03-27 15:00:30 -07:00
7d55e872e7 save model with current date 2026-03-27 11:56:45 -07:00
1e1cbeb180 graphics state 2026-03-27 11:56:45 -07:00
e19c46d851 save best model 2026-03-27 11:25:34 -07:00
59bb8eee12 implement training 2026-03-27 11:13:23 -07:00
13 changed files with 3226 additions and 6503 deletions

26
.gitignore vendored
View File

@@ -1 +1,27 @@
# Rust
/target
# Python
__pycache__
*.pyc
*.egg-info
.eggs
dist
build
.venv
# Data files
/files
*.snfm
*.qbot
*.snfb
*.model
*.bin
# TensorBoard
runs/
# IDE / tools
.claude
.idea
_rust.*

6516
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,16 +1,22 @@
[package]
name = "strafe-ai"
name = "strafe-ai-py"
version = "0.1.0"
edition = "2024"
[lib]
name = "_rust"
crate-type = ["cdylib"]
[dependencies]
burn = { version = "0.20.1", features = ["cuda", "autodiff"] }
pyo3 = { version = "0.28", features = ["extension-module"] }
numpy = "0.28"
glam = "0.32.1"
pollster = "0.4.0"
wgpu = "29.0.0"
strafesnet_common = { version = "0.9.0", registry = "strafesnet" }
strafesnet_graphics = { version = "0.0.10", registry = "strafesnet" }
strafesnet_graphics = { version = "=0.0.11-depth2", registry = "strafesnet" }
strafesnet_physics = { version = "=0.0.2-surf", registry = "strafesnet" }
strafesnet_roblox_bot_file = { version = "0.9.4", registry = "strafesnet" }
strafesnet_roblox_bot_player = { version = "0.6.1", registry = "strafesnet" }
strafesnet_roblox_bot_player = { version = "=0.6.2-depth2", registry = "strafesnet" }
strafesnet_snf = { version = "0.4.0", registry = "strafesnet" }
pollster = "0.4.0"

76
README.md Normal file
View File

@@ -0,0 +1,76 @@
# strafe-ai-py
PyTorch + Rust environment for training an AI to play bhop maps.
## Architecture
- **Rust** (via PyO3): Physics simulation + depth rendering — fast, deterministic
- **Python** (PyTorch): Model, training, RL algorithms, TensorBoard logging
The Rust side exposes a [Gymnasium](https://gymnasium.farama.org/)-compatible environment,
so any standard RL library (Stable Baselines3, CleanRL, etc.) works out of the box.
## Setup
Requires:
- Python 3.10+
- Rust toolchain
- CUDA toolkit
- [maturin](https://github.com/PyO3/maturin) for building the Rust extension
```bash
# create a virtual environment
python -m venv .venv
source .venv/bin/activate # or .venv\Scripts\activate on Windows
# install maturin
pip install maturin
# build and install the Rust extension + Python deps
maturin develop --release
# or install with RL extras
pip install -e ".[rl]"
```
## Usage
```bash
# test the environment
python -m strafe_ai.train --map-file path/to/map.snfm
# tensorboard
tensorboard --logdir runs
```
## Project Structure
```
strafe_ai/ Python package
environment.py Gymnasium env wrapping Rust sim
model.py PyTorch model (StrafeNet)
train.py Training script
src/
lib.rs PyO3 bindings (physics + rendering)
```
## Environment API
```python
from strafe_ai import StrafeEnvironment
env = StrafeEnvironment("map.snfm")
obs, info = env.reset()
# obs["position_history"] — (50,) float32 — 10 recent positions + angles
# obs["depth"] — (2304,) float32 — 64x36 depth buffer
action = env.action_space.sample() # 7 floats
obs, reward, terminated, truncated, info = env.step(action)
```
## Next Steps
- [ ] Implement reward function (curve_dt along WR path)
- [ ] Train with PPO (Stable Baselines3)
- [ ] Add .qbot replay loading for imitation learning pretraining

2319
poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

22
pyproject.toml Normal file
View File

@@ -0,0 +1,22 @@
[build-system]
requires = ["maturin>=1.5,<2.0"]
build-backend = "maturin"
[project]
name = "strafe-ai"
version = "0.1.0"
requires-python = ">=3.10"
dependencies = [
"torch>=2.0",
"gymnasium",
"numpy",
"tensorboard",
]
[project.optional-dependencies]
rl = ["stable-baselines3"]
[tool.maturin]
python-source = "."
features = ["pyo3/extension-module"]
module-name = "strafe_ai._rust"

View File

@@ -1 +0,0 @@
hard_tabs = true

375
src/lib.rs Normal file
View File

@@ -0,0 +1,375 @@
use pyo3::prelude::*;
use pyo3::exceptions::PyRuntimeError;
use strafesnet_common::instruction::TimedInstruction;
use strafesnet_common::mouse::MouseState;
use strafesnet_common::physics::{
Instruction as PhysicsInputInstruction, ModeInstruction, MouseInstruction,
SetControlInstruction, Time as PhysicsTime,
};
use strafesnet_physics::physics::{PhysicsContext, PhysicsData, PhysicsState};
const LIMITS: wgpu::Limits = wgpu::Limits::defaults();
const FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Rgba8UnormSrgb;
const SIZE_X: u32 = 64;
const SIZE_Y: u32 = 36;
const SIZE: glam::UVec2 = glam::uvec2(SIZE_X, SIZE_Y);
const DEPTH_PIXELS: usize = (SIZE_X * SIZE_Y) as usize;
const STRIDE_SIZE: u32 = (SIZE_X * size_of::<f32>() as u32).next_multiple_of(256);
const POSITION_HISTORY: usize = 10;
const STEP: PhysicsTime = PhysicsTime::from_millis(10);
/// The Rust-side environment that wraps physics simulation and depth rendering.
/// Exposed to Python via PyO3.
#[pyclass]
struct StrafeEnv {
// rendering
wgpu_device: wgpu::Device,
wgpu_queue: wgpu::Queue,
graphics: strafesnet_roblox_bot_player::graphics::Graphics,
graphics_texture_view: wgpu::TextureView,
output_staging_buffer: wgpu::Buffer,
texture_data: Vec<u8>,
// simulation
geometry: PhysicsData,
simulation: PhysicsState,
time: PhysicsTime,
mouse_pos: glam::IVec2,
start_offset: glam::Vec3,
// position history (relative positions + angles)
position_history: Vec<(glam::Vec3, glam::Vec2)>,
// map data (kept for reset)
map: strafesnet_common::map::CompleteMap,
}
#[pymethods]
impl StrafeEnv {
/// Create a new environment from a map file path.
#[new]
fn new(map_path: &str) -> PyResult<Self> {
let map_file = std::fs::read(map_path)
.map_err(|e| PyRuntimeError::new_err(format!("Failed to read map: {e}")))?;
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
.map_err(|e| PyRuntimeError::new_err(format!("Failed to decode map: {e:?}")))?
.into_complete_map()
.map_err(|e| PyRuntimeError::new_err(format!("Failed to load map: {e:?}")))?;
// compute start offset
let modes = map.modes.clone().denormalize();
let mode = modes
.get_mode(strafesnet_common::gameplay_modes::ModeId::MAIN)
.ok_or_else(|| PyRuntimeError::new_err("No MAIN mode in map"))?;
let start_zone = map
.models
.get(mode.get_start().get() as usize)
.ok_or_else(|| PyRuntimeError::new_err("Start zone not found"))?;
let start_offset = glam::Vec3::from_array(
start_zone.transform.translation.map(|f| f.into()).to_array(),
);
// init wgpu
let desc = wgpu::InstanceDescriptor::new_without_display_handle_from_env();
let instance = wgpu::Instance::new(desc);
let (wgpu_device, wgpu_queue) = pollster::block_on(async {
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
force_fallback_adapter: false,
compatible_surface: None,
})
.await
.unwrap();
strafesnet_graphics::setup::step4::request_device(&adapter, LIMITS)
.await
.unwrap()
});
let mut graphics = strafesnet_roblox_bot_player::graphics::Graphics::new(
&wgpu_device, &wgpu_queue, SIZE, FORMAT, LIMITS,
);
graphics
.change_map(&wgpu_device, &wgpu_queue, &map)
.unwrap();
let graphics_texture = wgpu_device.create_texture(&wgpu::TextureDescriptor {
label: Some("RGB texture"),
format: FORMAT,
size: wgpu::Extent3d {
width: SIZE_X,
height: SIZE_Y,
depth_or_array_layers: 1,
},
mip_level_count: 1,
sample_count: 1,
dimension: wgpu::TextureDimension::D2,
usage: wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
view_formats: &[],
});
let graphics_texture_view = graphics_texture.create_view(&wgpu::TextureViewDescriptor {
label: Some("RGB texture view"),
aspect: wgpu::TextureAspect::All,
usage: Some(
wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
),
..Default::default()
});
let texture_data = Vec::<u8>::with_capacity((STRIDE_SIZE * SIZE_Y) as usize);
let output_staging_buffer = wgpu_device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Output staging buffer"),
size: texture_data.capacity() as u64,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let geometry = PhysicsData::new(&map);
let simulation = PhysicsState::default();
let mut env = Self {
wgpu_device,
wgpu_queue,
graphics,
graphics_texture_view,
output_staging_buffer,
texture_data,
geometry,
simulation,
time: PhysicsTime::ZERO,
mouse_pos: glam::ivec2(-5300, 0),
start_offset,
position_history: Vec::with_capacity(POSITION_HISTORY),
map,
};
// initial reset
env.do_reset();
Ok(env)
}
/// Reset the environment. Returns (position_history, depth) as flat f32 arrays.
fn reset(&mut self) -> (Vec<f32>, Vec<f32>) {
self.do_reset();
self.get_observation()
}
/// Take one step with the given action.
/// action: [move_forward, move_left, move_back, move_right, jump, mouse_dx, mouse_dy]
/// Returns (position_history, depth, done)
fn step(&mut self, action: Vec<f32>) -> PyResult<(Vec<f32>, Vec<f32>, bool)> {
if action.len() != 7 {
return Err(PyRuntimeError::new_err("Action must have 7 elements"));
}
let move_forward = action[0] > 0.5;
let move_left = action[1] > 0.5;
let move_back = action[2] > 0.5;
let move_right = action[3] > 0.5;
let jump = action[4] > 0.5;
let mouse_dx = action[5];
let mouse_dy = action[6];
// apply controls
self.run_instruction(PhysicsInputInstruction::SetControl(
SetControlInstruction::SetMoveForward(move_forward),
));
self.run_instruction(PhysicsInputInstruction::SetControl(
SetControlInstruction::SetMoveLeft(move_left),
));
self.run_instruction(PhysicsInputInstruction::SetControl(
SetControlInstruction::SetMoveBack(move_back),
));
self.run_instruction(PhysicsInputInstruction::SetControl(
SetControlInstruction::SetMoveRight(move_right),
));
self.run_instruction(PhysicsInputInstruction::SetControl(
SetControlInstruction::SetJump(jump),
));
// apply mouse
self.mouse_pos += glam::vec2(mouse_dx, mouse_dy).round().as_ivec2();
let next_time = self.time + STEP;
PhysicsContext::run_input_instruction(
&mut self.simulation,
&self.geometry,
TimedInstruction {
time: self.time,
instruction: PhysicsInputInstruction::Mouse(MouseInstruction::SetNextMouse(
MouseState {
pos: self.mouse_pos,
time: next_time,
},
)),
},
);
self.time = next_time;
let (pos_hist, depth) = self.get_observation();
// done after 20 seconds of simulation
let done = self.time >= PhysicsTime::from_millis(20_000);
Ok((pos_hist, depth, done))
}
/// Get the current position as [x, y, z].
fn get_position(&self) -> Vec<f32> {
let trajectory = self.simulation.camera_trajectory(&self.geometry);
let pos = trajectory
.extrapolated_position(self.time)
.map(Into::<f32>::into)
.to_array();
pos.to_vec()
}
/// Get observation dimensions.
#[staticmethod]
fn observation_sizes() -> (usize, usize) {
(POSITION_HISTORY * 5, DEPTH_PIXELS)
}
/// Get action size.
#[staticmethod]
fn action_size() -> usize {
7
}
}
impl StrafeEnv {
fn do_reset(&mut self) {
self.simulation = PhysicsState::default();
self.time = PhysicsTime::ZERO;
self.mouse_pos = glam::ivec2(-5300, 0);
self.position_history.clear();
self.run_instruction(PhysicsInputInstruction::Mode(ModeInstruction::Reset));
self.run_instruction(PhysicsInputInstruction::Mode(ModeInstruction::Restart(
strafesnet_common::gameplay_modes::ModeId::MAIN,
)));
}
fn run_instruction(&mut self, instruction: PhysicsInputInstruction) {
PhysicsContext::run_input_instruction(
&mut self.simulation,
&self.geometry,
TimedInstruction {
time: self.time,
instruction,
},
);
}
fn get_observation(&mut self) -> (Vec<f32>, Vec<f32>) {
let trajectory = self.simulation.camera_trajectory(&self.geometry);
let pos: glam::Vec3 = trajectory
.extrapolated_position(self.time)
.map(Into::<f32>::into)
.to_array()
.into();
let camera = self.simulation.camera();
let angles = camera.simulate_move_angles(glam::IVec2::ZERO);
// build position history input
let mut pos_hist = Vec::with_capacity(POSITION_HISTORY * 5);
if !self.position_history.is_empty() {
let cam_matrix =
strafesnet_graphics::graphics::view_inv(pos - self.start_offset, angles).inverse();
for &(p, a) in self.position_history.iter().rev() {
let relative_pos = cam_matrix.transform_vector3(p);
let relative_ang = glam::vec2(angles.x - a.x, a.y);
pos_hist.extend_from_slice(&relative_pos.to_array());
pos_hist.extend_from_slice(&relative_ang.to_array());
}
}
// pad remaining history with zeros
for _ in self.position_history.len()..POSITION_HISTORY {
pos_hist.extend_from_slice(&[0.0, 0.0, 0.0, 0.0, 0.0]);
}
// update position history
if self.position_history.len() < POSITION_HISTORY {
self.position_history.push((pos, angles));
} else {
self.position_history.rotate_left(1);
*self.position_history.last_mut().unwrap() = (pos, angles);
}
// render depth
let render_pos = pos - self.start_offset;
let depth = self.render_depth(render_pos, angles);
(pos_hist, depth)
}
fn render_depth(&mut self, pos: glam::Vec3, angles: glam::Vec2) -> Vec<f32> {
let mut encoder = self
.wgpu_device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("depth encoder"),
});
self.graphics
.encode_commands(&mut encoder, &self.graphics_texture_view, pos, angles);
encoder.copy_texture_to_buffer(
wgpu::TexelCopyTextureInfo {
texture: self.graphics.depth_texture(),
mip_level: 0,
origin: wgpu::Origin3d::ZERO,
aspect: wgpu::TextureAspect::All,
},
wgpu::TexelCopyBufferInfo {
buffer: &self.output_staging_buffer,
layout: wgpu::TexelCopyBufferLayout {
offset: 0,
bytes_per_row: Some(STRIDE_SIZE),
rows_per_image: Some(SIZE_Y),
},
},
wgpu::Extent3d {
width: SIZE_X,
height: SIZE_Y,
depth_or_array_layers: 1,
},
);
self.wgpu_queue.submit([encoder.finish()]);
let buffer_slice = self.output_staging_buffer.slice(..);
let (sender, receiver) = std::sync::mpsc::channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |r| sender.send(r).unwrap());
self.wgpu_device
.poll(wgpu::PollType::wait_indefinitely())
.unwrap();
receiver.recv().unwrap().unwrap();
let mut depth = Vec::with_capacity(DEPTH_PIXELS);
{
let view = buffer_slice.get_mapped_range();
self.texture_data.extend_from_slice(&view[..]);
}
self.output_staging_buffer.unmap();
for y in 0..SIZE_Y {
depth.extend(
self.texture_data[(STRIDE_SIZE * y) as usize
..(STRIDE_SIZE * y + SIZE_X * size_of::<f32>() as u32) as usize]
.chunks_exact(4)
.map(|b| 1.0 - 2.0 * f32::from_le_bytes(b.try_into().unwrap())),
);
}
self.texture_data.clear();
depth
}
}
/// Python module definition
#[pymodule]
fn _rust(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<StrafeEnv>()?;
Ok(())
}

View File

@@ -1,147 +0,0 @@
use burn::backend::Autodiff;
use burn::module::AutodiffModule;
use burn::nn::loss::{MseLoss, Reduction};
use burn::nn::{Linear, LinearConfig, Relu, Sigmoid};
use burn::optim::{GradientsParams, Optimizer, SgdConfig};
use burn::prelude::*;
type InferenceBackend = burn::backend::Cuda<f32>;
type TrainingBackend = Autodiff<InferenceBackend>;
const LIMITS: wgpu::Limits = wgpu::Limits::defaults();
use strafesnet_common::session::Time as SessionTime;
use strafesnet_graphics::setup;
const INPUT: usize = 2;
const HIDDEN: usize = 64;
// MoveForward
// MoveLeft
// MoveBack
// MoveRight
// Jump
// mouse_dx
// mouse_dy
const OUTPUT: usize = 7;
#[derive(Module, Debug)]
struct Net<B: Backend> {
input: Linear<B>,
hidden: [Linear<B>; 64],
output: Linear<B>,
activation: Relu,
sigmoid: Sigmoid,
}
impl<B: Backend> Net<B> {
fn init(device: &B::Device) -> Self {
Self {
input: LinearConfig::new(INPUT, HIDDEN).init(device),
hidden: core::array::from_fn(|_| LinearConfig::new(HIDDEN, HIDDEN).init(device)),
output: LinearConfig::new(HIDDEN, OUTPUT).init(device),
activation: Relu::new(),
sigmoid: Sigmoid::new(),
}
}
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let x = self.input.forward(input);
let mut x = self.activation.forward(x);
for layer in &self.hidden {
x = layer.forward(x);
x = self.activation.forward(x);
}
let x = self.output.forward(x);
self.sigmoid.forward(x)
}
}
fn training() {
// load map
// load replay
// setup player
const SIZE_X: usize = 64;
const SIZE_Y: usize = 36;
let map_file = include_bytes!("../bhop_marble_5692093612.snfm");
let bot_file = include_bytes!("../bhop_marble_7cf33a64-7120-4514-b9fa-4fe29d9523d.qbot");
// read files
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
.unwrap()
.into_complete_map()
.unwrap();
let timelines =
strafesnet_roblox_bot_file::v0::read_all_to_block(std::io::Cursor::new(bot_file)).unwrap();
let bot = strafesnet_roblox_bot_player::bot::CompleteBot::new(timelines).unwrap();
let mut playback_head =
strafesnet_roblox_bot_player::head::PlaybackHead::new(&bot, SessionTime::ZERO);
// setup graphics
let desc = wgpu::InstanceDescriptor::new_without_display_handle_from_env();
let instance = wgpu::Instance::new(desc);
let (device, queue) = pollster::block_on(async {
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
force_fallback_adapter: false,
compatible_surface: None,
})
.await
.unwrap();
setup::step4::request_device(&adapter, LIMITS)
.await
.unwrap()
});
const FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Rgba8UnormSrgb;
let graphics = strafesnet_roblox_bot_player::graphics::Graphics::new(
&device,
&queue,
[SIZE_X as u32, SIZE_Y as u32].into(),
FORMAT,
LIMITS,
);
// setup simulation
// run progressively longer segments of the map, starting very close to the end of the run and working the starting time backwards until the ai can run the whole map
let device = Default::default();
let mut model: Net<TrainingBackend> = Net::init(&device);
let mut optim = SgdConfig::new().init();
let inputs = Tensor::from_floats([0.0f32; INPUT], &device);
let targets = Tensor::from_floats([0.0f32; OUTPUT], &device);
const LEARNING_RATE: f64 = 0.5;
const EPOCHS: usize = 100;
for epoch in 0..EPOCHS {
let predictions = model.forward(inputs.clone());
let loss = MseLoss::new().forward(predictions, targets.clone(), Reduction::Mean);
if epoch % (EPOCHS >> 4) == 0 || epoch == EPOCHS - 1 {
// .clone().into_scalar() extracts the f32 value from a 1-element tensor.
println!(
" epoch {:>5} | loss = {:.8}",
epoch,
loss.clone().into_scalar()
);
}
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
model = optim.step(LEARNING_RATE, model, grads);
}
}
fn inference() {
// load map
// setup simulation
// setup agent-simulation feedback loop
// go!
}
fn main() {}

4
strafe_ai/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
from strafe_ai.environment import StrafeEnvironment
from strafe_ai.model import StrafeNet
__all__ = ["StrafeEnvironment", "StrafeNet"]

71
strafe_ai/environment.py Normal file
View File

@@ -0,0 +1,71 @@
"""
Gymnasium-compatible environment wrapping the Rust physics sim + depth renderer.
Usage:
env = StrafeEnvironment("path/to/map.snfm")
obs, info = env.reset()
obs, reward, terminated, truncated, info = env.step(action)
"""
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from strafe_ai._rust import StrafeEnv
class StrafeEnvironment(gym.Env):
"""
A bhop environment.
Observation: dict with "position_history" and "depth" arrays.
Action: 7 floats — [forward, left, back, right, jump, mouse_dx, mouse_dy]
First 5 are binary (thresholded at 0.5), last 2 are continuous.
"""
metadata = {"render_modes": ["none"]}
def __init__(self, map_path: str):
super().__init__()
self._env = StrafeEnv(map_path)
pos_size, depth_size = StrafeEnv.observation_sizes()
# observation: position history + depth buffer
self.observation_space = spaces.Dict({
"position_history": spaces.Box(-np.inf, np.inf, shape=(pos_size,), dtype=np.float32),
"depth": spaces.Box(-1.0, 1.0, shape=(depth_size,), dtype=np.float32),
})
# action: 5 binary controls + 2 continuous mouse deltas
self.action_space = spaces.Box(
low=np.array([0, 0, 0, 0, 0, -100, -100], dtype=np.float32),
high=np.array([1, 1, 1, 1, 1, 100, 100], dtype=np.float32),
)
def reset(self, seed=None, options=None):
super().reset(seed=seed)
pos_hist, depth = self._env.reset()
obs = {
"position_history": np.array(pos_hist, dtype=np.float32),
"depth": np.array(depth, dtype=np.float32),
}
return obs, {}
def step(self, action):
action_list = action.tolist() if hasattr(action, "tolist") else list(action)
pos_hist, depth, done = self._env.step(action_list)
obs = {
"position_history": np.array(pos_hist, dtype=np.float32),
"depth": np.array(depth, dtype=np.float32),
}
# TODO: implement Rhys's reward function (curve_dt progress along WR path)
reward = 0.0
return obs, reward, done, False, {}
def get_position(self):
"""Get the agent's current [x, y, z] position."""
return np.array(self._env.get_position(), dtype=np.float32)

57
strafe_ai/model.py Normal file
View File

@@ -0,0 +1,57 @@
"""
Neural network model — PyTorch equivalent of the Rust Net struct.
Takes position history + depth buffer, outputs 7 control values.
"""
import torch
import torch.nn as nn
from strafe_ai._rust import StrafeEnv
POS_SIZE, DEPTH_SIZE = StrafeEnv.observation_sizes()
ACTION_SIZE = StrafeEnv.action_size()
# hidden layer sizes (same ratios as the Rust version)
INPUT_SIZE = POS_SIZE + DEPTH_SIZE
HIDDEN = [INPUT_SIZE >> 3, INPUT_SIZE >> 5, INPUT_SIZE >> 7]
class StrafeNet(nn.Module):
"""
Simple feedforward network for bhop control.
Architecture matches the Rust version:
- Dropout on depth input
- Concatenate position history + depth
- 3 hidden layers with ReLU
- Linear output (7 values)
"""
def __init__(self, dropout_rate: float = 0.1):
super().__init__()
self.depth_dropout = nn.Dropout(dropout_rate)
layers = []
prev_size = INPUT_SIZE
for hidden_size in HIDDEN:
layers.append(nn.Linear(prev_size, hidden_size))
layers.append(nn.ReLU())
prev_size = hidden_size
layers.append(nn.Linear(prev_size, ACTION_SIZE))
self.network = nn.Sequential(*layers)
def forward(self, position_history: torch.Tensor, depth: torch.Tensor) -> torch.Tensor:
"""
Args:
position_history: (batch, POS_SIZE) — relative position + angle history
depth: (batch, DEPTH_SIZE) — depth buffer pixels
Returns:
(batch, 7) — [forward, left, back, right, jump, mouse_dx, mouse_dy]
"""
x = self.depth_dropout(depth)
x = torch.cat([position_history, x], dim=1)
return self.network(x)

99
strafe_ai/train.py Normal file
View File

@@ -0,0 +1,99 @@
"""
Training script — imitation learning from .qbot replay files.
This is the Python equivalent of the Rust training code.
It trains the model to predict player controls from depth frames.
Usage:
python -m strafe_ai.train --map-file map.snfm --bot-file replay.qbot --epochs 10000
"""
import argparse
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from strafe_ai.model import StrafeNet, POS_SIZE, DEPTH_SIZE, ACTION_SIZE
from strafe_ai.environment import StrafeEnvironment
def load_training_data(env: StrafeEnvironment, bot_file: str):
"""
Generate training data by replaying a .qbot file through the environment.
For now this is a placeholder — the actual .qbot parsing happens in Rust.
You would call the Rust training data generator and load the results.
"""
# TODO: add a Rust function to generate training pairs from .qbot files
# For now, this returns dummy data to verify the pipeline works
raise NotImplementedError(
"Training data generation from .qbot files requires Rust bindings. "
"Use the Rust training code for imitation learning, or implement RL below."
)
def train_rl(map_file: str, epochs: int, lr: float, device: str):
"""
Reinforcement learning training loop.
This is the main training path going forward.
"""
print(f"Setting up environment with map: {map_file}")
env = StrafeEnvironment(map_file)
print(f"Using device: {device}")
torch_device = torch.device(device)
model = StrafeNet().to(torch_device)
num_params = sum(p.numel() for p in model.parameters())
print(f"Model has {num_params:,} parameters")
writer = SummaryWriter("runs/strafe-ai")
# TODO: implement RL algorithm (PPO via Stable Baselines3, or custom)
# For now, run random actions to verify the environment works
print("Running environment test (random actions)...")
obs, info = env.reset()
total_reward = 0.0
for step in range(100):
# random action
action = env.action_space.sample()
# forward pass through model (just to verify it works)
pos_hist = torch.tensor(obs["position_history"], device=torch_device).unsqueeze(0)
depth = torch.tensor(obs["depth"], device=torch_device).unsqueeze(0)
with torch.no_grad():
predicted = model(pos_hist, depth)
obs, reward, terminated, truncated, info = env.step(action)
total_reward += reward
if terminated or truncated:
obs, info = env.reset()
if step % 10 == 0:
pos = env.get_position()
print(f" step {step:4d} | pos = ({pos[0]:.1f}, {pos[1]:.1f}, {pos[2]:.1f})")
print(f"Test complete. Total reward: {total_reward}")
writer.close()
def main():
parser = argparse.ArgumentParser(description="Train strafe-ai")
parser.add_argument("--map-file", required=True, help="Path to .snfm map file")
parser.add_argument("--bot-file", help="Path to .qbot file (for imitation learning)")
parser.add_argument("--epochs", type=int, default=10000)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
args = parser.parse_args()
train_rl(args.map_file, args.epochs, args.lr, args.device)
if __name__ == "__main__":
main()