48 Commits
burn ... main

Author SHA1 Message Date
37f0987e62 simplify position history padding 2026-03-31 17:48:19 -07:00
7046da289b faster mouse movement 2026-03-31 13:57:42 -07:00
05f6432947 remove goofy clamping 2026-03-31 13:56:52 -07:00
Cameron Grant
4573aaa5ae Add dynamic speed bonus to reward function and adjust PPO training parameters. 2026-03-31 13:51:17 -07:00
Cameron Grant
76ff41203d Add support for resuming training from a saved checkpoint. Edited values. 2026-03-31 13:45:30 -07:00
87c3967caa Revert "work around strafe tick bug"
This reverts commit e23455bba2.
2026-03-31 13:00:26 -07:00
4916b76126 update physics 2026-03-31 13:00:26 -07:00
Cameron Grant
b30b43ed9d Added skip connections 2026-03-31 12:57:34 -07:00
Cameron Grant
b6b208975c Refactor action space to continuous movement and simplify mouse controls. 2026-03-31 12:51:44 -07:00
Cameron Grant
2795090f13 Fixed issues with watch helper 2026-03-31 12:51:25 -07:00
Cameron Grant
5564c6529a Added auto contrast to previews. 2026-03-31 12:43:06 -07:00
Cameron Grant
eb62ac2099 Continuous actions 2026-03-31 11:46:27 -07:00
Cameron Grant
5e385f806c Added CNN 2026-03-31 11:40:13 -07:00
a030afb018 remove extra rewards 2026-03-31 11:31:48 -07:00
e23455bba2 work around strafe tick bug 2026-03-31 11:20:09 -07:00
f7521d49f2 idle so get_observation shows next state rather than this state 2026-03-31 10:43:47 -07:00
ee02c9cbda if statement has bad cost benefit 2026-03-31 10:05:10 -07:00
2c59742799 simplify get_map_positions 2026-03-31 09:58:48 -07:00
24bcb63d0e add velocity to position history, fix position history, use zero camera pitch 2026-03-31 09:52:46 -07:00
Cameron Grant
ccb4fb5791 Draw topdown map to tensorboard during training. 2026-03-30 20:56:20 -07:00
Cameron Grant
0444c6d68a Add watch.py for visualizing agent performance; update Python dependencies in poetry.lock. 2026-03-30 20:39:50 -07:00
bd2f60fb72 fix reset 2026-03-30 16:23:15 -07:00
4dd3201192 reward horizontal speed 2026-03-30 16:20:14 -07:00
44c8c53122 deny reward when pressing A and D 2026-03-30 16:16:08 -07:00
aaa5a158e8 allow playback to run until WR time elapses 2026-03-30 16:10:45 -07:00
Cameron Grant
cb59737985 Optimize PPO training parameters and switch to SubprocVecEnv for multi-environment setups. 2026-03-30 15:56:45 -07:00
Cameron Grant
5ba65ba4d0 Implement reward shaping for bhop mechanics and adjust PPO training parameters 2026-03-30 15:43:12 -07:00
eeb935dcd6 note stable-baselines3 in readme 2026-03-30 15:11:23 -07:00
Cameron Grant
89a398f1f6 Replace custom RL training loop with Stable-Baselines3 PPO implementation. Update training script and dependencies accordingly. 2026-03-30 15:09:17 -07:00
Cameron Grant
9b1f61b128 Refactor training pipeline to implement RL loop with policy gradient (REINFORCE). 2026-03-30 14:51:38 -07:00
Cameron Grant
ae02fbba79 Update simulation completion logic to include run_finished status. 2026-03-30 14:43:13 -07:00
Cameron Grant
679962024f Reset best_progress and run_finished in physics state initialization. 2026-03-30 14:41:48 -07:00
c11910b33e handsomely reward run completion 2026-03-30 14:29:09 -07:00
e8845ce28d incremental reward 2026-03-30 14:15:56 -07:00
73dcb93d5e more steps 2026-03-30 14:08:31 -07:00
3b00541644 add reward 2026-03-30 14:08:31 -07:00
b29d5f3845 normal torch ok 2026-03-30 13:15:25 -07:00
48cd49bd43 document setup using uv + fish 2026-03-30 13:13:19 -07:00
8ac74d36f0 CompleteMap is only needed for construction 2026-03-30 12:46:28 -07:00
9995e852d4 update ndarray 2026-03-30 12:40:16 -07:00
Cameron Grant
0a1c8068fe Add rustfmt.toml and improve code formatting in lib.rs. 2026-03-30 12:39:55 -07:00
792078121b CompleteMap is only needed for construction 2026-03-30 12:26:34 -07:00
df3e813dd9 fix float to bool 2026-03-30 12:25:37 -07:00
cea0bcbaf3 Use slice deconstruction 2026-03-30 12:18:42 -07:00
e79d0378ac Use helper function 2026-03-30 12:16:57 -07:00
Cameron Grant
d907672daa Add poetry.lock file for Python dependency management. 2026-03-30 12:10:13 -07:00
Cameron Grant
899278ff64 Add Cargo.lock file for Rust dependency management. 2026-03-30 11:53:59 -07:00
Cameron Grant
5da88a0f69 Converted full project to PyTorch. 2026-03-30 11:39:04 -07:00
18 changed files with 3677 additions and 7262 deletions

29
.gitignore vendored
View File

@@ -1,2 +1,29 @@
/files
# Rust
/target
# Python
__pycache__
*.pyc
*.egg-info
.eggs
dist
build
.venv
# Data files
/files
*.snfm
*.qbot
*.snfb
*.model
*.bin
# TensorBoard
runs/
# IDE / tools
.claude
.idea
_rust.*
ppo_strafe_*_steps.zip

6596
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,11 +1,15 @@
[package]
name = "strafe-ai"
name = "strafe-ai-py"
version = "0.1.0"
edition = "2024"
[lib]
name = "_rust"
crate-type = ["cdylib"]
[dependencies]
burn = { version = "0.20.1", features = ["cuda", "autodiff"] }
clap = { version = "4.6.0", features = ["derive"] }
pyo3 = { version = "0.28", features = ["extension-module"] }
numpy = "0.28"
glam = "0.32.1"
pollster = "0.4.0"
wgpu = "29.0.0"

77
README.md Normal file
View File

@@ -0,0 +1,77 @@
# strafe-ai-py
PyTorch + Rust environment for training an AI to play bhop maps.
## Architecture
- **Rust** (via PyO3): Physics simulation + depth rendering — fast, deterministic
- **Python** (PyTorch): Model, training, RL algorithms, TensorBoard logging
The Rust side exposes a [Gymnasium](https://gymnasium.farama.org/)-compatible environment,
so any standard RL library (Stable Baselines3, CleanRL, etc.) works out of the box.
## Setup
Requires:
- Python 3.10+
- Rust toolchain
- CUDA toolkit
- [maturin](https://github.com/PyO3/maturin) for building the Rust extension
```bash
# create a virtual environment
uv venv
source .venv/bin/activate.fish # or .venv\Scripts\activate on Windows
# install maturin
uv pip install torch
uv pip install maturin gymnasium numpy tensorboard stable-baselines3
# build and install the Rust extension + Python deps
maturin develop --release
# or install with RL extras
pip install -e ".[rl]"
```
## Usage
```bash
# test the environment
python -m strafe_ai.train --map-file path/to/map.snfm
# tensorboard
tensorboard --logdir runs
```
## Project Structure
```
strafe_ai/ Python package
environment.py Gymnasium env wrapping Rust sim
model.py PyTorch model (StrafeNet)
train.py Training script
src/
lib.rs PyO3 bindings (physics + rendering)
```
## Environment API
```python
from strafe_ai import StrafeEnvironment
env = StrafeEnvironment("map.snfm")
obs, info = env.reset()
# obs["position_history"] — (50,) float32 — 10 recent positions + angles
# obs["depth"] — (2304,) float32 — 64x36 depth buffer
action = env.action_space.sample() # 7 floats
obs, reward, terminated, truncated, info = env.step(action)
```
## Next Steps
- [ ] Implement reward function (curve_dt along WR path)
- [ ] Train with PPO (Stable Baselines3)
- [ ] Add .qbot replay loading for imitation learning pretraining

2324
poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

21
pyproject.toml Normal file
View File

@@ -0,0 +1,21 @@
[build-system]
requires = ["maturin>=1.5,<2.0"]
build-backend = "maturin"
[project]
name = "strafe-ai"
version = "0.1.0"
requires-python = ">=3.10"
dependencies = [
"torch>=2.0",
"gymnasium",
"numpy",
"tensorboard",
"stable-baselines3",
"pillow",
]
[tool.maturin]
python-source = "."
features = ["pyo3/extension-module"]
module-name = "strafe_ai._rust"

View File

@@ -1 +1 @@
hard_tabs = true
hard_tabs = true

View File

@@ -1,268 +0,0 @@
#[derive(clap::Subcommand)]
pub enum Commands {
Simulate(SimulateSubcommand),
}
impl Commands {
pub fn run(self) {
match self {
Commands::Simulate(subcommand) => subcommand.run(),
}
}
}
#[derive(clap::Args)]
pub struct SimulateSubcommand {
#[arg(long)]
gpu_id: Option<usize>,
#[arg(long)]
model_file: std::path::PathBuf,
#[arg(long)]
output_file: Option<std::path::PathBuf>,
#[arg(long)]
map_file: std::path::PathBuf,
}
impl SimulateSubcommand {
fn run(self) {
let output_file = self.output_file.unwrap_or_else(|| {
let mut file_name = self
.model_file
.file_stem()
.unwrap()
.to_str()
.unwrap()
.to_owned();
file_name.push_str("_replay.snfb");
let mut path = self.model_file.clone();
path.set_file_name(file_name);
path
});
inference(
self.gpu_id.unwrap_or_default(),
self.model_file,
output_file,
self.map_file,
);
}
}
use burn::prelude::*;
use crate::inputs::InputGenerator;
use crate::net::{InferenceBackend, Net, POSITION_HISTORY, POSITION_HISTORY_SIZE, SIZE};
use strafesnet_common::instruction::TimedInstruction;
use strafesnet_common::mouse::MouseState;
use strafesnet_common::physics::{
Instruction as PhysicsInputInstruction, ModeInstruction, MouseInstruction,
SetControlInstruction, Time as PhysicsTime,
};
use strafesnet_physics::physics::{PhysicsContext, PhysicsData, PhysicsState};
pub struct Recording {
instructions: Vec<TimedInstruction<PhysicsInputInstruction, PhysicsTime>>,
}
struct FrameState {
trajectory: strafesnet_physics::physics::Trajectory,
camera: strafesnet_physics::physics::PhysicsCamera,
}
impl FrameState {
fn pos(&self, time: PhysicsTime) -> glam::Vec3 {
self.trajectory
.extrapolated_position(time)
.map(Into::<f32>::into)
.to_array()
.into()
}
fn vel(&self, time: PhysicsTime) -> glam::Vec3 {
self.trajectory
.extrapolated_velocity(time)
.map(Into::<f32>::into)
.to_array()
.into()
}
fn angles(&self) -> glam::Vec2 {
self.camera.simulate_move_angles(glam::IVec2::ZERO)
}
}
struct Session {
geometry_shared: PhysicsData,
simulation: PhysicsState,
recording: Recording,
}
impl Session {
fn get_frame_state(&self) -> FrameState {
FrameState {
trajectory: self.simulation.camera_trajectory(&self.geometry_shared),
camera: self.simulation.camera(),
}
}
fn run(&mut self, time: PhysicsTime, instruction: PhysicsInputInstruction) {
let instruction = TimedInstruction { time, instruction };
self.recording.instructions.push(instruction.clone());
PhysicsContext::run_input_instruction(
&mut self.simulation,
&self.geometry_shared,
instruction,
);
}
}
fn inference(
gpu_id: usize,
model_file: std::path::PathBuf,
output_file: std::path::PathBuf,
map_file: std::path::PathBuf,
) {
// pick device
let device = burn::backend::cuda::CudaDevice::new(gpu_id);
// load model
let mut model: Net<InferenceBackend> = Net::init(&device);
model = model
.load_file(
model_file,
&burn::record::BinFileRecorder::<burn::record::FullPrecisionSettings>::new(),
&device,
)
.unwrap();
// load map
let map_file = std::fs::read(map_file).unwrap();
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
.unwrap()
.into_complete_map()
.unwrap();
let modes = map.modes.clone().denormalize();
let mode = modes
.get_mode(strafesnet_common::gameplay_modes::ModeId::MAIN)
.unwrap();
let start_zone = map.models.get(mode.get_start().get() as usize).unwrap();
let start_offset = glam::Vec3::from_array(
start_zone
.transform
.translation
.map(|f| f.into())
.to_array(),
);
// setup graphics
let mut g = InputGenerator::new(&map);
// setup simulation
let mut session = Session {
geometry_shared: PhysicsData::new(&map),
simulation: PhysicsState::default(),
recording: Recording {
instructions: Vec::new(),
},
};
let mut time = PhysicsTime::ZERO;
// reset to start zone
session.run(time, PhysicsInputInstruction::Mode(ModeInstruction::Reset));
// session.run(
// time,
// PhysicsInputInstruction::Misc(MiscInstruction::SetSensitivity(?)),
// );
session.run(
time,
PhysicsInputInstruction::Mode(ModeInstruction::Restart(
strafesnet_common::gameplay_modes::ModeId::MAIN,
)),
);
// TEMP: turn mouse left
let mut mouse_pos = glam::ivec2(-5300, 0);
const STEP: PhysicsTime = PhysicsTime::from_millis(10);
let mut input_floats = Vec::new();
let mut depth_floats = Vec::new();
// setup agent-simulation feedback loop
for _ in 0..20 * 100 {
// generate inputs
let frame_state = session.get_frame_state();
g.generate_inputs(
frame_state.pos(time) - start_offset,
frame_state.vel(time),
frame_state.angles(),
&mut input_floats,
&mut depth_floats,
);
// inference
let inputs = Tensor::from_data(
TensorData::new(
input_floats.clone(),
Shape::new([1, POSITION_HISTORY_SIZE * POSITION_HISTORY]),
),
&device,
);
let depth = Tensor::from_data(
TensorData::new(
depth_floats.clone(),
Shape::new([1, 1, SIZE.y as usize, SIZE.x as usize]),
),
&device,
);
let outputs = model
.forward(inputs, depth)
.into_data()
.into_vec::<f32>()
.unwrap();
let &[
move_forward,
move_left,
move_back,
move_right,
jump,
mouse_dx,
mouse_dy,
] = outputs.as_slice()
else {
panic!()
};
macro_rules! set_control {
($control:ident,$output:expr) => {
session.run(
time,
PhysicsInputInstruction::SetControl(SetControlInstruction::$control(
0.5 < $output,
)),
);
};
}
set_control!(SetMoveForward, move_forward);
set_control!(SetMoveLeft, move_left);
set_control!(SetMoveBack, move_back);
set_control!(SetMoveRight, move_right);
set_control!(SetJump, jump);
mouse_pos += glam::vec2(mouse_dx, mouse_dy).round().as_ivec2();
let next_time = time + STEP;
session.run(
time,
PhysicsInputInstruction::Mouse(MouseInstruction::SetNextMouse(MouseState {
pos: mouse_pos,
time: next_time,
})),
);
time = next_time;
// clear
depth_floats.clear();
input_floats.clear();
}
let file = std::fs::File::create(output_file).unwrap();
strafesnet_snf::bot::write_bot(
std::io::BufWriter::new(file),
strafesnet_physics::VERSION.get(),
core::mem::take(&mut session.recording.instructions),
)
.unwrap();
}

View File

@@ -1,178 +0,0 @@
const LIMITS: wgpu::Limits = wgpu::Limits::defaults();
const FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Rgba8UnormSrgb;
use strafesnet_graphics::setup;
use crate::net::{HistoricState, POSITION_HISTORY, POSITION_HISTORY_SIZE, SIZE};
// bytes_per_row needs to be a multiple of 256.
const STRIDE_SIZE: u32 = (SIZE.x * size_of::<f32>() as u32).next_multiple_of(256);
pub struct InputGenerator {
device: wgpu::Device,
queue: wgpu::Queue,
graphics: strafesnet_roblox_bot_player::graphics::Graphics,
graphics_texture_view: wgpu::TextureView,
output_staging_buffer: wgpu::Buffer,
texture_data: Vec<u8>,
position_history: Vec<HistoricState>,
}
impl InputGenerator {
pub fn new(map: &strafesnet_common::map::CompleteMap) -> Self {
let desc = wgpu::InstanceDescriptor::new_without_display_handle_from_env();
let instance = wgpu::Instance::new(desc);
let (device, queue) = pollster::block_on(async {
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
force_fallback_adapter: false,
compatible_surface: None,
})
.await
.unwrap();
setup::step4::request_device(&adapter, LIMITS)
.await
.unwrap()
});
let mut graphics = strafesnet_roblox_bot_player::graphics::Graphics::new(
&device, &queue, SIZE, FORMAT, LIMITS,
);
graphics.change_map(&device, &queue, map).unwrap();
let graphics_texture = device.create_texture(&wgpu::TextureDescriptor {
label: Some("RGB texture"),
format: FORMAT,
size: wgpu::Extent3d {
width: SIZE.x,
height: SIZE.y,
depth_or_array_layers: 1,
},
mip_level_count: 1,
sample_count: 1,
dimension: wgpu::TextureDimension::D2,
usage: wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
view_formats: &[],
});
let graphics_texture_view = graphics_texture.create_view(&wgpu::TextureViewDescriptor {
label: Some("RGB texture view"),
aspect: wgpu::TextureAspect::All,
usage: Some(
wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
),
..Default::default()
});
let texture_data = Vec::<u8>::with_capacity((STRIDE_SIZE * SIZE.y) as usize);
let output_staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Output staging buffer"),
size: texture_data.capacity() as u64,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let position_history = Vec::with_capacity(POSITION_HISTORY);
Self {
device,
queue,
graphics,
graphics_texture_view,
output_staging_buffer,
texture_data,
position_history,
}
}
pub fn generate_inputs(
&mut self,
pos: glam::Vec3,
vel: glam::Vec3,
angles: glam::Vec2,
inputs: &mut Vec<f32>,
depth: &mut Vec<f32>,
) {
// write position history to model inputs
let camera =
strafesnet_graphics::graphics::view_inv(pos, glam::vec2(angles.x, 0.0)).inverse();
for state in self.position_history.iter().rev() {
let relative_pos = camera.project_point3(state.pos);
let relative_vel = camera.transform_vector3(state.vel);
let relative_ang = glam::vec2(angles.x - state.angles.x, state.angles.y);
inputs.extend_from_slice(&relative_pos.to_array());
inputs.extend_from_slice(&relative_vel.to_array());
inputs.extend_from_slice(&relative_ang.to_array());
}
// fill remaining history with zeroes
inputs.extend(core::iter::repeat_n(
0.0,
POSITION_HISTORY_SIZE * (POSITION_HISTORY - self.position_history.len()),
));
// track position history
if self.position_history.len() < POSITION_HISTORY {
self.position_history
.push(HistoricState { pos, vel, angles });
} else {
self.position_history.rotate_left(1);
*self.position_history.last_mut().unwrap() = HistoricState { pos, vel, angles };
}
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("wgpu encoder"),
});
// render!
self.graphics
.encode_commands(&mut encoder, &self.graphics_texture_view, pos, angles);
// copy the depth texture into ram
encoder.copy_texture_to_buffer(
wgpu::TexelCopyTextureInfo {
texture: self.graphics.depth_texture(),
mip_level: 0,
origin: wgpu::Origin3d::ZERO,
aspect: wgpu::TextureAspect::All,
},
wgpu::TexelCopyBufferInfo {
buffer: &self.output_staging_buffer,
layout: wgpu::TexelCopyBufferLayout {
offset: 0,
// This needs to be a multiple of 256.
bytes_per_row: Some(STRIDE_SIZE),
rows_per_image: Some(SIZE.y),
},
},
wgpu::Extent3d {
width: SIZE.x,
height: SIZE.y,
depth_or_array_layers: 1,
},
);
self.queue.submit([encoder.finish()]);
// map buffer
let buffer_slice = self.output_staging_buffer.slice(..);
let (sender, receiver) = std::sync::mpsc::channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |r| sender.send(r).unwrap());
self.device
.poll(wgpu::PollType::wait_indefinitely())
.unwrap();
receiver.recv().unwrap().unwrap();
// copy texture inside a scope so the mapped view gets dropped
{
let view = buffer_slice.get_mapped_range();
self.texture_data.extend_from_slice(&view[..]);
}
self.output_staging_buffer.unmap();
// discombolulate stride
for y in 0..SIZE.y {
depth.extend(
self.texture_data[(STRIDE_SIZE * y) as usize
..(STRIDE_SIZE * y + SIZE.x * size_of::<f32>() as u32) as usize]
.chunks_exact(4)
.map(|b| 2.0 * (1.0 - f32::from_le_bytes(b.try_into().unwrap()))),
)
}
self.texture_data.clear();
}
}

503
src/lib.rs Normal file
View File

@@ -0,0 +1,503 @@
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use strafesnet_common::instruction::TimedInstruction;
use strafesnet_common::mouse::MouseState;
use strafesnet_common::physics::{
Instruction as PhysicsInputInstruction, ModeInstruction, MouseInstruction,
SetControlInstruction, Time as PhysicsTime,
};
use strafesnet_physics::physics::{PhysicsContext, PhysicsData, PhysicsState};
use strafesnet_roblox_bot_player::head::Time as PlaybackTime;
const LIMITS: wgpu::Limits = wgpu::Limits::defaults();
const FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Rgba8UnormSrgb;
const SIZE_X: u32 = 64;
const SIZE_Y: u32 = 36;
const SIZE: glam::UVec2 = glam::uvec2(SIZE_X, SIZE_Y);
const DEPTH_PIXELS: usize = (SIZE_X * SIZE_Y) as usize;
const STRIDE_SIZE: u32 = (SIZE_X * size_of::<f32>() as u32).next_multiple_of(256);
const POSITION_HISTORY: usize = 10;
const POSITION_HISTORY_SIZE: usize = size_of::<HistoricState>() / size_of::<f32>();
const STEP: PhysicsTime = PhysicsTime::from_millis(10);
fn float_to_bool(float: f32) -> bool {
0.5 < float
}
struct HistoricState {
pos: glam::Vec3,
vel: glam::Vec3,
angles: glam::Vec2,
}
/// The Rust-side environment that wraps physics simulation and depth rendering.
/// Exposed to Python via PyO3.
#[pyclass]
struct StrafeEnv {
// rendering
wgpu_device: wgpu::Device,
wgpu_queue: wgpu::Queue,
graphics: strafesnet_roblox_bot_player::graphics::Graphics,
graphics_texture_view: wgpu::TextureView,
output_staging_buffer: wgpu::Buffer,
texture_data: Vec<u8>,
// simulation
geometry: PhysicsData,
simulation: PhysicsState,
time: PhysicsTime,
mouse_pos: glam::IVec2,
start_offset: glam::Vec3,
// scoring
bot: strafesnet_roblox_bot_player::bot::CompleteBot,
bvh: strafesnet_roblox_bot_player::bvh::Bvh,
wr_time: PhysicsTime,
best_progress: PlaybackTime,
run_finished: bool,
// position history
position_history: Vec<HistoricState>,
// map data (for platform visualization)
map: strafesnet_common::map::CompleteMap,
}
#[pymethods]
impl StrafeEnv {
/// Create a new environment from a map file path.
#[new]
fn new(map_path: &str, bot_path: &str) -> PyResult<Self> {
let map_file = std::fs::read(map_path)
.map_err(|e| PyRuntimeError::new_err(format!("Failed to read map: {e}")))?;
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
.map_err(|e| PyRuntimeError::new_err(format!("Failed to decode map: {e:?}")))?
.into_complete_map()
.map_err(|e| PyRuntimeError::new_err(format!("Failed to load map: {e:?}")))?;
let bot_file = std::fs::read(bot_path)
.map_err(|e| PyRuntimeError::new_err(format!("Failed to read bot: {e}")))?;
let timelines =
strafesnet_roblox_bot_file::v0::read_all_to_block(std::io::Cursor::new(bot_file))
.map_err(|e| PyRuntimeError::new_err(format!("Failed to decode bot: {e:?}")))?;
let bot = strafesnet_roblox_bot_player::bot::CompleteBot::new(timelines)
.map_err(|e| PyRuntimeError::new_err(format!("Failed to init bot: {e:?}")))?;
let bvh = strafesnet_roblox_bot_player::bvh::Bvh::new(&bot);
let wr_time = bot
.run_duration(strafesnet_roblox_bot_file::v0::ModeID(0))
.unwrap();
// compute start offset
let modes = map.modes.clone().denormalize();
let mode = modes
.get_mode(strafesnet_common::gameplay_modes::ModeId::MAIN)
.ok_or_else(|| PyRuntimeError::new_err("No MAIN mode in map"))?;
let start_zone = map
.models
.get(mode.get_start().get() as usize)
.ok_or_else(|| PyRuntimeError::new_err("Start zone not found"))?;
let start_offset = glam::Vec3::from_array(
start_zone
.transform
.translation
.map(|f| f.into())
.to_array(),
);
// init wgpu
let desc = wgpu::InstanceDescriptor::new_without_display_handle_from_env();
let instance = wgpu::Instance::new(desc);
let (wgpu_device, wgpu_queue) = pollster::block_on(async {
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
force_fallback_adapter: false,
compatible_surface: None,
})
.await
.unwrap();
strafesnet_graphics::setup::step4::request_device(&adapter, LIMITS)
.await
.unwrap()
});
let mut graphics = strafesnet_roblox_bot_player::graphics::Graphics::new(
&wgpu_device,
&wgpu_queue,
SIZE,
FORMAT,
LIMITS,
);
graphics
.change_map(&wgpu_device, &wgpu_queue, &map)
.unwrap();
let graphics_texture = wgpu_device.create_texture(&wgpu::TextureDescriptor {
label: Some("RGB texture"),
format: FORMAT,
size: wgpu::Extent3d {
width: SIZE_X,
height: SIZE_Y,
depth_or_array_layers: 1,
},
mip_level_count: 1,
sample_count: 1,
dimension: wgpu::TextureDimension::D2,
usage: wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
view_formats: &[],
});
let graphics_texture_view = graphics_texture.create_view(&wgpu::TextureViewDescriptor {
label: Some("RGB texture view"),
aspect: wgpu::TextureAspect::All,
usage: Some(
wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
),
..Default::default()
});
let texture_data = Vec::<u8>::with_capacity((STRIDE_SIZE * SIZE_Y) as usize);
let output_staging_buffer = wgpu_device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Output staging buffer"),
size: texture_data.capacity() as u64,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let geometry = PhysicsData::new(&map);
let simulation = PhysicsState::default();
let mut env = Self {
wgpu_device,
wgpu_queue,
graphics,
graphics_texture_view,
output_staging_buffer,
texture_data,
geometry,
simulation,
time: PhysicsTime::ZERO,
mouse_pos: glam::ivec2(-5300, 0),
start_offset,
position_history: Vec::with_capacity(POSITION_HISTORY),
bot,
bvh,
best_progress: PlaybackTime::ZERO,
wr_time: wr_time.coerce(),
run_finished: false,
map,
};
// initial reset
env.do_reset();
Ok(env)
}
/// Reset the environment. Returns (position_history, depth) as flat f32 arrays.
fn reset(&mut self) -> (Vec<f32>, Vec<f32>) {
self.do_reset();
self.get_observation()
}
/// Take one step with the given action.
/// action: [move_forward, move_left, move_back, move_right, jump, mouse_dx, mouse_dy]
/// Returns (position_history, depth, done)
fn step(&mut self, action: Vec<f32>) -> PyResult<(Vec<f32>, Vec<f32>, bool)> {
let &[
move_forward,
move_left,
move_back,
move_right,
jump,
mouse_dx,
mouse_dy,
] = action.as_slice()
else {
return Err(PyRuntimeError::new_err("Action must have 7 elements"));
};
// apply controls
self.run_instruction(PhysicsInputInstruction::SetControl(
SetControlInstruction::SetMoveForward(float_to_bool(move_forward)),
));
self.run_instruction(PhysicsInputInstruction::SetControl(
SetControlInstruction::SetMoveLeft(float_to_bool(move_left)),
));
self.run_instruction(PhysicsInputInstruction::SetControl(
SetControlInstruction::SetMoveBack(float_to_bool(move_back)),
));
self.run_instruction(PhysicsInputInstruction::SetControl(
SetControlInstruction::SetMoveRight(float_to_bool(move_right)),
));
self.run_instruction(PhysicsInputInstruction::SetControl(
SetControlInstruction::SetJump(float_to_bool(jump)),
));
// apply mouse
self.mouse_pos += glam::vec2(mouse_dx, mouse_dy).round().as_ivec2();
let next_time = self.time + STEP;
self.run_instruction(PhysicsInputInstruction::Mouse(
MouseInstruction::SetNextMouse(MouseState {
pos: self.mouse_pos,
time: next_time,
}),
));
self.time = next_time;
// immediately idle to advance time before get_observation
self.run_instruction(PhysicsInputInstruction::Idle);
let (pos_hist, depth) = self.get_observation();
// done on completion or timeout (world record not beaten)
let done = self.run_finished || self.wr_time < self.time;
Ok((pos_hist, depth, done))
}
fn reward(&mut self) -> f32 {
// this is going int -> float -> int but whatever.
let point = self
.simulation
.body()
.position
.map(Into::<f32>::into)
.to_array()
.into();
let time = self.bvh.closest_time_to_point(&self.bot, point).unwrap();
let progress = self.bot.playback_time(time);
let mut reward = 0.0;
// reward incremental progress
if self.best_progress < progress {
let diff = progress - self.best_progress;
self.best_progress = progress;
let diff_f32: f32 = diff.into();
reward += diff_f32;
};
// reward completion
if !self.run_finished {
if let Some(finish_time) = self.simulation.get_finish_time() {
let finish_time: f32 = finish_time.into();
let wr_time: f32 = self.wr_time.into();
// more reward for completing longer maps
reward += wr_time * wr_time / finish_time;
self.run_finished = true;
}
}
reward
}
/// Get physics state: [vel_x, vel_y, vel_z]
fn get_velocity(&self) -> Vec<f32> {
self.simulation
.body()
.velocity
.map(Into::<f32>::into)
.to_array()
.to_vec()
}
/// Get the current position as [x, y, z].
fn get_position(&self) -> Vec<f32> {
let trajectory = self.simulation.camera_trajectory(&self.geometry);
let pos = trajectory
.extrapolated_position(self.time)
.map(Into::<f32>::into)
.to_array();
pos.to_vec()
}
/// Get the WR replay path as list of [x, y, z] positions.
fn get_wr_path(&self) -> Vec<Vec<f32>> {
self.bot
.timelines()
.output_events
.iter()
.map(|event| {
vec![
event.event.position.x,
event.event.position.y,
event.event.position.z,
]
})
.collect()
}
/// Get map model positions as list of [x, y, z] (translation only).
fn get_map_positions(&self) -> Vec<Vec<f32>> {
self.map
.models
.iter()
.map(|model| {
model
.transform
.translation
.map(|f| f.into())
.to_array()
.to_vec()
})
.collect()
}
/// Get observation dimensions.
#[staticmethod]
fn observation_sizes() -> (usize, usize) {
(POSITION_HISTORY * POSITION_HISTORY_SIZE, DEPTH_PIXELS)
}
/// Get action size.
#[staticmethod]
fn action_size() -> usize {
7
}
}
impl StrafeEnv {
fn do_reset(&mut self) {
self.simulation = PhysicsState::default();
self.time = PhysicsTime::ZERO;
self.mouse_pos = glam::ivec2(-5300, 0);
self.position_history.clear();
self.best_progress = PlaybackTime::ZERO;
self.run_finished = false;
self.run_instruction(PhysicsInputInstruction::Mode(ModeInstruction::Reset));
self.run_instruction(PhysicsInputInstruction::Mode(ModeInstruction::Restart(
strafesnet_common::gameplay_modes::ModeId::MAIN,
)));
}
fn run_instruction(&mut self, instruction: PhysicsInputInstruction) {
PhysicsContext::run_input_instruction(
&mut self.simulation,
&self.geometry,
TimedInstruction {
time: self.time,
instruction,
},
);
}
fn get_observation(&mut self) -> (Vec<f32>, Vec<f32>) {
let trajectory = self.simulation.camera_trajectory(&self.geometry);
let pos = glam::Vec3::from_array(
trajectory
.extrapolated_position(self.time)
.map(Into::<f32>::into)
.to_array(),
) - self.start_offset;
let vel: glam::Vec3 = trajectory
.extrapolated_velocity(self.time)
.map(Into::<f32>::into)
.to_array()
.into();
let camera = self.simulation.camera();
let angles = camera.simulate_move_angles(glam::IVec2::ZERO);
// build position history input
let mut pos_hist = Vec::with_capacity(POSITION_HISTORY * POSITION_HISTORY_SIZE);
let cam_matrix =
strafesnet_graphics::graphics::view_inv(pos, glam::vec2(angles.x, 0.0)).inverse();
for state in self.position_history.iter().rev() {
let relative_pos = cam_matrix.transform_point3(state.pos);
let relative_vel = cam_matrix.transform_vector3(state.vel);
let relative_ang = glam::vec2(angles.x - state.angles.x, state.angles.y);
pos_hist.extend_from_slice(&relative_pos.to_array());
pos_hist.extend_from_slice(&relative_vel.to_array());
pos_hist.extend_from_slice(&relative_ang.to_array());
}
// pad remaining history with zeros
pos_hist.extend(core::iter::repeat_n(
0.0,
POSITION_HISTORY_SIZE * (POSITION_HISTORY - self.position_history.len()),
));
// update position history
if self.position_history.len() < POSITION_HISTORY {
self.position_history
.push(HistoricState { pos, vel, angles });
} else {
self.position_history.rotate_left(1);
*self.position_history.last_mut().unwrap() = HistoricState { pos, vel, angles };
}
// render depth
let depth = self.render_depth(pos, angles);
(pos_hist, depth)
}
fn render_depth(&mut self, pos: glam::Vec3, angles: glam::Vec2) -> Vec<f32> {
let mut encoder =
self.wgpu_device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("depth encoder"),
});
self.graphics
.encode_commands(&mut encoder, &self.graphics_texture_view, pos, angles);
encoder.copy_texture_to_buffer(
wgpu::TexelCopyTextureInfo {
texture: self.graphics.depth_texture(),
mip_level: 0,
origin: wgpu::Origin3d::ZERO,
aspect: wgpu::TextureAspect::All,
},
wgpu::TexelCopyBufferInfo {
buffer: &self.output_staging_buffer,
layout: wgpu::TexelCopyBufferLayout {
offset: 0,
bytes_per_row: Some(STRIDE_SIZE),
rows_per_image: Some(SIZE_Y),
},
},
wgpu::Extent3d {
width: SIZE_X,
height: SIZE_Y,
depth_or_array_layers: 1,
},
);
self.wgpu_queue.submit([encoder.finish()]);
let buffer_slice = self.output_staging_buffer.slice(..);
let (sender, receiver) = std::sync::mpsc::channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |r| sender.send(r).unwrap());
self.wgpu_device
.poll(wgpu::PollType::wait_indefinitely())
.unwrap();
receiver.recv().unwrap().unwrap();
let mut depth = Vec::with_capacity(DEPTH_PIXELS);
{
let view = buffer_slice.get_mapped_range();
self.texture_data.extend_from_slice(&view[..]);
}
self.output_staging_buffer.unmap();
for y in 0..SIZE_Y {
depth.extend(
self.texture_data[(STRIDE_SIZE * y) as usize
..(STRIDE_SIZE * y + SIZE_X * size_of::<f32>() as u32) as usize]
.chunks_exact(4)
.map(|b| 1.0 - 2.0 * f32::from_le_bytes(b.try_into().unwrap())),
);
}
self.texture_data.clear();
depth
}
}
/// Python module definition
#[pymodule]
fn _rust(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<StrafeEnv>()?;
Ok(())
}

View File

@@ -1,32 +0,0 @@
#![recursion_limit = "256"]
use clap::{Parser, Subcommand};
mod inference;
mod inputs;
mod net;
mod training;
#[derive(Parser)]
#[command(author,version,about,long_about=None)]
#[command(propagate_version = true)]
struct Cli {
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand)]
enum Commands {
#[command(flatten)]
Roblox(inference::Commands),
#[command(flatten)]
Source(training::Commands),
}
fn main() {
let cli = Cli::parse();
match cli.command {
Commands::Roblox(commands) => commands.run(),
Commands::Source(commands) => commands.run(),
}
}

View File

@@ -1,95 +0,0 @@
use burn::backend::Autodiff;
use burn::nn::conv::{Conv2d, Conv2dConfig};
use burn::nn::pool::{MaxPool2d, MaxPool2dConfig};
use burn::nn::{Dropout, DropoutConfig, Linear, LinearConfig, PaddingConfig2d, Relu};
use burn::prelude::*;
pub type InferenceBackend = burn::backend::Cuda<f32>;
pub type TrainingBackend = Autodiff<InferenceBackend>;
pub const SIZE: glam::UVec2 = glam::uvec2(64, 36);
pub const DEPTH_SIZE: usize = (SIZE.x * SIZE.y) as usize;
pub const POSITION_HISTORY: usize = 10;
pub const POSITION_HISTORY_SIZE: usize = size_of::<HistoricState>() / size_of::<f32>();
const CONV1_SIZE: usize = 8;
const CONV2_SIZE: usize = 16;
const INPUT: usize = ((SIZE.x >> 2) * (SIZE.y >> 2)) as usize * CONV2_SIZE
+ POSITION_HISTORY_SIZE * POSITION_HISTORY;
pub const HIDDEN: [usize; 3] = [INPUT >> 3, INPUT >> 5, INPUT >> 7];
// MoveForward
// MoveLeft
// MoveBack
// MoveRight
// Jump
// mouse_dx
// mouse_dy
pub const OUTPUT: usize = 7;
pub struct HistoricState {
pub pos: glam::Vec3,
pub vel: glam::Vec3,
pub angles: glam::Vec2,
}
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
input: Linear<B>,
conv1: Conv2d<B>,
conv2: Conv2d<B>,
pool: MaxPool2d,
dropout: Dropout,
hidden: [Linear<B>; HIDDEN.len() - 1],
output: Linear<B>,
activation: Relu,
}
impl<B: Backend> Net<B> {
pub fn init(device: &B::Device) -> Self {
let mut it = HIDDEN.into_iter();
let mut last_size = it.next().unwrap();
let input = LinearConfig::new(INPUT, last_size).init(device);
let hidden = core::array::from_fn(|_| {
let size = it.next().unwrap();
let layer = LinearConfig::new(last_size, size).init(device);
last_size = size;
layer
});
let output = LinearConfig::new(last_size, OUTPUT).init(device);
let conv1 = Conv2dConfig::new([1, CONV1_SIZE], [3, 3])
.with_padding(PaddingConfig2d::Same)
.init(device);
let conv2 = Conv2dConfig::new([CONV1_SIZE, CONV2_SIZE], [3, 3])
.with_padding(PaddingConfig2d::Same)
.init(device);
let pool = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
let dropout = DropoutConfig::new(0.1).init();
Self {
input,
conv1,
conv2,
pool,
dropout,
hidden,
output,
activation: Relu::new(),
}
}
pub fn forward(&self, input: Tensor<B, 2>, depth: Tensor<B, 4>) -> Tensor<B, 2> {
let x = self.conv1.forward(depth);
let x = self.activation.forward(x);
let x = self.pool.forward(x);
let x = self.dropout.forward(x);
let x = self.conv2.forward(x);
let x = self.activation.forward(x);
let x = self.pool.forward(x);
let x = self.dropout.forward(x);
let x = x.flatten(1, 3);
let x = Tensor::cat(vec![input, x], 1);
let x = self.input.forward(x);
let mut x = self.activation.forward(x);
for layer in &self.hidden {
x = layer.forward(x);
x = self.activation.forward(x);
}
self.output.forward(x)
}
}

View File

@@ -1,236 +0,0 @@
#[derive(clap::Subcommand)]
pub enum Commands {
Train(TrainSubcommand),
}
impl Commands {
pub fn run(self) {
match self {
Commands::Train(subcommand) => subcommand.run(),
}
}
}
#[derive(clap::Args)]
pub struct TrainSubcommand {
#[arg(long)]
gpu_id: Option<usize>,
#[arg(long)]
epochs: Option<usize>,
#[arg(long)]
learning_rate: Option<f64>,
#[arg(long)]
map_file: std::path::PathBuf,
#[arg(long)]
bot_file: std::path::PathBuf,
}
impl TrainSubcommand {
fn run(self) {
training(
self.gpu_id.unwrap_or_default(),
self.epochs.unwrap_or(100_000),
self.learning_rate.unwrap_or(0.001),
self.map_file,
self.bot_file,
);
}
}
use burn::nn::loss::{MseLoss, Reduction};
use burn::optim::{AdamConfig, GradientsParams, Optimizer};
use burn::prelude::*;
use crate::inputs::InputGenerator;
use crate::net::{
DEPTH_SIZE, Net, OUTPUT, POSITION_HISTORY, POSITION_HISTORY_SIZE, SIZE, TrainingBackend,
};
use strafesnet_roblox_bot_file::v0;
fn training(
gpu_id: usize,
epochs: usize,
learning_rate: f64,
map_file: std::path::PathBuf,
bot_file: std::path::PathBuf,
) {
// read files
let map_file = std::fs::read(map_file).unwrap();
let bot_file = std::fs::read(bot_file).unwrap();
// load map
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
.unwrap()
.into_complete_map()
.unwrap();
// load replay
let timelines =
strafesnet_roblox_bot_file::v0::read_all_to_block(std::io::Cursor::new(bot_file)).unwrap();
let bot = strafesnet_roblox_bot_player::bot::CompleteBot::new(timelines).unwrap();
let world_offset = bot.world_offset();
let timelines = bot.timelines();
// set up graphics
let mut g = InputGenerator::new(&map);
// training data
let training_samples = timelines.input_events.len() - 1;
let input_size = POSITION_HISTORY_SIZE * size_of::<f32>();
let depth_size = DEPTH_SIZE * size_of::<f32>();
let mut inputs = Vec::with_capacity(input_size * training_samples);
let mut depth = Vec::with_capacity(depth_size * training_samples);
let mut targets = Vec::with_capacity(OUTPUT * training_samples);
// generate all frames
println!("Generating {training_samples} frames of depth textures...");
let mut it = timelines.input_events.iter();
// grab mouse position from first frame, omitting one frame from the training data
let first = it.next().unwrap();
let mut last_mx = first.event.mouse_pos.x;
let mut last_my = first.event.mouse_pos.y;
for input_event in it {
let mouse_dx = input_event.event.mouse_pos.x - last_mx;
let mouse_dy = input_event.event.mouse_pos.y - last_my;
last_mx = input_event.event.mouse_pos.x;
last_my = input_event.event.mouse_pos.y;
// set targets
targets.extend([
// MoveForward
input_event
.event
.game_controls
.contains(v0::GameControls::MoveForward) as i32 as f32,
// MoveLeft
input_event
.event
.game_controls
.contains(v0::GameControls::MoveLeft) as i32 as f32,
// MoveBack
input_event
.event
.game_controls
.contains(v0::GameControls::MoveBack) as i32 as f32,
// MoveRight
input_event
.event
.game_controls
.contains(v0::GameControls::MoveRight) as i32 as f32,
// Jump
input_event
.event
.game_controls
.contains(v0::GameControls::Jump) as i32 as f32,
mouse_dx,
mouse_dy,
]);
// find the closest output event to the input event time
let output_event_index = timelines
.output_events
.binary_search_by(|event| event.time.partial_cmp(&input_event.time).unwrap());
let output_event = match output_event_index {
// found the exact same timestamp
Ok(output_event_index) => &timelines.output_events[output_event_index],
// found first index greater than the time.
// check this index and the one before and return the closest one
Err(insert_index) => timelines
.output_events
.get(insert_index)
.into_iter()
.chain(
insert_index
.checked_sub(1)
.and_then(|index| timelines.output_events.get(index)),
)
.min_by(|&e0, &e1| {
(e0.time - input_event.time)
.abs()
.partial_cmp(&(e1.time - input_event.time).abs())
.unwrap()
})
.unwrap(),
};
fn vec3(v: v0::Vector3) -> glam::Vec3 {
glam::vec3(v.x, v.y, v.z)
}
fn angles(a: v0::Vector3) -> glam::Vec2 {
glam::vec2(a.y, a.x)
}
let pos = vec3(output_event.event.position) - world_offset;
let vel = vec3(output_event.event.velocity);
let angles = angles(output_event.event.angles);
g.generate_inputs(pos, vel, angles, &mut inputs, &mut depth);
}
let device = burn::backend::cuda::CudaDevice::new(gpu_id);
let mut model: Net<TrainingBackend> = Net::init(&device);
let num_params = model.num_params();
println!("Training model ({} parameters)", num_params);
let mut optim = AdamConfig::new().init();
let inputs = Tensor::from_data(
TensorData::new(
inputs,
Shape::new([training_samples, POSITION_HISTORY_SIZE * POSITION_HISTORY]),
),
&device,
);
let depth = Tensor::from_data(
TensorData::new(
depth,
Shape::new([training_samples, 1, SIZE.y as usize, SIZE.x as usize]),
),
&device,
);
let targets = Tensor::from_data(
TensorData::new(targets, Shape::new([training_samples, OUTPUT])),
&device,
);
let mut best_model = model.clone();
let mut best_loss = f32::INFINITY;
for epoch in 0..epochs {
let predictions = model.forward(inputs.clone(), depth.clone());
let loss = MseLoss::new().forward(predictions, targets.clone(), Reduction::Mean);
let loss_scalar = loss.clone().into_scalar();
if epoch == 0 {
// kinda a fake print, but that's what is happening after this point
println!("Compiling optimized GPU kernels...");
}
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
// get the best model
if loss_scalar < best_loss {
best_loss = loss_scalar;
best_model = model.clone();
}
model = optim.step(learning_rate, model, grads);
if epoch % (epochs >> 4) == 0 || epoch == epochs - 1 {
println!(" epoch {epoch:>5} | loss = {loss_scalar:.8} | best_loss = {best_loss:.8}");
}
}
let date_string = format!("{}_{}.model", num_params, best_loss);
best_model
.save_file(
date_string,
&burn::record::BinFileRecorder::<burn::record::FullPrecisionSettings>::new(),
)
.unwrap();
}

4
strafe_ai/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
from strafe_ai.environment import StrafeEnvironment
from strafe_ai.model import StrafeFeaturesExtractor
__all__ = ["StrafeEnvironment", "StrafeFeaturesExtractor"]

99
strafe_ai/environment.py Normal file
View File

@@ -0,0 +1,99 @@
"""
Gymnasium-compatible environment wrapping the Rust physics sim + depth renderer.
Usage:
env = StrafeEnvironment("map.snfm", "replay.qbot")
obs, info = env.reset()
obs, reward, terminated, truncated, info = env.step(action)
"""
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from strafe_ai._rust import StrafeEnv
class StrafeEnvironment(gym.Env):
"""
Bhop environment with continuous action space.
Observations:
- depth: (1, 36, 64) — depth buffer as single-channel image for CNN
- position_history: (80,) — 10 entries x 8 floats (pos + vel + angles)
Actions: 7 continuous floats
[forward, left, back, right, jump, mouse_dx, mouse_dy]
First 5 are thresholded at 0.5 for binary controls.
mouse_dx/dy are scaled by 1/sqrt(speed) in the environment.
"""
metadata = {"render_modes": ["none"]}
def __init__(self, map_path: str, bot_path: str):
super().__init__()
self._env = StrafeEnv(map_path, bot_path)
pos_size, depth_size = StrafeEnv.observation_sizes()
self.observation_space = spaces.Dict({
"depth": spaces.Box(-1.0, 1.0, shape=(1, 36, 64), dtype=np.float32),
"position_history": spaces.Box(-np.inf, np.inf, shape=(pos_size,), dtype=np.float32),
})
# 2 movement axes + jump + mouse yaw = 4 floats
self.action_space = spaces.Box(
low=np.array([-1, -1, -1, -1], dtype=np.float32),
high=np.array([1, 1, 1, 1], dtype=np.float32),
)
def _make_obs(self, pos_hist, depth):
return {
"depth": np.array(depth, dtype=np.float32).reshape(1, 36, 64),
"position_history": np.array(pos_hist, dtype=np.float32),
}
def reset(self, seed=None, options=None):
super().reset(seed=seed)
pos_hist, depth = self._env.reset()
return self._make_obs(pos_hist, depth), {}
def step(self, action):
# forward/back axis: >0 = forward, <0 = back
fb = float(action[0])
fwd = 1.0 if fb > 0 else 0.0
back = 1.0 if fb < 0 else 0.0
# left/right axis: >0 = right, <0 = left
lr = float(action[1])
right = 1.0 if lr > 0 else 0.0
left = 1.0 if lr < 0 else 0.0
# jump
jump = 1.0 if action[2] > 0 else 0.0
# mouse yaw only, no pitch
mouse_dx = float(action[3]) * 300.0
mouse_dy = 0.0
controls = [fwd, left, back, right, jump, mouse_dx, mouse_dy]
pos_hist, depth, done = self._env.step(controls)
obs = self._make_obs(pos_hist, depth)
# progress along WR path (primary signal)
reward = self._env.reward() * 100.0
# speed bonus that fades out over training
# self._step_count tracks total lifetime steps across all episodes
self._lifetime_steps = getattr(self, '_lifetime_steps', 0)
speed_weight = 1.0 - self._lifetime_steps / 500_000 # fades to 0 over 500k steps
if speed_weight > 0:
vel = self._env.get_velocity()
hspeed = (vel[0] ** 2 + vel[2] ** 2) ** 0.5
reward += hspeed * speed_weight
return obs, reward, done, False, {}
def get_position(self):
return np.array(self._env.get_position(), dtype=np.float32)

90
strafe_ai/model.py Normal file
View File

@@ -0,0 +1,90 @@
"""
Custom feature extractor for SB3 — ResNet-style CNN on depth buffer,
MLP on position history (which includes velocity).
"""
import torch
import torch.nn as nn
import gymnasium as gym
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
class ResBlock(nn.Module):
"""Residual block with skip connection."""
def __init__(self, channels):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
residual = x
x = torch.relu(self.bn1(self.conv1(x)))
x = self.bn2(self.conv2(x))
return torch.relu(x + residual)
class StrafeFeaturesExtractor(BaseFeaturesExtractor):
"""
Processes the multi-input observation:
- depth (1, 36, 64) -> ResNet CNN -> 256 features
- position_history (80,) -> MLP -> 64 features
Concatenated -> 320 features total
"""
def __init__(self, observation_space: gym.spaces.Dict):
features_dim = 320
super().__init__(observation_space, features_dim=features_dim)
depth_shape = observation_space["depth"].shape # (1, H, W)
h, w = depth_shape[1], depth_shape[2]
# ResNet-style CNN for depth buffer
self.depth_cnn = nn.Sequential(
# initial conv: 1 -> 32 channels, downsample
nn.Conv2d(1, 32, kernel_size=5, stride=2, padding=2),
nn.BatchNorm2d(32),
nn.ReLU(),
# residual block at 32 channels
ResBlock(32),
# downsample: 32 -> 64 channels
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
# residual block at 64 channels
ResBlock(64),
# downsample: 64 -> 128 channels
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
# residual block at 128 channels
ResBlock(128),
# global average pooling -> 128 features regardless of input size
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(128, 256),
nn.ReLU(),
)
# MLP for position history (includes velocity + angles)
pos_size = observation_space["position_history"].shape[0]
self.position_mlp = nn.Sequential(
nn.Linear(pos_size, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
)
def forward(self, observations: dict) -> torch.Tensor:
depth_features = self.depth_cnn(observations["depth"])
pos_features = self.position_mlp(observations["position_history"])
return torch.cat([depth_features, pos_features], dim=1)

294
strafe_ai/train.py Normal file
View File

@@ -0,0 +1,294 @@
"""
RL training script for strafe-ai using PPO.
Usage:
python -m strafe_ai.train --map-file map.snfm --bot-file replay.qbot
"""
import argparse
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
from strafe_ai.environment import StrafeEnvironment
from strafe_ai.model import StrafeFeaturesExtractor
class LogCallback(BaseCallback):
"""Print progress every N steps."""
def __init__(self, print_freq=10000):
super().__init__()
self.print_freq = print_freq
self.best_reward = 0.0
def _on_step(self):
if self.num_timesteps % self.print_freq < self.training_env.num_envs:
infos = self.locals.get("infos", [])
rewards = [info.get("episode", {}).get("r", 0) for info in infos if "episode" in info]
if rewards:
mean_r = sum(rewards) / len(rewards)
self.best_reward = max(self.best_reward, max(rewards))
print(
f" step {self.num_timesteps:>8d} | "
f"mean_reward = {mean_r:.2f} | "
f"best = {self.best_reward:.2f}"
)
return True
class VideoCallback(BaseCallback):
"""Record a depth-frame video of the agent playing every N steps."""
def __init__(self, map_file, bot_file, freq=250_000, video_steps=500):
super().__init__()
self.map_file = map_file
self.bot_file = bot_file
self.freq = freq
self.video_steps = video_steps
self.writer = None
self.next_video_at = freq
def _on_training_start(self):
# write to a separate subdirectory to avoid conflicting with SB3's writer
log_dir = (self.logger.dir or "runs/strafe-ai") + "_videos"
self.writer = SummaryWriter(log_dir)
print(f" [video] Callback active. Writer at: {log_dir}")
print(f" [video] Will record every {self.freq} steps")
print(f" [video] First video at step {self.next_video_at}")
def _on_step(self):
return True
def _on_rollout_end(self):
"""Called after each rollout collection."""
ts = self.num_timesteps
if ts % 50000 < 2048:
print(f" [video] rollout_end: timesteps={ts}, next_video_at={self.next_video_at}")
if ts >= self.next_video_at:
print(f" [video] Triggered at timestep {ts}")
self.next_video_at = ts + self.freq
try:
self._record_video()
except Exception as e:
print(f" [video] ERROR: {e}")
import traceback
traceback.print_exc()
def _on_training_end(self):
if self.writer:
self.writer.close()
def _record_video(self):
print(f" [video] Recording at step {self.num_timesteps}...")
import io
from PIL import Image, ImageDraw
from tensorboard.compat.proto.summary_pb2 import Summary
# create a fresh env for recording
env = StrafeEnvironment(self.map_file, self.bot_file)
obs, _ = env.reset()
pil_frames = []
path_positions = []
for _ in range(self.video_steps):
# capture depth frame
depth = obs["depth"]
frame = depth.reshape(36, 64)
# auto-contrast for visualization: stretch min/max to full 0-255 range
fmin, fmax = frame.min(), frame.max()
if fmax > fmin:
frame = (frame - fmin) / (fmax - fmin) * 255.0
else:
frame = np.zeros_like(frame)
img = Image.fromarray(frame.astype(np.uint8)).resize((256, 144), Image.Resampling.NEAREST)
pil_frames.append(img)
# record position
pos = env.get_position()
path_positions.append((pos[0], pos[1], pos[2]))
action, _ = self.model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, _ = env.step(action)
if terminated or truncated:
obs, _ = env.reset()
# --- Depth GIF ---
buf = io.BytesIO()
pil_frames[0].save(
buf, format="GIF", save_all=True,
append_images=pil_frames[1:], duration=33, loop=0,
)
buf.seek(0)
image_summary = Summary.Image(
encoded_image_string=buf.read(), height=144, width=256, colorspace=1,
)
summary = Summary(value=[Summary.Value(tag="agent/depth_view", image=image_summary)])
self.writer.file_writer.add_summary(summary, self.num_timesteps)
# --- Top-down map with path trace ---
map_img = self._draw_topdown_map(env, path_positions)
buf = io.BytesIO()
map_img.save(buf, format="PNG")
buf.seek(0)
map_summary = Summary.Image(
encoded_image_string=buf.read(),
height=map_img.height,
width=map_img.width,
colorspace=3,
)
summary = Summary(value=[Summary.Value(tag="agent/topdown_map", image=map_summary)])
self.writer.file_writer.add_summary(summary, self.num_timesteps)
self.writer.flush()
print(f" [video] Saved GIF + map at step {self.num_timesteps}")
def _draw_topdown_map(self, env, path_positions, img_size=512):
"""Draw a top-down map with platforms, WR path, and the agent's path."""
from PIL import Image, ImageDraw
platforms = env._env.get_map_positions()
wr_path = env._env.get_wr_path()
all_x = [p[0] for p in platforms] + [p[0] for p in path_positions] + [p[0] for p in wr_path]
all_z = [p[2] for p in platforms] + [p[2] for p in path_positions] + [p[2] for p in wr_path]
if not all_x:
return Image.new("RGB", (img_size, img_size), (30, 30, 30))
min_x, max_x = min(all_x) - 50, max(all_x) + 50
min_z, max_z = min(all_z) - 50, max(all_z) + 50
range_x = max_x - min_x
range_z = max_z - min_z
scale = (img_size - 20) / max(range_x, range_z)
def to_pixel(x, z):
px = int((x - min_x) * scale) + 10
pz = int((z - min_z) * scale) + 10
return px, pz
img = Image.new("RGB", (img_size, img_size), (30, 30, 30))
draw = ImageDraw.Draw(img)
# platforms
for p in platforms:
px, pz = to_pixel(p[0], p[2])
r = max(2, int(5 * scale))
draw.rectangle([px - r, pz - r, px + r, pz + r], fill=(80, 80, 120))
# WR path as dotted blue line
if len(wr_path) > 1:
wr_points = [to_pixel(p[0], p[2]) for p in wr_path]
for i in range(0, len(wr_points) - 1, 2):
if i + 1 < len(wr_points):
draw.line([wr_points[i], wr_points[i + 1]], fill=(50, 120, 255), width=1)
# agent path as solid red line
if len(path_positions) > 1:
points = [to_pixel(p[0], p[2]) for p in path_positions]
draw.line(points, fill=(255, 50, 50), width=2)
sx, sz = points[0]
draw.ellipse([sx - 4, sz - 4, sx + 4, sz + 4], fill=(50, 255, 50))
ex, ez = points[-1]
draw.ellipse([ex - 4, ez - 4, ex + 4, ez + 4], fill=(255, 255, 50))
return img
def make_env(map_file, bot_file):
"""Factory function for creating environments."""
def _init():
return Monitor(StrafeEnvironment(map_file, bot_file))
return _init
def train(map_file, bot_file, total_timesteps, lr, device_name, n_envs, resume_from=None):
print(f"Map: {map_file}")
print(f"Bot: {bot_file}")
print(f"Device: {device_name}")
print(f"Environments: {n_envs}")
if n_envs > 1:
env = SubprocVecEnv([make_env(map_file, bot_file) for _ in range(n_envs)])
else:
env = DummyVecEnv([make_env(map_file, bot_file)])
if resume_from:
print(f"Resuming from: {resume_from}")
model = PPO.load(resume_from, env=env, device=device_name,
tensorboard_log="runs/")
else:
policy_kwargs = dict(
features_extractor_class=StrafeFeaturesExtractor,
net_arch=[256, 128],
log_std_init=-0.5, # start with std ≈ 0.6
)
model = PPO(
"MultiInputPolicy",
env,
learning_rate=lr,
n_steps=2048 // n_envs,
batch_size=256,
n_epochs=10,
gamma=0.99,
gae_lambda=0.95,
clip_range=0.2,
ent_coef=0.01,
policy_kwargs=policy_kwargs,
verbose=1,
tensorboard_log="runs/",
device=device_name,
)
print(f"Policy: {sum(p.numel() for p in model.policy.parameters()):,} parameters")
print(f"Training for {total_timesteps:,} timesteps...")
checkpoint_cb = CheckpointCallback(
save_freq=250_000 // n_envs,
save_path="checkpoints/",
name_prefix="ppo_strafe",
)
video_cb = VideoCallback(
map_file=map_file,
bot_file=bot_file,
freq=250_000,
video_steps=500,
)
model.learn(
total_timesteps=total_timesteps,
callback=[LogCallback(), checkpoint_cb, video_cb],
tb_log_name="strafe-ai",
)
model.save("ppo_strafe_ai")
print("Saved model to ppo_strafe_ai.zip")
env.close()
def main():
parser = argparse.ArgumentParser(description="Train strafe-ai with PPO")
parser.add_argument("--map-file", required=True, help="Path to .snfm map file")
parser.add_argument("--bot-file", required=True, help="Path to .qbot WR replay file")
parser.add_argument("--timesteps", type=int, default=10_000_000)
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--n-envs", type=int, default=1, help="Number of parallel environments")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--resume", default=None, help="Path to checkpoint to resume from (no .zip)")
args = parser.parse_args()
train(args.map_file, args.bot_file, args.timesteps, args.lr, args.device, args.n_envs, args.resume)
if __name__ == "__main__":
main()

81
strafe_ai/watch.py Normal file
View File

@@ -0,0 +1,81 @@
"""
Watch the agent play — prints actions and position each step.
Usage:
python -m strafe_ai.watch --map-file map.snfm --bot-file replay.qbot [--model ppo_strafe_ai]
"""
import argparse
import numpy as np
from strafe_ai.environment import StrafeEnvironment
def watch(map_file, bot_file, model_path, steps):
env = StrafeEnvironment(map_file, bot_file)
model = None
if model_path:
from stable_baselines3 import PPO
model = PPO.load(model_path)
print(f"Loaded model from {model_path}")
else:
print("No model - using random actions")
obs, _ = env.reset()
total_reward = 0.0
print(f"{'step':>5} | {'fwd':>3} {'lft':>3} {'bck':>3} {'rgt':>3} {'jmp':>3} | "
f"{'m_dx':>6} {'m_dy':>6} | {'pos_x':>8} {'pos_y':>8} {'pos_z':>8} | "
f"{'vel_x':>7} {'vel_y':>7} {'vel_z':>7} | {'speed':>6} | {'reward':>7}")
print("-" * 120)
for step in range(steps):
if model:
action, _ = model.predict(obs, deterministic=True)
else:
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
total_reward += reward
pos = env.get_position()
vel = env._env.get_velocity()
speed = (vel[0] ** 2 + vel[2] ** 2) ** 0.5
fb = float(action[0])
lr = float(action[1])
fwd = "W" if fb > 0 else "."
bck = "S" if fb < 0 else "."
lft = "A" if lr < 0 else "."
rgt = "D" if lr > 0 else "."
jmp = "^" if action[2] > 0 else "."
m_dx = action[3] * 100.0
m_dy = 0.0
print(f"{step:5d} | {fwd:>3} {lft:>3} {bck:>3} {rgt:>3} {jmp:>3} | "
f"{m_dx:6.1f} {m_dy:6.1f} | "
f"{pos[0]:8.1f} {pos[1]:8.1f} {pos[2]:8.1f} | "
f"{vel[0]:7.1f} {vel[1]:7.1f} {vel[2]:7.1f} | "
f"{speed:6.1f} | {reward:7.3f}")
if terminated or truncated:
print(f"\n--- Episode ended at step {step} | Total reward: {total_reward:.2f} ---\n")
obs, _ = env.reset()
total_reward = 0.0
print(f"\nFinal total reward: {total_reward:.2f}")
def main():
parser = argparse.ArgumentParser(description="Watch the agent play")
parser.add_argument("--map-file", required=True)
parser.add_argument("--bot-file", required=True)
parser.add_argument("--model", default=None, help="Path to saved PPO model (no .zip)")
parser.add_argument("--steps", type=int, default=500)
args = parser.parse_args()
watch(args.map_file, args.bot_file, args.model, args.steps)
if __name__ == "__main__":
main()