70 Commits

Author SHA1 Message Date
Cameron Grant
4573aaa5ae Add dynamic speed bonus to reward function and adjust PPO training parameters. 2026-03-31 13:51:17 -07:00
Cameron Grant
76ff41203d Add support for resuming training from a saved checkpoint. Edited values. 2026-03-31 13:45:30 -07:00
87c3967caa Revert "work around strafe tick bug"
This reverts commit e23455bba2.
2026-03-31 13:00:26 -07:00
4916b76126 update physics 2026-03-31 13:00:26 -07:00
Cameron Grant
b30b43ed9d Added skip connections 2026-03-31 12:57:34 -07:00
Cameron Grant
b6b208975c Refactor action space to continuous movement and simplify mouse controls. 2026-03-31 12:51:44 -07:00
Cameron Grant
2795090f13 Fixed issues with watch helper 2026-03-31 12:51:25 -07:00
Cameron Grant
5564c6529a Added auto contrast to previews. 2026-03-31 12:43:06 -07:00
Cameron Grant
eb62ac2099 Continuous actions 2026-03-31 11:46:27 -07:00
Cameron Grant
5e385f806c Added CNN 2026-03-31 11:40:13 -07:00
a030afb018 remove extra rewards 2026-03-31 11:31:48 -07:00
e23455bba2 work around strafe tick bug 2026-03-31 11:20:09 -07:00
f7521d49f2 idle so get_observation shows next state rather than this state 2026-03-31 10:43:47 -07:00
ee02c9cbda if statement has bad cost benefit 2026-03-31 10:05:10 -07:00
2c59742799 simplify get_map_positions 2026-03-31 09:58:48 -07:00
24bcb63d0e add velocity to position history, fix position history, use zero camera pitch 2026-03-31 09:52:46 -07:00
Cameron Grant
ccb4fb5791 Draw topdown map to tensorboard during training. 2026-03-30 20:56:20 -07:00
Cameron Grant
0444c6d68a Add watch.py for visualizing agent performance; update Python dependencies in poetry.lock. 2026-03-30 20:39:50 -07:00
bd2f60fb72 fix reset 2026-03-30 16:23:15 -07:00
4dd3201192 reward horizontal speed 2026-03-30 16:20:14 -07:00
44c8c53122 deny reward when pressing A and D 2026-03-30 16:16:08 -07:00
aaa5a158e8 allow playback to run until WR time elapses 2026-03-30 16:10:45 -07:00
Cameron Grant
cb59737985 Optimize PPO training parameters and switch to SubprocVecEnv for multi-environment setups. 2026-03-30 15:56:45 -07:00
Cameron Grant
5ba65ba4d0 Implement reward shaping for bhop mechanics and adjust PPO training parameters 2026-03-30 15:43:12 -07:00
eeb935dcd6 note stable-baselines3 in readme 2026-03-30 15:11:23 -07:00
Cameron Grant
89a398f1f6 Replace custom RL training loop with Stable-Baselines3 PPO implementation. Update training script and dependencies accordingly. 2026-03-30 15:09:17 -07:00
Cameron Grant
9b1f61b128 Refactor training pipeline to implement RL loop with policy gradient (REINFORCE). 2026-03-30 14:51:38 -07:00
Cameron Grant
ae02fbba79 Update simulation completion logic to include run_finished status. 2026-03-30 14:43:13 -07:00
Cameron Grant
679962024f Reset best_progress and run_finished in physics state initialization. 2026-03-30 14:41:48 -07:00
c11910b33e handsomely reward run completion 2026-03-30 14:29:09 -07:00
e8845ce28d incremental reward 2026-03-30 14:15:56 -07:00
73dcb93d5e more steps 2026-03-30 14:08:31 -07:00
3b00541644 add reward 2026-03-30 14:08:31 -07:00
b29d5f3845 normal torch ok 2026-03-30 13:15:25 -07:00
48cd49bd43 document setup using uv + fish 2026-03-30 13:13:19 -07:00
8ac74d36f0 CompleteMap is only needed for construction 2026-03-30 12:46:28 -07:00
9995e852d4 update ndarray 2026-03-30 12:40:16 -07:00
Cameron Grant
0a1c8068fe Add rustfmt.toml and improve code formatting in lib.rs. 2026-03-30 12:39:55 -07:00
792078121b CompleteMap is only needed for construction 2026-03-30 12:26:34 -07:00
df3e813dd9 fix float to bool 2026-03-30 12:25:37 -07:00
cea0bcbaf3 Use slice deconstruction 2026-03-30 12:18:42 -07:00
e79d0378ac Use helper function 2026-03-30 12:16:57 -07:00
Cameron Grant
d907672daa Add poetry.lock file for Python dependency management. 2026-03-30 12:10:13 -07:00
Cameron Grant
899278ff64 Add Cargo.lock file for Rust dependency management. 2026-03-30 11:53:59 -07:00
Cameron Grant
5da88a0f69 Converted full project to PyTorch. 2026-03-30 11:39:04 -07:00
96c21fffa9 separate depth from inputs 2026-03-28 08:44:22 -07:00
357e0f4a20 print best loss 2026-03-28 08:36:16 -07:00
31bfa208f8 include angles in history 2026-03-28 08:26:03 -07:00
1d09378bfd silence lint 2026-03-28 08:08:01 -07:00
bf2bf6d693 dropout first 2026-03-28 07:37:04 -07:00
a144ff1178 fix file name shenanigans 2026-03-27 19:33:19 -07:00
48f9657d0f don't hardcode map and bot 2026-03-27 16:46:02 -07:00
e38c0a92b4 remove chrono dep 2026-03-27 16:28:30 -07:00
148471dce1 simulate: add output_file argument 2026-03-27 16:28:30 -07:00
7bf439395b write model name based on num params 2026-03-27 16:16:26 -07:00
03f5eb5c13 tweak model 2026-03-27 16:04:03 -07:00
1e7bb6c4ce format 2026-03-27 15:57:32 -07:00
d8b0f9abbb rename GraphicsState to InputGenerator 2026-03-27 15:57:17 -07:00
fb8c6e2492 training options 2026-03-27 15:56:15 -07:00
18cad85b62 add cli args 2026-03-27 15:56:15 -07:00
b195a7eb95 split code into modules 2026-03-27 15:46:51 -07:00
4208090da0 add clap dep 2026-03-27 15:32:17 -07:00
e31b148f41 simulator 2026-03-27 15:29:32 -07:00
a05113baa5 feed position history into model inputs 2026-03-27 15:28:53 -07:00
9ad8a70ad0 hardcode depth "normalization" 2026-03-27 15:13:29 -07:00
e890623f2e add dropout to input 2026-03-27 15:00:30 -07:00
7d55e872e7 save model with current date 2026-03-27 11:56:45 -07:00
1e1cbeb180 graphics state 2026-03-27 11:56:45 -07:00
e19c46d851 save best model 2026-03-27 11:25:34 -07:00
59bb8eee12 implement training 2026-03-27 11:13:23 -07:00
14 changed files with 3701 additions and 6506 deletions

28
.gitignore vendored
View File

@@ -1 +1,29 @@
# Rust
/target
# Python
__pycache__
*.pyc
*.egg-info
.eggs
dist
build
.venv
# Data files
/files
*.snfm
*.qbot
*.snfb
*.model
*.bin
# TensorBoard
runs/
# IDE / tools
.claude
.idea
_rust.*
ppo_strafe_*_steps.zip

6520
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,16 +1,22 @@
[package]
name = "strafe-ai"
name = "strafe-ai-py"
version = "0.1.0"
edition = "2024"
[lib]
name = "_rust"
crate-type = ["cdylib"]
[dependencies]
burn = { version = "0.20.1", features = ["cuda", "autodiff"] }
pyo3 = { version = "0.28", features = ["extension-module"] }
numpy = "0.28"
glam = "0.32.1"
pollster = "0.4.0"
wgpu = "29.0.0"
strafesnet_common = { version = "0.9.0", registry = "strafesnet" }
strafesnet_graphics = { version = "0.0.10", registry = "strafesnet" }
strafesnet_physics = { version = "=0.0.2-surf", registry = "strafesnet" }
strafesnet_graphics = { version = "=0.0.11-depth2", registry = "strafesnet" }
strafesnet_physics = { version = "=0.0.2-surf2", registry = "strafesnet" }
strafesnet_roblox_bot_file = { version = "0.9.4", registry = "strafesnet" }
strafesnet_roblox_bot_player = { version = "0.6.1", registry = "strafesnet" }
strafesnet_roblox_bot_player = { version = "=0.6.2-depth2", registry = "strafesnet" }
strafesnet_snf = { version = "0.4.0", registry = "strafesnet" }
pollster = "0.4.0"

77
README.md Normal file
View File

@@ -0,0 +1,77 @@
# strafe-ai-py
PyTorch + Rust environment for training an AI to play bhop maps.
## Architecture
- **Rust** (via PyO3): Physics simulation + depth rendering — fast, deterministic
- **Python** (PyTorch): Model, training, RL algorithms, TensorBoard logging
The Rust side exposes a [Gymnasium](https://gymnasium.farama.org/)-compatible environment,
so any standard RL library (Stable Baselines3, CleanRL, etc.) works out of the box.
## Setup
Requires:
- Python 3.10+
- Rust toolchain
- CUDA toolkit
- [maturin](https://github.com/PyO3/maturin) for building the Rust extension
```bash
# create a virtual environment
uv venv
source .venv/bin/activate.fish # or .venv\Scripts\activate on Windows
# install maturin
uv pip install torch
uv pip install maturin gymnasium numpy tensorboard stable-baselines3
# build and install the Rust extension + Python deps
maturin develop --release
# or install with RL extras
pip install -e ".[rl]"
```
## Usage
```bash
# test the environment
python -m strafe_ai.train --map-file path/to/map.snfm
# tensorboard
tensorboard --logdir runs
```
## Project Structure
```
strafe_ai/ Python package
environment.py Gymnasium env wrapping Rust sim
model.py PyTorch model (StrafeNet)
train.py Training script
src/
lib.rs PyO3 bindings (physics + rendering)
```
## Environment API
```python
from strafe_ai import StrafeEnvironment
env = StrafeEnvironment("map.snfm")
obs, info = env.reset()
# obs["position_history"] — (50,) float32 — 10 recent positions + angles
# obs["depth"] — (2304,) float32 — 64x36 depth buffer
action = env.action_space.sample() # 7 floats
obs, reward, terminated, truncated, info = env.step(action)
```
## Next Steps
- [ ] Implement reward function (curve_dt along WR path)
- [ ] Train with PPO (Stable Baselines3)
- [ ] Add .qbot replay loading for imitation learning pretraining

2324
poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

21
pyproject.toml Normal file
View File

@@ -0,0 +1,21 @@
[build-system]
requires = ["maturin>=1.5,<2.0"]
build-backend = "maturin"
[project]
name = "strafe-ai"
version = "0.1.0"
requires-python = ">=3.10"
dependencies = [
"torch>=2.0",
"gymnasium",
"numpy",
"tensorboard",
"stable-baselines3",
"pillow",
]
[tool.maturin]
python-source = "."
features = ["pyo3/extension-module"]
module-name = "strafe_ai._rust"

View File

@@ -1 +1 @@
hard_tabs = true
hard_tabs = true

502
src/lib.rs Normal file
View File

@@ -0,0 +1,502 @@
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use strafesnet_common::instruction::TimedInstruction;
use strafesnet_common::mouse::MouseState;
use strafesnet_common::physics::{
Instruction as PhysicsInputInstruction, ModeInstruction, MouseInstruction,
SetControlInstruction, Time as PhysicsTime,
};
use strafesnet_physics::physics::{PhysicsContext, PhysicsData, PhysicsState};
use strafesnet_roblox_bot_player::head::Time as PlaybackTime;
const LIMITS: wgpu::Limits = wgpu::Limits::defaults();
const FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Rgba8UnormSrgb;
const SIZE_X: u32 = 64;
const SIZE_Y: u32 = 36;
const SIZE: glam::UVec2 = glam::uvec2(SIZE_X, SIZE_Y);
const DEPTH_PIXELS: usize = (SIZE_X * SIZE_Y) as usize;
const STRIDE_SIZE: u32 = (SIZE_X * size_of::<f32>() as u32).next_multiple_of(256);
const POSITION_HISTORY: usize = 10;
const POSITION_HISTORY_SIZE: usize = size_of::<HistoricState>() / size_of::<f32>();
const STEP: PhysicsTime = PhysicsTime::from_millis(10);
fn float_to_bool(float: f32) -> bool {
0.5 < float
}
struct HistoricState {
pos: glam::Vec3,
vel: glam::Vec3,
angles: glam::Vec2,
}
/// The Rust-side environment that wraps physics simulation and depth rendering.
/// Exposed to Python via PyO3.
#[pyclass]
struct StrafeEnv {
// rendering
wgpu_device: wgpu::Device,
wgpu_queue: wgpu::Queue,
graphics: strafesnet_roblox_bot_player::graphics::Graphics,
graphics_texture_view: wgpu::TextureView,
output_staging_buffer: wgpu::Buffer,
texture_data: Vec<u8>,
// simulation
geometry: PhysicsData,
simulation: PhysicsState,
time: PhysicsTime,
mouse_pos: glam::IVec2,
start_offset: glam::Vec3,
// scoring
bot: strafesnet_roblox_bot_player::bot::CompleteBot,
bvh: strafesnet_roblox_bot_player::bvh::Bvh,
wr_time: PhysicsTime,
best_progress: PlaybackTime,
run_finished: bool,
// position history
position_history: Vec<HistoricState>,
// map data (for platform visualization)
map: strafesnet_common::map::CompleteMap,
}
#[pymethods]
impl StrafeEnv {
/// Create a new environment from a map file path.
#[new]
fn new(map_path: &str, bot_path: &str) -> PyResult<Self> {
let map_file = std::fs::read(map_path)
.map_err(|e| PyRuntimeError::new_err(format!("Failed to read map: {e}")))?;
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
.map_err(|e| PyRuntimeError::new_err(format!("Failed to decode map: {e:?}")))?
.into_complete_map()
.map_err(|e| PyRuntimeError::new_err(format!("Failed to load map: {e:?}")))?;
let bot_file = std::fs::read(bot_path)
.map_err(|e| PyRuntimeError::new_err(format!("Failed to read bot: {e}")))?;
let timelines =
strafesnet_roblox_bot_file::v0::read_all_to_block(std::io::Cursor::new(bot_file))
.map_err(|e| PyRuntimeError::new_err(format!("Failed to decode bot: {e:?}")))?;
let bot = strafesnet_roblox_bot_player::bot::CompleteBot::new(timelines)
.map_err(|e| PyRuntimeError::new_err(format!("Failed to init bot: {e:?}")))?;
let bvh = strafesnet_roblox_bot_player::bvh::Bvh::new(&bot);
let wr_time = bot
.run_duration(strafesnet_roblox_bot_file::v0::ModeID(0))
.unwrap();
// compute start offset
let modes = map.modes.clone().denormalize();
let mode = modes
.get_mode(strafesnet_common::gameplay_modes::ModeId::MAIN)
.ok_or_else(|| PyRuntimeError::new_err("No MAIN mode in map"))?;
let start_zone = map
.models
.get(mode.get_start().get() as usize)
.ok_or_else(|| PyRuntimeError::new_err("Start zone not found"))?;
let start_offset = glam::Vec3::from_array(
start_zone
.transform
.translation
.map(|f| f.into())
.to_array(),
);
// init wgpu
let desc = wgpu::InstanceDescriptor::new_without_display_handle_from_env();
let instance = wgpu::Instance::new(desc);
let (wgpu_device, wgpu_queue) = pollster::block_on(async {
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
force_fallback_adapter: false,
compatible_surface: None,
})
.await
.unwrap();
strafesnet_graphics::setup::step4::request_device(&adapter, LIMITS)
.await
.unwrap()
});
let mut graphics = strafesnet_roblox_bot_player::graphics::Graphics::new(
&wgpu_device,
&wgpu_queue,
SIZE,
FORMAT,
LIMITS,
);
graphics
.change_map(&wgpu_device, &wgpu_queue, &map)
.unwrap();
let graphics_texture = wgpu_device.create_texture(&wgpu::TextureDescriptor {
label: Some("RGB texture"),
format: FORMAT,
size: wgpu::Extent3d {
width: SIZE_X,
height: SIZE_Y,
depth_or_array_layers: 1,
},
mip_level_count: 1,
sample_count: 1,
dimension: wgpu::TextureDimension::D2,
usage: wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
view_formats: &[],
});
let graphics_texture_view = graphics_texture.create_view(&wgpu::TextureViewDescriptor {
label: Some("RGB texture view"),
aspect: wgpu::TextureAspect::All,
usage: Some(
wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
),
..Default::default()
});
let texture_data = Vec::<u8>::with_capacity((STRIDE_SIZE * SIZE_Y) as usize);
let output_staging_buffer = wgpu_device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Output staging buffer"),
size: texture_data.capacity() as u64,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let geometry = PhysicsData::new(&map);
let simulation = PhysicsState::default();
let mut env = Self {
wgpu_device,
wgpu_queue,
graphics,
graphics_texture_view,
output_staging_buffer,
texture_data,
geometry,
simulation,
time: PhysicsTime::ZERO,
mouse_pos: glam::ivec2(-5300, 0),
start_offset,
position_history: Vec::with_capacity(POSITION_HISTORY),
bot,
bvh,
best_progress: PlaybackTime::ZERO,
wr_time: wr_time.coerce(),
run_finished: false,
map,
};
// initial reset
env.do_reset();
Ok(env)
}
/// Reset the environment. Returns (position_history, depth) as flat f32 arrays.
fn reset(&mut self) -> (Vec<f32>, Vec<f32>) {
self.do_reset();
self.get_observation()
}
/// Take one step with the given action.
/// action: [move_forward, move_left, move_back, move_right, jump, mouse_dx, mouse_dy]
/// Returns (position_history, depth, done)
fn step(&mut self, action: Vec<f32>) -> PyResult<(Vec<f32>, Vec<f32>, bool)> {
let &[
move_forward,
move_left,
move_back,
move_right,
jump,
mouse_dx,
mouse_dy,
] = action.as_slice()
else {
return Err(PyRuntimeError::new_err("Action must have 7 elements"));
};
// apply controls
self.run_instruction(PhysicsInputInstruction::SetControl(
SetControlInstruction::SetMoveForward(float_to_bool(move_forward)),
));
self.run_instruction(PhysicsInputInstruction::SetControl(
SetControlInstruction::SetMoveLeft(float_to_bool(move_left)),
));
self.run_instruction(PhysicsInputInstruction::SetControl(
SetControlInstruction::SetMoveBack(float_to_bool(move_back)),
));
self.run_instruction(PhysicsInputInstruction::SetControl(
SetControlInstruction::SetMoveRight(float_to_bool(move_right)),
));
self.run_instruction(PhysicsInputInstruction::SetControl(
SetControlInstruction::SetJump(float_to_bool(jump)),
));
// apply mouse
self.mouse_pos += glam::vec2(mouse_dx, mouse_dy).round().as_ivec2();
let next_time = self.time + STEP;
self.run_instruction(PhysicsInputInstruction::Mouse(
MouseInstruction::SetNextMouse(MouseState {
pos: self.mouse_pos,
time: next_time,
}),
));
self.time = next_time;
// immediately idle to advance time before get_observation
self.run_instruction(PhysicsInputInstruction::Idle);
let (pos_hist, depth) = self.get_observation();
// done on completion or timeout (world record not beaten)
let done = self.run_finished || self.wr_time < self.time;
Ok((pos_hist, depth, done))
}
fn reward(&mut self) -> f32 {
// this is going int -> float -> int but whatever.
let point = self
.simulation
.body()
.position
.map(Into::<f32>::into)
.to_array()
.into();
let time = self.bvh.closest_time_to_point(&self.bot, point).unwrap();
let progress = self.bot.playback_time(time);
let mut reward = 0.0;
// reward incremental progress
if self.best_progress < progress {
let diff = progress - self.best_progress;
self.best_progress = progress;
let diff_f32: f32 = diff.into();
reward += diff_f32;
};
// reward completion
if !self.run_finished {
if let Some(finish_time) = self.simulation.get_finish_time() {
let finish_time: f32 = finish_time.into();
let wr_time: f32 = self.wr_time.into();
// more reward for completing longer maps
reward += wr_time * wr_time / finish_time;
self.run_finished = true;
}
}
reward
}
/// Get physics state: [vel_x, vel_y, vel_z]
fn get_velocity(&self) -> Vec<f32> {
self.simulation
.body()
.velocity
.map(Into::<f32>::into)
.to_array()
.to_vec()
}
/// Get the current position as [x, y, z].
fn get_position(&self) -> Vec<f32> {
let trajectory = self.simulation.camera_trajectory(&self.geometry);
let pos = trajectory
.extrapolated_position(self.time)
.map(Into::<f32>::into)
.to_array();
pos.to_vec()
}
/// Get the WR replay path as list of [x, y, z] positions.
fn get_wr_path(&self) -> Vec<Vec<f32>> {
self.bot
.timelines()
.output_events
.iter()
.map(|event| {
vec![
event.event.position.x,
event.event.position.y,
event.event.position.z,
]
})
.collect()
}
/// Get map model positions as list of [x, y, z] (translation only).
fn get_map_positions(&self) -> Vec<Vec<f32>> {
self.map
.models
.iter()
.map(|model| {
model
.transform
.translation
.map(|f| f.into())
.to_array()
.to_vec()
})
.collect()
}
/// Get observation dimensions.
#[staticmethod]
fn observation_sizes() -> (usize, usize) {
(POSITION_HISTORY * POSITION_HISTORY_SIZE, DEPTH_PIXELS)
}
/// Get action size.
#[staticmethod]
fn action_size() -> usize {
7
}
}
impl StrafeEnv {
fn do_reset(&mut self) {
self.simulation = PhysicsState::default();
self.time = PhysicsTime::ZERO;
self.mouse_pos = glam::ivec2(-5300, 0);
self.position_history.clear();
self.best_progress = PlaybackTime::ZERO;
self.run_finished = false;
self.run_instruction(PhysicsInputInstruction::Mode(ModeInstruction::Reset));
self.run_instruction(PhysicsInputInstruction::Mode(ModeInstruction::Restart(
strafesnet_common::gameplay_modes::ModeId::MAIN,
)));
}
fn run_instruction(&mut self, instruction: PhysicsInputInstruction) {
PhysicsContext::run_input_instruction(
&mut self.simulation,
&self.geometry,
TimedInstruction {
time: self.time,
instruction,
},
);
}
fn get_observation(&mut self) -> (Vec<f32>, Vec<f32>) {
let trajectory = self.simulation.camera_trajectory(&self.geometry);
let pos = glam::Vec3::from_array(
trajectory
.extrapolated_position(self.time)
.map(Into::<f32>::into)
.to_array(),
) - self.start_offset;
let vel: glam::Vec3 = trajectory
.extrapolated_velocity(self.time)
.map(Into::<f32>::into)
.to_array()
.into();
let camera = self.simulation.camera();
let angles = camera.simulate_move_angles(glam::IVec2::ZERO);
// build position history input
let mut pos_hist = Vec::with_capacity(POSITION_HISTORY * POSITION_HISTORY_SIZE);
let cam_matrix =
strafesnet_graphics::graphics::view_inv(pos, glam::vec2(angles.x, 0.0)).inverse();
for state in self.position_history.iter().rev() {
let relative_pos = cam_matrix.transform_point3(state.pos);
let relative_vel = cam_matrix.transform_vector3(state.vel);
let relative_ang = glam::vec2(angles.x - state.angles.x, state.angles.y);
pos_hist.extend_from_slice(&relative_pos.to_array());
pos_hist.extend_from_slice(&relative_vel.to_array());
pos_hist.extend_from_slice(&relative_ang.to_array());
}
// pad remaining history with zeros
for _ in self.position_history.len()..POSITION_HISTORY {
pos_hist.extend(core::iter::repeat_n(0.0, POSITION_HISTORY_SIZE));
}
// update position history
if self.position_history.len() < POSITION_HISTORY {
self.position_history
.push(HistoricState { pos, vel, angles });
} else {
self.position_history.rotate_left(1);
*self.position_history.last_mut().unwrap() = HistoricState { pos, vel, angles };
}
// render depth
let depth = self.render_depth(pos, angles);
(pos_hist, depth)
}
fn render_depth(&mut self, pos: glam::Vec3, angles: glam::Vec2) -> Vec<f32> {
let mut encoder =
self.wgpu_device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("depth encoder"),
});
self.graphics
.encode_commands(&mut encoder, &self.graphics_texture_view, pos, angles);
encoder.copy_texture_to_buffer(
wgpu::TexelCopyTextureInfo {
texture: self.graphics.depth_texture(),
mip_level: 0,
origin: wgpu::Origin3d::ZERO,
aspect: wgpu::TextureAspect::All,
},
wgpu::TexelCopyBufferInfo {
buffer: &self.output_staging_buffer,
layout: wgpu::TexelCopyBufferLayout {
offset: 0,
bytes_per_row: Some(STRIDE_SIZE),
rows_per_image: Some(SIZE_Y),
},
},
wgpu::Extent3d {
width: SIZE_X,
height: SIZE_Y,
depth_or_array_layers: 1,
},
);
self.wgpu_queue.submit([encoder.finish()]);
let buffer_slice = self.output_staging_buffer.slice(..);
let (sender, receiver) = std::sync::mpsc::channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |r| sender.send(r).unwrap());
self.wgpu_device
.poll(wgpu::PollType::wait_indefinitely())
.unwrap();
receiver.recv().unwrap().unwrap();
let mut depth = Vec::with_capacity(DEPTH_PIXELS);
{
let view = buffer_slice.get_mapped_range();
self.texture_data.extend_from_slice(&view[..]);
}
self.output_staging_buffer.unmap();
for y in 0..SIZE_Y {
depth.extend(
self.texture_data[(STRIDE_SIZE * y) as usize
..(STRIDE_SIZE * y + SIZE_X * size_of::<f32>() as u32) as usize]
.chunks_exact(4)
.map(|b| 1.0 - 2.0 * f32::from_le_bytes(b.try_into().unwrap())),
);
}
self.texture_data.clear();
depth
}
}
/// Python module definition
#[pymodule]
fn _rust(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<StrafeEnv>()?;
Ok(())
}

View File

@@ -1,147 +0,0 @@
use burn::backend::Autodiff;
use burn::module::AutodiffModule;
use burn::nn::loss::{MseLoss, Reduction};
use burn::nn::{Linear, LinearConfig, Relu, Sigmoid};
use burn::optim::{GradientsParams, Optimizer, SgdConfig};
use burn::prelude::*;
type InferenceBackend = burn::backend::Cuda<f32>;
type TrainingBackend = Autodiff<InferenceBackend>;
const LIMITS: wgpu::Limits = wgpu::Limits::defaults();
use strafesnet_common::session::Time as SessionTime;
use strafesnet_graphics::setup;
const INPUT: usize = 2;
const HIDDEN: usize = 64;
// MoveForward
// MoveLeft
// MoveBack
// MoveRight
// Jump
// mouse_dx
// mouse_dy
const OUTPUT: usize = 7;
#[derive(Module, Debug)]
struct Net<B: Backend> {
input: Linear<B>,
hidden: [Linear<B>; 64],
output: Linear<B>,
activation: Relu,
sigmoid: Sigmoid,
}
impl<B: Backend> Net<B> {
fn init(device: &B::Device) -> Self {
Self {
input: LinearConfig::new(INPUT, HIDDEN).init(device),
hidden: core::array::from_fn(|_| LinearConfig::new(HIDDEN, HIDDEN).init(device)),
output: LinearConfig::new(HIDDEN, OUTPUT).init(device),
activation: Relu::new(),
sigmoid: Sigmoid::new(),
}
}
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let x = self.input.forward(input);
let mut x = self.activation.forward(x);
for layer in &self.hidden {
x = layer.forward(x);
x = self.activation.forward(x);
}
let x = self.output.forward(x);
self.sigmoid.forward(x)
}
}
fn training() {
// load map
// load replay
// setup player
const SIZE_X: usize = 64;
const SIZE_Y: usize = 36;
let map_file = include_bytes!("../bhop_marble_5692093612.snfm");
let bot_file = include_bytes!("../bhop_marble_7cf33a64-7120-4514-b9fa-4fe29d9523d.qbot");
// read files
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
.unwrap()
.into_complete_map()
.unwrap();
let timelines =
strafesnet_roblox_bot_file::v0::read_all_to_block(std::io::Cursor::new(bot_file)).unwrap();
let bot = strafesnet_roblox_bot_player::bot::CompleteBot::new(timelines).unwrap();
let mut playback_head =
strafesnet_roblox_bot_player::head::PlaybackHead::new(&bot, SessionTime::ZERO);
// setup graphics
let desc = wgpu::InstanceDescriptor::new_without_display_handle_from_env();
let instance = wgpu::Instance::new(desc);
let (device, queue) = pollster::block_on(async {
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
force_fallback_adapter: false,
compatible_surface: None,
})
.await
.unwrap();
setup::step4::request_device(&adapter, LIMITS)
.await
.unwrap()
});
const FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Rgba8UnormSrgb;
let graphics = strafesnet_roblox_bot_player::graphics::Graphics::new(
&device,
&queue,
[SIZE_X as u32, SIZE_Y as u32].into(),
FORMAT,
LIMITS,
);
// setup simulation
// run progressively longer segments of the map, starting very close to the end of the run and working the starting time backwards until the ai can run the whole map
let device = Default::default();
let mut model: Net<TrainingBackend> = Net::init(&device);
let mut optim = SgdConfig::new().init();
let inputs = Tensor::from_floats([0.0f32; INPUT], &device);
let targets = Tensor::from_floats([0.0f32; OUTPUT], &device);
const LEARNING_RATE: f64 = 0.5;
const EPOCHS: usize = 100;
for epoch in 0..EPOCHS {
let predictions = model.forward(inputs.clone());
let loss = MseLoss::new().forward(predictions, targets.clone(), Reduction::Mean);
if epoch % (EPOCHS >> 4) == 0 || epoch == EPOCHS - 1 {
// .clone().into_scalar() extracts the f32 value from a 1-element tensor.
println!(
" epoch {:>5} | loss = {:.8}",
epoch,
loss.clone().into_scalar()
);
}
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
model = optim.step(LEARNING_RATE, model, grads);
}
}
fn inference() {
// load map
// setup simulation
// setup agent-simulation feedback loop
// go!
}
fn main() {}

4
strafe_ai/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
from strafe_ai.environment import StrafeEnvironment
from strafe_ai.model import StrafeFeaturesExtractor
__all__ = ["StrafeEnvironment", "StrafeFeaturesExtractor"]

99
strafe_ai/environment.py Normal file
View File

@@ -0,0 +1,99 @@
"""
Gymnasium-compatible environment wrapping the Rust physics sim + depth renderer.
Usage:
env = StrafeEnvironment("map.snfm", "replay.qbot")
obs, info = env.reset()
obs, reward, terminated, truncated, info = env.step(action)
"""
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from strafe_ai._rust import StrafeEnv
class StrafeEnvironment(gym.Env):
"""
Bhop environment with continuous action space.
Observations:
- depth: (1, 36, 64) — depth buffer as single-channel image for CNN
- position_history: (80,) — 10 entries x 8 floats (pos + vel + angles)
Actions: 7 continuous floats
[forward, left, back, right, jump, mouse_dx, mouse_dy]
First 5 are thresholded at 0.5 for binary controls.
mouse_dx/dy are scaled by 1/sqrt(speed) in the environment.
"""
metadata = {"render_modes": ["none"]}
def __init__(self, map_path: str, bot_path: str):
super().__init__()
self._env = StrafeEnv(map_path, bot_path)
pos_size, depth_size = StrafeEnv.observation_sizes()
self.observation_space = spaces.Dict({
"depth": spaces.Box(-1.0, 1.0, shape=(1, 36, 64), dtype=np.float32),
"position_history": spaces.Box(-np.inf, np.inf, shape=(pos_size,), dtype=np.float32),
})
# 2 movement axes + jump + mouse yaw = 4 floats
self.action_space = spaces.Box(
low=np.array([-1, -1, -1, -1], dtype=np.float32),
high=np.array([1, 1, 1, 1], dtype=np.float32),
)
def _make_obs(self, pos_hist, depth):
return {
"depth": np.array(depth, dtype=np.float32).reshape(1, 36, 64),
"position_history": np.array(pos_hist, dtype=np.float32),
}
def reset(self, seed=None, options=None):
super().reset(seed=seed)
pos_hist, depth = self._env.reset()
return self._make_obs(pos_hist, depth), {}
def step(self, action):
# forward/back axis: >0 = forward, <0 = back
fb = float(action[0])
fwd = 1.0 if fb > 0 else 0.0
back = 1.0 if fb < 0 else 0.0
# left/right axis: >0 = right, <0 = left
lr = float(action[1])
right = 1.0 if lr > 0 else 0.0
left = 1.0 if lr < 0 else 0.0
# jump
jump = 1.0 if action[2] > 0 else 0.0
# mouse yaw only, no pitch
mouse_dx = float(action[3]) * 100.0
mouse_dy = 0.0
controls = [fwd, left, back, right, jump, mouse_dx, mouse_dy]
pos_hist, depth, done = self._env.step(controls)
obs = self._make_obs(pos_hist, depth)
# progress along WR path (primary signal)
reward = self._env.reward() * 100.0
# speed bonus that fades out over training
# self._step_count tracks total lifetime steps across all episodes
self._lifetime_steps = getattr(self, '_lifetime_steps', 0) + 1
speed_weight = max(0.0, 1.0 - self._lifetime_steps / 500_000) # fades to 0 over 500k steps
if speed_weight > 0:
vel = self._env.get_velocity()
hspeed = (vel[0] ** 2 + vel[2] ** 2) ** 0.5
reward += min(hspeed * 0.1, 1.0) * speed_weight
return obs, reward, done, False, {}
def get_position(self):
return np.array(self._env.get_position(), dtype=np.float32)

90
strafe_ai/model.py Normal file
View File

@@ -0,0 +1,90 @@
"""
Custom feature extractor for SB3 — ResNet-style CNN on depth buffer,
MLP on position history (which includes velocity).
"""
import torch
import torch.nn as nn
import gymnasium as gym
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
class ResBlock(nn.Module):
"""Residual block with skip connection."""
def __init__(self, channels):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
residual = x
x = torch.relu(self.bn1(self.conv1(x)))
x = self.bn2(self.conv2(x))
return torch.relu(x + residual)
class StrafeFeaturesExtractor(BaseFeaturesExtractor):
"""
Processes the multi-input observation:
- depth (1, 36, 64) -> ResNet CNN -> 256 features
- position_history (80,) -> MLP -> 64 features
Concatenated -> 320 features total
"""
def __init__(self, observation_space: gym.spaces.Dict):
features_dim = 320
super().__init__(observation_space, features_dim=features_dim)
depth_shape = observation_space["depth"].shape # (1, H, W)
h, w = depth_shape[1], depth_shape[2]
# ResNet-style CNN for depth buffer
self.depth_cnn = nn.Sequential(
# initial conv: 1 -> 32 channels, downsample
nn.Conv2d(1, 32, kernel_size=5, stride=2, padding=2),
nn.BatchNorm2d(32),
nn.ReLU(),
# residual block at 32 channels
ResBlock(32),
# downsample: 32 -> 64 channels
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
# residual block at 64 channels
ResBlock(64),
# downsample: 64 -> 128 channels
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
# residual block at 128 channels
ResBlock(128),
# global average pooling -> 128 features regardless of input size
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(128, 256),
nn.ReLU(),
)
# MLP for position history (includes velocity + angles)
pos_size = observation_space["position_history"].shape[0]
self.position_mlp = nn.Sequential(
nn.Linear(pos_size, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
)
def forward(self, observations: dict) -> torch.Tensor:
depth_features = self.depth_cnn(observations["depth"])
pos_features = self.position_mlp(observations["position_history"])
return torch.cat([depth_features, pos_features], dim=1)

294
strafe_ai/train.py Normal file
View File

@@ -0,0 +1,294 @@
"""
RL training script for strafe-ai using PPO.
Usage:
python -m strafe_ai.train --map-file map.snfm --bot-file replay.qbot
"""
import argparse
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
from strafe_ai.environment import StrafeEnvironment
from strafe_ai.model import StrafeFeaturesExtractor
class LogCallback(BaseCallback):
"""Print progress every N steps."""
def __init__(self, print_freq=10000):
super().__init__()
self.print_freq = print_freq
self.best_reward = 0.0
def _on_step(self):
if self.num_timesteps % self.print_freq < self.training_env.num_envs:
infos = self.locals.get("infos", [])
rewards = [info.get("episode", {}).get("r", 0) for info in infos if "episode" in info]
if rewards:
mean_r = sum(rewards) / len(rewards)
self.best_reward = max(self.best_reward, max(rewards))
print(
f" step {self.num_timesteps:>8d} | "
f"mean_reward = {mean_r:.2f} | "
f"best = {self.best_reward:.2f}"
)
return True
class VideoCallback(BaseCallback):
"""Record a depth-frame video of the agent playing every N steps."""
def __init__(self, map_file, bot_file, freq=250_000, video_steps=500):
super().__init__()
self.map_file = map_file
self.bot_file = bot_file
self.freq = freq
self.video_steps = video_steps
self.writer = None
self.next_video_at = freq
def _on_training_start(self):
# write to a separate subdirectory to avoid conflicting with SB3's writer
log_dir = (self.logger.dir or "runs/strafe-ai") + "_videos"
self.writer = SummaryWriter(log_dir)
print(f" [video] Callback active. Writer at: {log_dir}")
print(f" [video] Will record every {self.freq} steps")
print(f" [video] First video at step {self.next_video_at}")
def _on_step(self):
return True
def _on_rollout_end(self):
"""Called after each rollout collection."""
ts = self.num_timesteps
if ts % 50000 < 2048:
print(f" [video] rollout_end: timesteps={ts}, next_video_at={self.next_video_at}")
if ts >= self.next_video_at:
print(f" [video] Triggered at timestep {ts}")
self.next_video_at = ts + self.freq
try:
self._record_video()
except Exception as e:
print(f" [video] ERROR: {e}")
import traceback
traceback.print_exc()
def _on_training_end(self):
if self.writer:
self.writer.close()
def _record_video(self):
print(f" [video] Recording at step {self.num_timesteps}...")
import io
from PIL import Image, ImageDraw
from tensorboard.compat.proto.summary_pb2 import Summary
# create a fresh env for recording
env = StrafeEnvironment(self.map_file, self.bot_file)
obs, _ = env.reset()
pil_frames = []
path_positions = []
for _ in range(self.video_steps):
# capture depth frame
depth = obs["depth"]
frame = depth.reshape(36, 64)
# auto-contrast for visualization: stretch min/max to full 0-255 range
fmin, fmax = frame.min(), frame.max()
if fmax > fmin:
frame = (frame - fmin) / (fmax - fmin) * 255.0
else:
frame = np.zeros_like(frame)
img = Image.fromarray(frame.astype(np.uint8)).resize((256, 144), Image.Resampling.NEAREST)
pil_frames.append(img)
# record position
pos = env.get_position()
path_positions.append((pos[0], pos[1], pos[2]))
action, _ = self.model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, _ = env.step(action)
if terminated or truncated:
obs, _ = env.reset()
# --- Depth GIF ---
buf = io.BytesIO()
pil_frames[0].save(
buf, format="GIF", save_all=True,
append_images=pil_frames[1:], duration=33, loop=0,
)
buf.seek(0)
image_summary = Summary.Image(
encoded_image_string=buf.read(), height=144, width=256, colorspace=1,
)
summary = Summary(value=[Summary.Value(tag="agent/depth_view", image=image_summary)])
self.writer.file_writer.add_summary(summary, self.num_timesteps)
# --- Top-down map with path trace ---
map_img = self._draw_topdown_map(env, path_positions)
buf = io.BytesIO()
map_img.save(buf, format="PNG")
buf.seek(0)
map_summary = Summary.Image(
encoded_image_string=buf.read(),
height=map_img.height,
width=map_img.width,
colorspace=3,
)
summary = Summary(value=[Summary.Value(tag="agent/topdown_map", image=map_summary)])
self.writer.file_writer.add_summary(summary, self.num_timesteps)
self.writer.flush()
print(f" [video] Saved GIF + map at step {self.num_timesteps}")
def _draw_topdown_map(self, env, path_positions, img_size=512):
"""Draw a top-down map with platforms, WR path, and the agent's path."""
from PIL import Image, ImageDraw
platforms = env._env.get_map_positions()
wr_path = env._env.get_wr_path()
all_x = [p[0] for p in platforms] + [p[0] for p in path_positions] + [p[0] for p in wr_path]
all_z = [p[2] for p in platforms] + [p[2] for p in path_positions] + [p[2] for p in wr_path]
if not all_x:
return Image.new("RGB", (img_size, img_size), (30, 30, 30))
min_x, max_x = min(all_x) - 50, max(all_x) + 50
min_z, max_z = min(all_z) - 50, max(all_z) + 50
range_x = max_x - min_x
range_z = max_z - min_z
scale = (img_size - 20) / max(range_x, range_z)
def to_pixel(x, z):
px = int((x - min_x) * scale) + 10
pz = int((z - min_z) * scale) + 10
return px, pz
img = Image.new("RGB", (img_size, img_size), (30, 30, 30))
draw = ImageDraw.Draw(img)
# platforms
for p in platforms:
px, pz = to_pixel(p[0], p[2])
r = max(2, int(5 * scale))
draw.rectangle([px - r, pz - r, px + r, pz + r], fill=(80, 80, 120))
# WR path as dotted blue line
if len(wr_path) > 1:
wr_points = [to_pixel(p[0], p[2]) for p in wr_path]
for i in range(0, len(wr_points) - 1, 2):
if i + 1 < len(wr_points):
draw.line([wr_points[i], wr_points[i + 1]], fill=(50, 120, 255), width=1)
# agent path as solid red line
if len(path_positions) > 1:
points = [to_pixel(p[0], p[2]) for p in path_positions]
draw.line(points, fill=(255, 50, 50), width=2)
sx, sz = points[0]
draw.ellipse([sx - 4, sz - 4, sx + 4, sz + 4], fill=(50, 255, 50))
ex, ez = points[-1]
draw.ellipse([ex - 4, ez - 4, ex + 4, ez + 4], fill=(255, 255, 50))
return img
def make_env(map_file, bot_file):
"""Factory function for creating environments."""
def _init():
return Monitor(StrafeEnvironment(map_file, bot_file))
return _init
def train(map_file, bot_file, total_timesteps, lr, device_name, n_envs, resume_from=None):
print(f"Map: {map_file}")
print(f"Bot: {bot_file}")
print(f"Device: {device_name}")
print(f"Environments: {n_envs}")
if n_envs > 1:
env = SubprocVecEnv([make_env(map_file, bot_file) for _ in range(n_envs)])
else:
env = DummyVecEnv([make_env(map_file, bot_file)])
if resume_from:
print(f"Resuming from: {resume_from}")
model = PPO.load(resume_from, env=env, device=device_name,
tensorboard_log="runs/")
else:
policy_kwargs = dict(
features_extractor_class=StrafeFeaturesExtractor,
net_arch=[256, 128],
log_std_init=-0.5, # start with std ≈ 0.6
)
model = PPO(
"MultiInputPolicy",
env,
learning_rate=lr,
n_steps=2048 // n_envs,
batch_size=256,
n_epochs=10,
gamma=0.99,
gae_lambda=0.95,
clip_range=0.2,
ent_coef=0.01,
policy_kwargs=policy_kwargs,
verbose=1,
tensorboard_log="runs/",
device=device_name,
)
print(f"Policy: {sum(p.numel() for p in model.policy.parameters()):,} parameters")
print(f"Training for {total_timesteps:,} timesteps...")
checkpoint_cb = CheckpointCallback(
save_freq=250_000 // n_envs,
save_path="checkpoints/",
name_prefix="ppo_strafe",
)
video_cb = VideoCallback(
map_file=map_file,
bot_file=bot_file,
freq=250_000,
video_steps=500,
)
model.learn(
total_timesteps=total_timesteps,
callback=[LogCallback(), checkpoint_cb, video_cb],
tb_log_name="strafe-ai",
)
model.save("ppo_strafe_ai")
print("Saved model to ppo_strafe_ai.zip")
env.close()
def main():
parser = argparse.ArgumentParser(description="Train strafe-ai with PPO")
parser.add_argument("--map-file", required=True, help="Path to .snfm map file")
parser.add_argument("--bot-file", required=True, help="Path to .qbot WR replay file")
parser.add_argument("--timesteps", type=int, default=10_000_000)
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--n-envs", type=int, default=1, help="Number of parallel environments")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--resume", default=None, help="Path to checkpoint to resume from (no .zip)")
args = parser.parse_args()
train(args.map_file, args.bot_file, args.timesteps, args.lr, args.device, args.n_envs, args.resume)
if __name__ == "__main__":
main()

81
strafe_ai/watch.py Normal file
View File

@@ -0,0 +1,81 @@
"""
Watch the agent play — prints actions and position each step.
Usage:
python -m strafe_ai.watch --map-file map.snfm --bot-file replay.qbot [--model ppo_strafe_ai]
"""
import argparse
import numpy as np
from strafe_ai.environment import StrafeEnvironment
def watch(map_file, bot_file, model_path, steps):
env = StrafeEnvironment(map_file, bot_file)
model = None
if model_path:
from stable_baselines3 import PPO
model = PPO.load(model_path)
print(f"Loaded model from {model_path}")
else:
print("No model - using random actions")
obs, _ = env.reset()
total_reward = 0.0
print(f"{'step':>5} | {'fwd':>3} {'lft':>3} {'bck':>3} {'rgt':>3} {'jmp':>3} | "
f"{'m_dx':>6} {'m_dy':>6} | {'pos_x':>8} {'pos_y':>8} {'pos_z':>8} | "
f"{'vel_x':>7} {'vel_y':>7} {'vel_z':>7} | {'speed':>6} | {'reward':>7}")
print("-" * 120)
for step in range(steps):
if model:
action, _ = model.predict(obs, deterministic=True)
else:
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
total_reward += reward
pos = env.get_position()
vel = env._env.get_velocity()
speed = (vel[0] ** 2 + vel[2] ** 2) ** 0.5
fb = float(action[0])
lr = float(action[1])
fwd = "W" if fb > 0 else "."
bck = "S" if fb < 0 else "."
lft = "A" if lr < 0 else "."
rgt = "D" if lr > 0 else "."
jmp = "^" if action[2] > 0 else "."
m_dx = action[3] * 100.0
m_dy = 0.0
print(f"{step:5d} | {fwd:>3} {lft:>3} {bck:>3} {rgt:>3} {jmp:>3} | "
f"{m_dx:6.1f} {m_dy:6.1f} | "
f"{pos[0]:8.1f} {pos[1]:8.1f} {pos[2]:8.1f} | "
f"{vel[0]:7.1f} {vel[1]:7.1f} {vel[2]:7.1f} | "
f"{speed:6.1f} | {reward:7.3f}")
if terminated or truncated:
print(f"\n--- Episode ended at step {step} | Total reward: {total_reward:.2f} ---\n")
obs, _ = env.reset()
total_reward = 0.0
print(f"\nFinal total reward: {total_reward:.2f}")
def main():
parser = argparse.ArgumentParser(description="Watch the agent play")
parser.add_argument("--map-file", required=True)
parser.add_argument("--bot-file", required=True)
parser.add_argument("--model", default=None, help="Path to saved PPO model (no .zip)")
parser.add_argument("--steps", type=int, default=500)
args = parser.parse_args()
watch(args.map_file, args.bot_file, args.model, args.steps)
if __name__ == "__main__":
main()