Compare commits
70 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4573aaa5ae | ||
|
|
76ff41203d | ||
|
87c3967caa
|
|||
|
4916b76126
|
|||
|
|
b30b43ed9d | ||
|
|
b6b208975c | ||
|
|
2795090f13 | ||
|
|
5564c6529a | ||
|
|
eb62ac2099 | ||
|
|
5e385f806c | ||
|
a030afb018
|
|||
|
e23455bba2
|
|||
|
f7521d49f2
|
|||
|
ee02c9cbda
|
|||
|
2c59742799
|
|||
|
24bcb63d0e
|
|||
|
|
ccb4fb5791 | ||
|
|
0444c6d68a | ||
|
bd2f60fb72
|
|||
|
4dd3201192
|
|||
|
44c8c53122
|
|||
|
aaa5a158e8
|
|||
|
|
cb59737985 | ||
|
|
5ba65ba4d0 | ||
|
eeb935dcd6
|
|||
|
|
89a398f1f6 | ||
|
|
9b1f61b128 | ||
|
|
ae02fbba79 | ||
|
|
679962024f | ||
|
c11910b33e
|
|||
|
e8845ce28d
|
|||
|
73dcb93d5e
|
|||
|
3b00541644
|
|||
|
b29d5f3845
|
|||
|
48cd49bd43
|
|||
|
8ac74d36f0
|
|||
|
9995e852d4
|
|||
|
|
0a1c8068fe | ||
|
792078121b
|
|||
|
df3e813dd9
|
|||
|
cea0bcbaf3
|
|||
|
e79d0378ac
|
|||
|
|
d907672daa | ||
|
|
899278ff64 | ||
|
|
5da88a0f69 | ||
|
96c21fffa9
|
|||
|
357e0f4a20
|
|||
|
31bfa208f8
|
|||
|
1d09378bfd
|
|||
|
bf2bf6d693
|
|||
|
a144ff1178
|
|||
|
48f9657d0f
|
|||
|
e38c0a92b4
|
|||
|
148471dce1
|
|||
|
7bf439395b
|
|||
|
03f5eb5c13
|
|||
|
1e7bb6c4ce
|
|||
|
d8b0f9abbb
|
|||
|
fb8c6e2492
|
|||
|
18cad85b62
|
|||
|
b195a7eb95
|
|||
|
4208090da0
|
|||
|
e31b148f41
|
|||
|
a05113baa5
|
|||
|
9ad8a70ad0
|
|||
|
e890623f2e
|
|||
|
7d55e872e7
|
|||
|
1e1cbeb180
|
|||
|
e19c46d851
|
|||
|
59bb8eee12
|
28
.gitignore
vendored
28
.gitignore
vendored
@@ -1 +1,29 @@
|
||||
# Rust
|
||||
/target
|
||||
|
||||
# Python
|
||||
__pycache__
|
||||
*.pyc
|
||||
*.egg-info
|
||||
.eggs
|
||||
dist
|
||||
build
|
||||
.venv
|
||||
|
||||
# Data files
|
||||
/files
|
||||
*.snfm
|
||||
*.qbot
|
||||
*.snfb
|
||||
*.model
|
||||
*.bin
|
||||
|
||||
# TensorBoard
|
||||
runs/
|
||||
|
||||
# IDE / tools
|
||||
.claude
|
||||
.idea
|
||||
_rust.*
|
||||
|
||||
ppo_strafe_*_steps.zip
|
||||
6520
Cargo.lock
generated
6520
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
18
Cargo.toml
18
Cargo.toml
@@ -1,16 +1,22 @@
|
||||
[package]
|
||||
name = "strafe-ai"
|
||||
name = "strafe-ai-py"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[lib]
|
||||
name = "_rust"
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
burn = { version = "0.20.1", features = ["cuda", "autodiff"] }
|
||||
pyo3 = { version = "0.28", features = ["extension-module"] }
|
||||
numpy = "0.28"
|
||||
glam = "0.32.1"
|
||||
pollster = "0.4.0"
|
||||
wgpu = "29.0.0"
|
||||
|
||||
strafesnet_common = { version = "0.9.0", registry = "strafesnet" }
|
||||
strafesnet_graphics = { version = "0.0.10", registry = "strafesnet" }
|
||||
strafesnet_physics = { version = "=0.0.2-surf", registry = "strafesnet" }
|
||||
strafesnet_graphics = { version = "=0.0.11-depth2", registry = "strafesnet" }
|
||||
strafesnet_physics = { version = "=0.0.2-surf2", registry = "strafesnet" }
|
||||
strafesnet_roblox_bot_file = { version = "0.9.4", registry = "strafesnet" }
|
||||
strafesnet_roblox_bot_player = { version = "0.6.1", registry = "strafesnet" }
|
||||
strafesnet_roblox_bot_player = { version = "=0.6.2-depth2", registry = "strafesnet" }
|
||||
strafesnet_snf = { version = "0.4.0", registry = "strafesnet" }
|
||||
pollster = "0.4.0"
|
||||
|
||||
77
README.md
Normal file
77
README.md
Normal file
@@ -0,0 +1,77 @@
|
||||
# strafe-ai-py
|
||||
|
||||
PyTorch + Rust environment for training an AI to play bhop maps.
|
||||
|
||||
## Architecture
|
||||
|
||||
- **Rust** (via PyO3): Physics simulation + depth rendering — fast, deterministic
|
||||
- **Python** (PyTorch): Model, training, RL algorithms, TensorBoard logging
|
||||
|
||||
The Rust side exposes a [Gymnasium](https://gymnasium.farama.org/)-compatible environment,
|
||||
so any standard RL library (Stable Baselines3, CleanRL, etc.) works out of the box.
|
||||
|
||||
## Setup
|
||||
|
||||
Requires:
|
||||
- Python 3.10+
|
||||
- Rust toolchain
|
||||
- CUDA toolkit
|
||||
- [maturin](https://github.com/PyO3/maturin) for building the Rust extension
|
||||
|
||||
```bash
|
||||
# create a virtual environment
|
||||
uv venv
|
||||
source .venv/bin/activate.fish # or .venv\Scripts\activate on Windows
|
||||
|
||||
# install maturin
|
||||
uv pip install torch
|
||||
uv pip install maturin gymnasium numpy tensorboard stable-baselines3
|
||||
|
||||
# build and install the Rust extension + Python deps
|
||||
maturin develop --release
|
||||
|
||||
# or install with RL extras
|
||||
pip install -e ".[rl]"
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# test the environment
|
||||
python -m strafe_ai.train --map-file path/to/map.snfm
|
||||
|
||||
# tensorboard
|
||||
tensorboard --logdir runs
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
strafe_ai/ Python package
|
||||
environment.py Gymnasium env wrapping Rust sim
|
||||
model.py PyTorch model (StrafeNet)
|
||||
train.py Training script
|
||||
src/
|
||||
lib.rs PyO3 bindings (physics + rendering)
|
||||
```
|
||||
|
||||
## Environment API
|
||||
|
||||
```python
|
||||
from strafe_ai import StrafeEnvironment
|
||||
|
||||
env = StrafeEnvironment("map.snfm")
|
||||
obs, info = env.reset()
|
||||
|
||||
# obs["position_history"] — (50,) float32 — 10 recent positions + angles
|
||||
# obs["depth"] — (2304,) float32 — 64x36 depth buffer
|
||||
|
||||
action = env.action_space.sample() # 7 floats
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
```
|
||||
|
||||
## Next Steps
|
||||
|
||||
- [ ] Implement reward function (curve_dt along WR path)
|
||||
- [ ] Train with PPO (Stable Baselines3)
|
||||
- [ ] Add .qbot replay loading for imitation learning pretraining
|
||||
2324
poetry.lock
generated
Normal file
2324
poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
21
pyproject.toml
Normal file
21
pyproject.toml
Normal file
@@ -0,0 +1,21 @@
|
||||
[build-system]
|
||||
requires = ["maturin>=1.5,<2.0"]
|
||||
build-backend = "maturin"
|
||||
|
||||
[project]
|
||||
name = "strafe-ai"
|
||||
version = "0.1.0"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"torch>=2.0",
|
||||
"gymnasium",
|
||||
"numpy",
|
||||
"tensorboard",
|
||||
"stable-baselines3",
|
||||
"pillow",
|
||||
]
|
||||
|
||||
[tool.maturin]
|
||||
python-source = "."
|
||||
features = ["pyo3/extension-module"]
|
||||
module-name = "strafe_ai._rust"
|
||||
@@ -1 +1 @@
|
||||
hard_tabs = true
|
||||
hard_tabs = true
|
||||
502
src/lib.rs
Normal file
502
src/lib.rs
Normal file
@@ -0,0 +1,502 @@
|
||||
use pyo3::exceptions::PyRuntimeError;
|
||||
use pyo3::prelude::*;
|
||||
|
||||
use strafesnet_common::instruction::TimedInstruction;
|
||||
use strafesnet_common::mouse::MouseState;
|
||||
use strafesnet_common::physics::{
|
||||
Instruction as PhysicsInputInstruction, ModeInstruction, MouseInstruction,
|
||||
SetControlInstruction, Time as PhysicsTime,
|
||||
};
|
||||
use strafesnet_physics::physics::{PhysicsContext, PhysicsData, PhysicsState};
|
||||
use strafesnet_roblox_bot_player::head::Time as PlaybackTime;
|
||||
|
||||
const LIMITS: wgpu::Limits = wgpu::Limits::defaults();
|
||||
const FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Rgba8UnormSrgb;
|
||||
|
||||
const SIZE_X: u32 = 64;
|
||||
const SIZE_Y: u32 = 36;
|
||||
const SIZE: glam::UVec2 = glam::uvec2(SIZE_X, SIZE_Y);
|
||||
const DEPTH_PIXELS: usize = (SIZE_X * SIZE_Y) as usize;
|
||||
const STRIDE_SIZE: u32 = (SIZE_X * size_of::<f32>() as u32).next_multiple_of(256);
|
||||
const POSITION_HISTORY: usize = 10;
|
||||
const POSITION_HISTORY_SIZE: usize = size_of::<HistoricState>() / size_of::<f32>();
|
||||
|
||||
const STEP: PhysicsTime = PhysicsTime::from_millis(10);
|
||||
|
||||
fn float_to_bool(float: f32) -> bool {
|
||||
0.5 < float
|
||||
}
|
||||
|
||||
struct HistoricState {
|
||||
pos: glam::Vec3,
|
||||
vel: glam::Vec3,
|
||||
angles: glam::Vec2,
|
||||
}
|
||||
|
||||
/// The Rust-side environment that wraps physics simulation and depth rendering.
|
||||
/// Exposed to Python via PyO3.
|
||||
#[pyclass]
|
||||
struct StrafeEnv {
|
||||
// rendering
|
||||
wgpu_device: wgpu::Device,
|
||||
wgpu_queue: wgpu::Queue,
|
||||
graphics: strafesnet_roblox_bot_player::graphics::Graphics,
|
||||
graphics_texture_view: wgpu::TextureView,
|
||||
output_staging_buffer: wgpu::Buffer,
|
||||
texture_data: Vec<u8>,
|
||||
|
||||
// simulation
|
||||
geometry: PhysicsData,
|
||||
simulation: PhysicsState,
|
||||
time: PhysicsTime,
|
||||
mouse_pos: glam::IVec2,
|
||||
start_offset: glam::Vec3,
|
||||
|
||||
// scoring
|
||||
bot: strafesnet_roblox_bot_player::bot::CompleteBot,
|
||||
bvh: strafesnet_roblox_bot_player::bvh::Bvh,
|
||||
wr_time: PhysicsTime,
|
||||
best_progress: PlaybackTime,
|
||||
run_finished: bool,
|
||||
|
||||
// position history
|
||||
position_history: Vec<HistoricState>,
|
||||
|
||||
// map data (for platform visualization)
|
||||
map: strafesnet_common::map::CompleteMap,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl StrafeEnv {
|
||||
/// Create a new environment from a map file path.
|
||||
#[new]
|
||||
fn new(map_path: &str, bot_path: &str) -> PyResult<Self> {
|
||||
let map_file = std::fs::read(map_path)
|
||||
.map_err(|e| PyRuntimeError::new_err(format!("Failed to read map: {e}")))?;
|
||||
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
|
||||
.map_err(|e| PyRuntimeError::new_err(format!("Failed to decode map: {e:?}")))?
|
||||
.into_complete_map()
|
||||
.map_err(|e| PyRuntimeError::new_err(format!("Failed to load map: {e:?}")))?;
|
||||
|
||||
let bot_file = std::fs::read(bot_path)
|
||||
.map_err(|e| PyRuntimeError::new_err(format!("Failed to read bot: {e}")))?;
|
||||
let timelines =
|
||||
strafesnet_roblox_bot_file::v0::read_all_to_block(std::io::Cursor::new(bot_file))
|
||||
.map_err(|e| PyRuntimeError::new_err(format!("Failed to decode bot: {e:?}")))?;
|
||||
let bot = strafesnet_roblox_bot_player::bot::CompleteBot::new(timelines)
|
||||
.map_err(|e| PyRuntimeError::new_err(format!("Failed to init bot: {e:?}")))?;
|
||||
let bvh = strafesnet_roblox_bot_player::bvh::Bvh::new(&bot);
|
||||
let wr_time = bot
|
||||
.run_duration(strafesnet_roblox_bot_file::v0::ModeID(0))
|
||||
.unwrap();
|
||||
|
||||
// compute start offset
|
||||
let modes = map.modes.clone().denormalize();
|
||||
let mode = modes
|
||||
.get_mode(strafesnet_common::gameplay_modes::ModeId::MAIN)
|
||||
.ok_or_else(|| PyRuntimeError::new_err("No MAIN mode in map"))?;
|
||||
let start_zone = map
|
||||
.models
|
||||
.get(mode.get_start().get() as usize)
|
||||
.ok_or_else(|| PyRuntimeError::new_err("Start zone not found"))?;
|
||||
let start_offset = glam::Vec3::from_array(
|
||||
start_zone
|
||||
.transform
|
||||
.translation
|
||||
.map(|f| f.into())
|
||||
.to_array(),
|
||||
);
|
||||
|
||||
// init wgpu
|
||||
let desc = wgpu::InstanceDescriptor::new_without_display_handle_from_env();
|
||||
let instance = wgpu::Instance::new(desc);
|
||||
let (wgpu_device, wgpu_queue) = pollster::block_on(async {
|
||||
let adapter = instance
|
||||
.request_adapter(&wgpu::RequestAdapterOptions {
|
||||
power_preference: wgpu::PowerPreference::HighPerformance,
|
||||
force_fallback_adapter: false,
|
||||
compatible_surface: None,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
strafesnet_graphics::setup::step4::request_device(&adapter, LIMITS)
|
||||
.await
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let mut graphics = strafesnet_roblox_bot_player::graphics::Graphics::new(
|
||||
&wgpu_device,
|
||||
&wgpu_queue,
|
||||
SIZE,
|
||||
FORMAT,
|
||||
LIMITS,
|
||||
);
|
||||
graphics
|
||||
.change_map(&wgpu_device, &wgpu_queue, &map)
|
||||
.unwrap();
|
||||
|
||||
let graphics_texture = wgpu_device.create_texture(&wgpu::TextureDescriptor {
|
||||
label: Some("RGB texture"),
|
||||
format: FORMAT,
|
||||
size: wgpu::Extent3d {
|
||||
width: SIZE_X,
|
||||
height: SIZE_Y,
|
||||
depth_or_array_layers: 1,
|
||||
},
|
||||
mip_level_count: 1,
|
||||
sample_count: 1,
|
||||
dimension: wgpu::TextureDimension::D2,
|
||||
usage: wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
|
||||
view_formats: &[],
|
||||
});
|
||||
let graphics_texture_view = graphics_texture.create_view(&wgpu::TextureViewDescriptor {
|
||||
label: Some("RGB texture view"),
|
||||
aspect: wgpu::TextureAspect::All,
|
||||
usage: Some(
|
||||
wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
|
||||
),
|
||||
..Default::default()
|
||||
});
|
||||
let texture_data = Vec::<u8>::with_capacity((STRIDE_SIZE * SIZE_Y) as usize);
|
||||
let output_staging_buffer = wgpu_device.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: Some("Output staging buffer"),
|
||||
size: texture_data.capacity() as u64,
|
||||
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
|
||||
let geometry = PhysicsData::new(&map);
|
||||
let simulation = PhysicsState::default();
|
||||
|
||||
let mut env = Self {
|
||||
wgpu_device,
|
||||
wgpu_queue,
|
||||
graphics,
|
||||
graphics_texture_view,
|
||||
output_staging_buffer,
|
||||
texture_data,
|
||||
geometry,
|
||||
simulation,
|
||||
time: PhysicsTime::ZERO,
|
||||
mouse_pos: glam::ivec2(-5300, 0),
|
||||
start_offset,
|
||||
position_history: Vec::with_capacity(POSITION_HISTORY),
|
||||
bot,
|
||||
bvh,
|
||||
best_progress: PlaybackTime::ZERO,
|
||||
wr_time: wr_time.coerce(),
|
||||
run_finished: false,
|
||||
map,
|
||||
};
|
||||
|
||||
// initial reset
|
||||
env.do_reset();
|
||||
Ok(env)
|
||||
}
|
||||
|
||||
/// Reset the environment. Returns (position_history, depth) as flat f32 arrays.
|
||||
fn reset(&mut self) -> (Vec<f32>, Vec<f32>) {
|
||||
self.do_reset();
|
||||
self.get_observation()
|
||||
}
|
||||
|
||||
/// Take one step with the given action.
|
||||
/// action: [move_forward, move_left, move_back, move_right, jump, mouse_dx, mouse_dy]
|
||||
/// Returns (position_history, depth, done)
|
||||
fn step(&mut self, action: Vec<f32>) -> PyResult<(Vec<f32>, Vec<f32>, bool)> {
|
||||
let &[
|
||||
move_forward,
|
||||
move_left,
|
||||
move_back,
|
||||
move_right,
|
||||
jump,
|
||||
mouse_dx,
|
||||
mouse_dy,
|
||||
] = action.as_slice()
|
||||
else {
|
||||
return Err(PyRuntimeError::new_err("Action must have 7 elements"));
|
||||
};
|
||||
|
||||
// apply controls
|
||||
self.run_instruction(PhysicsInputInstruction::SetControl(
|
||||
SetControlInstruction::SetMoveForward(float_to_bool(move_forward)),
|
||||
));
|
||||
self.run_instruction(PhysicsInputInstruction::SetControl(
|
||||
SetControlInstruction::SetMoveLeft(float_to_bool(move_left)),
|
||||
));
|
||||
self.run_instruction(PhysicsInputInstruction::SetControl(
|
||||
SetControlInstruction::SetMoveBack(float_to_bool(move_back)),
|
||||
));
|
||||
self.run_instruction(PhysicsInputInstruction::SetControl(
|
||||
SetControlInstruction::SetMoveRight(float_to_bool(move_right)),
|
||||
));
|
||||
self.run_instruction(PhysicsInputInstruction::SetControl(
|
||||
SetControlInstruction::SetJump(float_to_bool(jump)),
|
||||
));
|
||||
|
||||
// apply mouse
|
||||
self.mouse_pos += glam::vec2(mouse_dx, mouse_dy).round().as_ivec2();
|
||||
let next_time = self.time + STEP;
|
||||
self.run_instruction(PhysicsInputInstruction::Mouse(
|
||||
MouseInstruction::SetNextMouse(MouseState {
|
||||
pos: self.mouse_pos,
|
||||
time: next_time,
|
||||
}),
|
||||
));
|
||||
self.time = next_time;
|
||||
|
||||
// immediately idle to advance time before get_observation
|
||||
self.run_instruction(PhysicsInputInstruction::Idle);
|
||||
|
||||
let (pos_hist, depth) = self.get_observation();
|
||||
|
||||
// done on completion or timeout (world record not beaten)
|
||||
let done = self.run_finished || self.wr_time < self.time;
|
||||
|
||||
Ok((pos_hist, depth, done))
|
||||
}
|
||||
|
||||
fn reward(&mut self) -> f32 {
|
||||
// this is going int -> float -> int but whatever.
|
||||
let point = self
|
||||
.simulation
|
||||
.body()
|
||||
.position
|
||||
.map(Into::<f32>::into)
|
||||
.to_array()
|
||||
.into();
|
||||
let time = self.bvh.closest_time_to_point(&self.bot, point).unwrap();
|
||||
let progress = self.bot.playback_time(time);
|
||||
|
||||
let mut reward = 0.0;
|
||||
|
||||
// reward incremental progress
|
||||
if self.best_progress < progress {
|
||||
let diff = progress - self.best_progress;
|
||||
self.best_progress = progress;
|
||||
let diff_f32: f32 = diff.into();
|
||||
reward += diff_f32;
|
||||
};
|
||||
|
||||
// reward completion
|
||||
if !self.run_finished {
|
||||
if let Some(finish_time) = self.simulation.get_finish_time() {
|
||||
let finish_time: f32 = finish_time.into();
|
||||
let wr_time: f32 = self.wr_time.into();
|
||||
// more reward for completing longer maps
|
||||
reward += wr_time * wr_time / finish_time;
|
||||
self.run_finished = true;
|
||||
}
|
||||
}
|
||||
|
||||
reward
|
||||
}
|
||||
|
||||
/// Get physics state: [vel_x, vel_y, vel_z]
|
||||
fn get_velocity(&self) -> Vec<f32> {
|
||||
self.simulation
|
||||
.body()
|
||||
.velocity
|
||||
.map(Into::<f32>::into)
|
||||
.to_array()
|
||||
.to_vec()
|
||||
}
|
||||
|
||||
/// Get the current position as [x, y, z].
|
||||
fn get_position(&self) -> Vec<f32> {
|
||||
let trajectory = self.simulation.camera_trajectory(&self.geometry);
|
||||
let pos = trajectory
|
||||
.extrapolated_position(self.time)
|
||||
.map(Into::<f32>::into)
|
||||
.to_array();
|
||||
pos.to_vec()
|
||||
}
|
||||
|
||||
/// Get the WR replay path as list of [x, y, z] positions.
|
||||
fn get_wr_path(&self) -> Vec<Vec<f32>> {
|
||||
self.bot
|
||||
.timelines()
|
||||
.output_events
|
||||
.iter()
|
||||
.map(|event| {
|
||||
vec![
|
||||
event.event.position.x,
|
||||
event.event.position.y,
|
||||
event.event.position.z,
|
||||
]
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get map model positions as list of [x, y, z] (translation only).
|
||||
fn get_map_positions(&self) -> Vec<Vec<f32>> {
|
||||
self.map
|
||||
.models
|
||||
.iter()
|
||||
.map(|model| {
|
||||
model
|
||||
.transform
|
||||
.translation
|
||||
.map(|f| f.into())
|
||||
.to_array()
|
||||
.to_vec()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get observation dimensions.
|
||||
#[staticmethod]
|
||||
fn observation_sizes() -> (usize, usize) {
|
||||
(POSITION_HISTORY * POSITION_HISTORY_SIZE, DEPTH_PIXELS)
|
||||
}
|
||||
|
||||
/// Get action size.
|
||||
#[staticmethod]
|
||||
fn action_size() -> usize {
|
||||
7
|
||||
}
|
||||
}
|
||||
|
||||
impl StrafeEnv {
|
||||
fn do_reset(&mut self) {
|
||||
self.simulation = PhysicsState::default();
|
||||
self.time = PhysicsTime::ZERO;
|
||||
self.mouse_pos = glam::ivec2(-5300, 0);
|
||||
self.position_history.clear();
|
||||
self.best_progress = PlaybackTime::ZERO;
|
||||
self.run_finished = false;
|
||||
|
||||
self.run_instruction(PhysicsInputInstruction::Mode(ModeInstruction::Reset));
|
||||
self.run_instruction(PhysicsInputInstruction::Mode(ModeInstruction::Restart(
|
||||
strafesnet_common::gameplay_modes::ModeId::MAIN,
|
||||
)));
|
||||
}
|
||||
|
||||
fn run_instruction(&mut self, instruction: PhysicsInputInstruction) {
|
||||
PhysicsContext::run_input_instruction(
|
||||
&mut self.simulation,
|
||||
&self.geometry,
|
||||
TimedInstruction {
|
||||
time: self.time,
|
||||
instruction,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
fn get_observation(&mut self) -> (Vec<f32>, Vec<f32>) {
|
||||
let trajectory = self.simulation.camera_trajectory(&self.geometry);
|
||||
let pos = glam::Vec3::from_array(
|
||||
trajectory
|
||||
.extrapolated_position(self.time)
|
||||
.map(Into::<f32>::into)
|
||||
.to_array(),
|
||||
) - self.start_offset;
|
||||
let vel: glam::Vec3 = trajectory
|
||||
.extrapolated_velocity(self.time)
|
||||
.map(Into::<f32>::into)
|
||||
.to_array()
|
||||
.into();
|
||||
let camera = self.simulation.camera();
|
||||
let angles = camera.simulate_move_angles(glam::IVec2::ZERO);
|
||||
|
||||
// build position history input
|
||||
let mut pos_hist = Vec::with_capacity(POSITION_HISTORY * POSITION_HISTORY_SIZE);
|
||||
let cam_matrix =
|
||||
strafesnet_graphics::graphics::view_inv(pos, glam::vec2(angles.x, 0.0)).inverse();
|
||||
for state in self.position_history.iter().rev() {
|
||||
let relative_pos = cam_matrix.transform_point3(state.pos);
|
||||
let relative_vel = cam_matrix.transform_vector3(state.vel);
|
||||
let relative_ang = glam::vec2(angles.x - state.angles.x, state.angles.y);
|
||||
pos_hist.extend_from_slice(&relative_pos.to_array());
|
||||
pos_hist.extend_from_slice(&relative_vel.to_array());
|
||||
pos_hist.extend_from_slice(&relative_ang.to_array());
|
||||
}
|
||||
// pad remaining history with zeros
|
||||
for _ in self.position_history.len()..POSITION_HISTORY {
|
||||
pos_hist.extend(core::iter::repeat_n(0.0, POSITION_HISTORY_SIZE));
|
||||
}
|
||||
|
||||
// update position history
|
||||
if self.position_history.len() < POSITION_HISTORY {
|
||||
self.position_history
|
||||
.push(HistoricState { pos, vel, angles });
|
||||
} else {
|
||||
self.position_history.rotate_left(1);
|
||||
*self.position_history.last_mut().unwrap() = HistoricState { pos, vel, angles };
|
||||
}
|
||||
|
||||
// render depth
|
||||
let depth = self.render_depth(pos, angles);
|
||||
|
||||
(pos_hist, depth)
|
||||
}
|
||||
|
||||
fn render_depth(&mut self, pos: glam::Vec3, angles: glam::Vec2) -> Vec<f32> {
|
||||
let mut encoder =
|
||||
self.wgpu_device
|
||||
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
|
||||
label: Some("depth encoder"),
|
||||
});
|
||||
|
||||
self.graphics
|
||||
.encode_commands(&mut encoder, &self.graphics_texture_view, pos, angles);
|
||||
|
||||
encoder.copy_texture_to_buffer(
|
||||
wgpu::TexelCopyTextureInfo {
|
||||
texture: self.graphics.depth_texture(),
|
||||
mip_level: 0,
|
||||
origin: wgpu::Origin3d::ZERO,
|
||||
aspect: wgpu::TextureAspect::All,
|
||||
},
|
||||
wgpu::TexelCopyBufferInfo {
|
||||
buffer: &self.output_staging_buffer,
|
||||
layout: wgpu::TexelCopyBufferLayout {
|
||||
offset: 0,
|
||||
bytes_per_row: Some(STRIDE_SIZE),
|
||||
rows_per_image: Some(SIZE_Y),
|
||||
},
|
||||
},
|
||||
wgpu::Extent3d {
|
||||
width: SIZE_X,
|
||||
height: SIZE_Y,
|
||||
depth_or_array_layers: 1,
|
||||
},
|
||||
);
|
||||
|
||||
self.wgpu_queue.submit([encoder.finish()]);
|
||||
|
||||
let buffer_slice = self.output_staging_buffer.slice(..);
|
||||
let (sender, receiver) = std::sync::mpsc::channel();
|
||||
buffer_slice.map_async(wgpu::MapMode::Read, move |r| sender.send(r).unwrap());
|
||||
self.wgpu_device
|
||||
.poll(wgpu::PollType::wait_indefinitely())
|
||||
.unwrap();
|
||||
receiver.recv().unwrap().unwrap();
|
||||
|
||||
let mut depth = Vec::with_capacity(DEPTH_PIXELS);
|
||||
{
|
||||
let view = buffer_slice.get_mapped_range();
|
||||
self.texture_data.extend_from_slice(&view[..]);
|
||||
}
|
||||
self.output_staging_buffer.unmap();
|
||||
|
||||
for y in 0..SIZE_Y {
|
||||
depth.extend(
|
||||
self.texture_data[(STRIDE_SIZE * y) as usize
|
||||
..(STRIDE_SIZE * y + SIZE_X * size_of::<f32>() as u32) as usize]
|
||||
.chunks_exact(4)
|
||||
.map(|b| 1.0 - 2.0 * f32::from_le_bytes(b.try_into().unwrap())),
|
||||
);
|
||||
}
|
||||
|
||||
self.texture_data.clear();
|
||||
depth
|
||||
}
|
||||
}
|
||||
|
||||
/// Python module definition
|
||||
#[pymodule]
|
||||
fn _rust(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<StrafeEnv>()?;
|
||||
Ok(())
|
||||
}
|
||||
147
src/main.rs
147
src/main.rs
@@ -1,147 +0,0 @@
|
||||
use burn::backend::Autodiff;
|
||||
use burn::module::AutodiffModule;
|
||||
use burn::nn::loss::{MseLoss, Reduction};
|
||||
use burn::nn::{Linear, LinearConfig, Relu, Sigmoid};
|
||||
use burn::optim::{GradientsParams, Optimizer, SgdConfig};
|
||||
use burn::prelude::*;
|
||||
|
||||
type InferenceBackend = burn::backend::Cuda<f32>;
|
||||
type TrainingBackend = Autodiff<InferenceBackend>;
|
||||
|
||||
const LIMITS: wgpu::Limits = wgpu::Limits::defaults();
|
||||
use strafesnet_common::session::Time as SessionTime;
|
||||
use strafesnet_graphics::setup;
|
||||
|
||||
const INPUT: usize = 2;
|
||||
const HIDDEN: usize = 64;
|
||||
// MoveForward
|
||||
// MoveLeft
|
||||
// MoveBack
|
||||
// MoveRight
|
||||
// Jump
|
||||
// mouse_dx
|
||||
// mouse_dy
|
||||
const OUTPUT: usize = 7;
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
struct Net<B: Backend> {
|
||||
input: Linear<B>,
|
||||
hidden: [Linear<B>; 64],
|
||||
output: Linear<B>,
|
||||
activation: Relu,
|
||||
sigmoid: Sigmoid,
|
||||
}
|
||||
impl<B: Backend> Net<B> {
|
||||
fn init(device: &B::Device) -> Self {
|
||||
Self {
|
||||
input: LinearConfig::new(INPUT, HIDDEN).init(device),
|
||||
hidden: core::array::from_fn(|_| LinearConfig::new(HIDDEN, HIDDEN).init(device)),
|
||||
output: LinearConfig::new(HIDDEN, OUTPUT).init(device),
|
||||
activation: Relu::new(),
|
||||
sigmoid: Sigmoid::new(),
|
||||
}
|
||||
}
|
||||
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
let x = self.input.forward(input);
|
||||
let mut x = self.activation.forward(x);
|
||||
for layer in &self.hidden {
|
||||
x = layer.forward(x);
|
||||
x = self.activation.forward(x);
|
||||
}
|
||||
let x = self.output.forward(x);
|
||||
self.sigmoid.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
fn training() {
|
||||
// load map
|
||||
// load replay
|
||||
// setup player
|
||||
const SIZE_X: usize = 64;
|
||||
const SIZE_Y: usize = 36;
|
||||
|
||||
let map_file = include_bytes!("../bhop_marble_5692093612.snfm");
|
||||
let bot_file = include_bytes!("../bhop_marble_7cf33a64-7120-4514-b9fa-4fe29d9523d.qbot");
|
||||
|
||||
// read files
|
||||
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
|
||||
.unwrap()
|
||||
.into_complete_map()
|
||||
.unwrap();
|
||||
let timelines =
|
||||
strafesnet_roblox_bot_file::v0::read_all_to_block(std::io::Cursor::new(bot_file)).unwrap();
|
||||
|
||||
let bot = strafesnet_roblox_bot_player::bot::CompleteBot::new(timelines).unwrap();
|
||||
let mut playback_head =
|
||||
strafesnet_roblox_bot_player::head::PlaybackHead::new(&bot, SessionTime::ZERO);
|
||||
|
||||
// setup graphics
|
||||
let desc = wgpu::InstanceDescriptor::new_without_display_handle_from_env();
|
||||
let instance = wgpu::Instance::new(desc);
|
||||
let (device, queue) = pollster::block_on(async {
|
||||
let adapter = instance
|
||||
.request_adapter(&wgpu::RequestAdapterOptions {
|
||||
power_preference: wgpu::PowerPreference::HighPerformance,
|
||||
force_fallback_adapter: false,
|
||||
compatible_surface: None,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
setup::step4::request_device(&adapter, LIMITS)
|
||||
.await
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
const FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Rgba8UnormSrgb;
|
||||
let graphics = strafesnet_roblox_bot_player::graphics::Graphics::new(
|
||||
&device,
|
||||
&queue,
|
||||
[SIZE_X as u32, SIZE_Y as u32].into(),
|
||||
FORMAT,
|
||||
LIMITS,
|
||||
);
|
||||
|
||||
// setup simulation
|
||||
// run progressively longer segments of the map, starting very close to the end of the run and working the starting time backwards until the ai can run the whole map
|
||||
|
||||
let device = Default::default();
|
||||
|
||||
let mut model: Net<TrainingBackend> = Net::init(&device);
|
||||
|
||||
let mut optim = SgdConfig::new().init();
|
||||
|
||||
let inputs = Tensor::from_floats([0.0f32; INPUT], &device);
|
||||
let targets = Tensor::from_floats([0.0f32; OUTPUT], &device);
|
||||
|
||||
const LEARNING_RATE: f64 = 0.5;
|
||||
const EPOCHS: usize = 100;
|
||||
|
||||
for epoch in 0..EPOCHS {
|
||||
let predictions = model.forward(inputs.clone());
|
||||
|
||||
let loss = MseLoss::new().forward(predictions, targets.clone(), Reduction::Mean);
|
||||
|
||||
if epoch % (EPOCHS >> 4) == 0 || epoch == EPOCHS - 1 {
|
||||
// .clone().into_scalar() extracts the f32 value from a 1-element tensor.
|
||||
println!(
|
||||
" epoch {:>5} | loss = {:.8}",
|
||||
epoch,
|
||||
loss.clone().into_scalar()
|
||||
);
|
||||
}
|
||||
|
||||
let grads = loss.backward();
|
||||
let grads = GradientsParams::from_grads(grads, &model);
|
||||
|
||||
model = optim.step(LEARNING_RATE, model, grads);
|
||||
}
|
||||
}
|
||||
|
||||
fn inference() {
|
||||
// load map
|
||||
// setup simulation
|
||||
// setup agent-simulation feedback loop
|
||||
// go!
|
||||
}
|
||||
|
||||
fn main() {}
|
||||
4
strafe_ai/__init__.py
Normal file
4
strafe_ai/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from strafe_ai.environment import StrafeEnvironment
|
||||
from strafe_ai.model import StrafeFeaturesExtractor
|
||||
|
||||
__all__ = ["StrafeEnvironment", "StrafeFeaturesExtractor"]
|
||||
99
strafe_ai/environment.py
Normal file
99
strafe_ai/environment.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Gymnasium-compatible environment wrapping the Rust physics sim + depth renderer.
|
||||
|
||||
Usage:
|
||||
env = StrafeEnvironment("map.snfm", "replay.qbot")
|
||||
obs, info = env.reset()
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
"""
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
||||
from strafe_ai._rust import StrafeEnv
|
||||
|
||||
|
||||
class StrafeEnvironment(gym.Env):
|
||||
"""
|
||||
Bhop environment with continuous action space.
|
||||
|
||||
Observations:
|
||||
- depth: (1, 36, 64) — depth buffer as single-channel image for CNN
|
||||
- position_history: (80,) — 10 entries x 8 floats (pos + vel + angles)
|
||||
|
||||
Actions: 7 continuous floats
|
||||
[forward, left, back, right, jump, mouse_dx, mouse_dy]
|
||||
First 5 are thresholded at 0.5 for binary controls.
|
||||
mouse_dx/dy are scaled by 1/sqrt(speed) in the environment.
|
||||
"""
|
||||
|
||||
metadata = {"render_modes": ["none"]}
|
||||
|
||||
def __init__(self, map_path: str, bot_path: str):
|
||||
super().__init__()
|
||||
|
||||
self._env = StrafeEnv(map_path, bot_path)
|
||||
pos_size, depth_size = StrafeEnv.observation_sizes()
|
||||
|
||||
self.observation_space = spaces.Dict({
|
||||
"depth": spaces.Box(-1.0, 1.0, shape=(1, 36, 64), dtype=np.float32),
|
||||
"position_history": spaces.Box(-np.inf, np.inf, shape=(pos_size,), dtype=np.float32),
|
||||
})
|
||||
|
||||
# 2 movement axes + jump + mouse yaw = 4 floats
|
||||
self.action_space = spaces.Box(
|
||||
low=np.array([-1, -1, -1, -1], dtype=np.float32),
|
||||
high=np.array([1, 1, 1, 1], dtype=np.float32),
|
||||
)
|
||||
|
||||
def _make_obs(self, pos_hist, depth):
|
||||
return {
|
||||
"depth": np.array(depth, dtype=np.float32).reshape(1, 36, 64),
|
||||
"position_history": np.array(pos_hist, dtype=np.float32),
|
||||
}
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
super().reset(seed=seed)
|
||||
pos_hist, depth = self._env.reset()
|
||||
return self._make_obs(pos_hist, depth), {}
|
||||
|
||||
def step(self, action):
|
||||
# forward/back axis: >0 = forward, <0 = back
|
||||
fb = float(action[0])
|
||||
fwd = 1.0 if fb > 0 else 0.0
|
||||
back = 1.0 if fb < 0 else 0.0
|
||||
|
||||
# left/right axis: >0 = right, <0 = left
|
||||
lr = float(action[1])
|
||||
right = 1.0 if lr > 0 else 0.0
|
||||
left = 1.0 if lr < 0 else 0.0
|
||||
|
||||
# jump
|
||||
jump = 1.0 if action[2] > 0 else 0.0
|
||||
|
||||
# mouse yaw only, no pitch
|
||||
mouse_dx = float(action[3]) * 100.0
|
||||
mouse_dy = 0.0
|
||||
|
||||
controls = [fwd, left, back, right, jump, mouse_dx, mouse_dy]
|
||||
pos_hist, depth, done = self._env.step(controls)
|
||||
|
||||
obs = self._make_obs(pos_hist, depth)
|
||||
|
||||
# progress along WR path (primary signal)
|
||||
reward = self._env.reward() * 100.0
|
||||
|
||||
# speed bonus that fades out over training
|
||||
# self._step_count tracks total lifetime steps across all episodes
|
||||
self._lifetime_steps = getattr(self, '_lifetime_steps', 0) + 1
|
||||
speed_weight = max(0.0, 1.0 - self._lifetime_steps / 500_000) # fades to 0 over 500k steps
|
||||
if speed_weight > 0:
|
||||
vel = self._env.get_velocity()
|
||||
hspeed = (vel[0] ** 2 + vel[2] ** 2) ** 0.5
|
||||
reward += min(hspeed * 0.1, 1.0) * speed_weight
|
||||
|
||||
return obs, reward, done, False, {}
|
||||
|
||||
def get_position(self):
|
||||
return np.array(self._env.get_position(), dtype=np.float32)
|
||||
90
strafe_ai/model.py
Normal file
90
strafe_ai/model.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
Custom feature extractor for SB3 — ResNet-style CNN on depth buffer,
|
||||
MLP on position history (which includes velocity).
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import gymnasium as gym
|
||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
"""Residual block with skip connection."""
|
||||
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
||||
self.bn1 = nn.BatchNorm2d(channels)
|
||||
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
||||
self.bn2 = nn.BatchNorm2d(channels)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
x = torch.relu(self.bn1(self.conv1(x)))
|
||||
x = self.bn2(self.conv2(x))
|
||||
return torch.relu(x + residual)
|
||||
|
||||
|
||||
class StrafeFeaturesExtractor(BaseFeaturesExtractor):
|
||||
"""
|
||||
Processes the multi-input observation:
|
||||
- depth (1, 36, 64) -> ResNet CNN -> 256 features
|
||||
- position_history (80,) -> MLP -> 64 features
|
||||
Concatenated -> 320 features total
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.spaces.Dict):
|
||||
features_dim = 320
|
||||
super().__init__(observation_space, features_dim=features_dim)
|
||||
|
||||
depth_shape = observation_space["depth"].shape # (1, H, W)
|
||||
h, w = depth_shape[1], depth_shape[2]
|
||||
|
||||
# ResNet-style CNN for depth buffer
|
||||
self.depth_cnn = nn.Sequential(
|
||||
# initial conv: 1 -> 32 channels, downsample
|
||||
nn.Conv2d(1, 32, kernel_size=5, stride=2, padding=2),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(),
|
||||
|
||||
# residual block at 32 channels
|
||||
ResBlock(32),
|
||||
|
||||
# downsample: 32 -> 64 channels
|
||||
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(),
|
||||
|
||||
# residual block at 64 channels
|
||||
ResBlock(64),
|
||||
|
||||
# downsample: 64 -> 128 channels
|
||||
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(),
|
||||
|
||||
# residual block at 128 channels
|
||||
ResBlock(128),
|
||||
|
||||
# global average pooling -> 128 features regardless of input size
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
nn.Flatten(),
|
||||
|
||||
nn.Linear(128, 256),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
# MLP for position history (includes velocity + angles)
|
||||
pos_size = observation_space["position_history"].shape[0]
|
||||
self.position_mlp = nn.Sequential(
|
||||
nn.Linear(pos_size, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, 64),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
def forward(self, observations: dict) -> torch.Tensor:
|
||||
depth_features = self.depth_cnn(observations["depth"])
|
||||
pos_features = self.position_mlp(observations["position_history"])
|
||||
return torch.cat([depth_features, pos_features], dim=1)
|
||||
294
strafe_ai/train.py
Normal file
294
strafe_ai/train.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""
|
||||
RL training script for strafe-ai using PPO.
|
||||
|
||||
Usage:
|
||||
python -m strafe_ai.train --map-file map.snfm --bot-file replay.qbot
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
|
||||
|
||||
from strafe_ai.environment import StrafeEnvironment
|
||||
from strafe_ai.model import StrafeFeaturesExtractor
|
||||
|
||||
|
||||
class LogCallback(BaseCallback):
|
||||
"""Print progress every N steps."""
|
||||
|
||||
def __init__(self, print_freq=10000):
|
||||
super().__init__()
|
||||
self.print_freq = print_freq
|
||||
self.best_reward = 0.0
|
||||
|
||||
def _on_step(self):
|
||||
if self.num_timesteps % self.print_freq < self.training_env.num_envs:
|
||||
infos = self.locals.get("infos", [])
|
||||
rewards = [info.get("episode", {}).get("r", 0) for info in infos if "episode" in info]
|
||||
if rewards:
|
||||
mean_r = sum(rewards) / len(rewards)
|
||||
self.best_reward = max(self.best_reward, max(rewards))
|
||||
print(
|
||||
f" step {self.num_timesteps:>8d} | "
|
||||
f"mean_reward = {mean_r:.2f} | "
|
||||
f"best = {self.best_reward:.2f}"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
class VideoCallback(BaseCallback):
|
||||
"""Record a depth-frame video of the agent playing every N steps."""
|
||||
|
||||
def __init__(self, map_file, bot_file, freq=250_000, video_steps=500):
|
||||
super().__init__()
|
||||
self.map_file = map_file
|
||||
self.bot_file = bot_file
|
||||
self.freq = freq
|
||||
self.video_steps = video_steps
|
||||
self.writer = None
|
||||
self.next_video_at = freq
|
||||
|
||||
def _on_training_start(self):
|
||||
# write to a separate subdirectory to avoid conflicting with SB3's writer
|
||||
log_dir = (self.logger.dir or "runs/strafe-ai") + "_videos"
|
||||
self.writer = SummaryWriter(log_dir)
|
||||
print(f" [video] Callback active. Writer at: {log_dir}")
|
||||
print(f" [video] Will record every {self.freq} steps")
|
||||
print(f" [video] First video at step {self.next_video_at}")
|
||||
|
||||
def _on_step(self):
|
||||
return True
|
||||
|
||||
def _on_rollout_end(self):
|
||||
"""Called after each rollout collection."""
|
||||
ts = self.num_timesteps
|
||||
if ts % 50000 < 2048:
|
||||
print(f" [video] rollout_end: timesteps={ts}, next_video_at={self.next_video_at}")
|
||||
if ts >= self.next_video_at:
|
||||
print(f" [video] Triggered at timestep {ts}")
|
||||
self.next_video_at = ts + self.freq
|
||||
try:
|
||||
self._record_video()
|
||||
except Exception as e:
|
||||
print(f" [video] ERROR: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
def _on_training_end(self):
|
||||
if self.writer:
|
||||
self.writer.close()
|
||||
|
||||
def _record_video(self):
|
||||
print(f" [video] Recording at step {self.num_timesteps}...")
|
||||
|
||||
import io
|
||||
from PIL import Image, ImageDraw
|
||||
from tensorboard.compat.proto.summary_pb2 import Summary
|
||||
|
||||
# create a fresh env for recording
|
||||
env = StrafeEnvironment(self.map_file, self.bot_file)
|
||||
obs, _ = env.reset()
|
||||
|
||||
pil_frames = []
|
||||
path_positions = []
|
||||
|
||||
for _ in range(self.video_steps):
|
||||
# capture depth frame
|
||||
depth = obs["depth"]
|
||||
frame = depth.reshape(36, 64)
|
||||
# auto-contrast for visualization: stretch min/max to full 0-255 range
|
||||
fmin, fmax = frame.min(), frame.max()
|
||||
if fmax > fmin:
|
||||
frame = (frame - fmin) / (fmax - fmin) * 255.0
|
||||
else:
|
||||
frame = np.zeros_like(frame)
|
||||
img = Image.fromarray(frame.astype(np.uint8)).resize((256, 144), Image.Resampling.NEAREST)
|
||||
pil_frames.append(img)
|
||||
|
||||
# record position
|
||||
pos = env.get_position()
|
||||
path_positions.append((pos[0], pos[1], pos[2]))
|
||||
|
||||
action, _ = self.model.predict(obs, deterministic=True)
|
||||
obs, reward, terminated, truncated, _ = env.step(action)
|
||||
if terminated or truncated:
|
||||
obs, _ = env.reset()
|
||||
|
||||
# --- Depth GIF ---
|
||||
buf = io.BytesIO()
|
||||
pil_frames[0].save(
|
||||
buf, format="GIF", save_all=True,
|
||||
append_images=pil_frames[1:], duration=33, loop=0,
|
||||
)
|
||||
buf.seek(0)
|
||||
image_summary = Summary.Image(
|
||||
encoded_image_string=buf.read(), height=144, width=256, colorspace=1,
|
||||
)
|
||||
summary = Summary(value=[Summary.Value(tag="agent/depth_view", image=image_summary)])
|
||||
self.writer.file_writer.add_summary(summary, self.num_timesteps)
|
||||
|
||||
# --- Top-down map with path trace ---
|
||||
map_img = self._draw_topdown_map(env, path_positions)
|
||||
buf = io.BytesIO()
|
||||
map_img.save(buf, format="PNG")
|
||||
buf.seek(0)
|
||||
map_summary = Summary.Image(
|
||||
encoded_image_string=buf.read(),
|
||||
height=map_img.height,
|
||||
width=map_img.width,
|
||||
colorspace=3,
|
||||
)
|
||||
summary = Summary(value=[Summary.Value(tag="agent/topdown_map", image=map_summary)])
|
||||
self.writer.file_writer.add_summary(summary, self.num_timesteps)
|
||||
|
||||
self.writer.flush()
|
||||
print(f" [video] Saved GIF + map at step {self.num_timesteps}")
|
||||
|
||||
def _draw_topdown_map(self, env, path_positions, img_size=512):
|
||||
"""Draw a top-down map with platforms, WR path, and the agent's path."""
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
platforms = env._env.get_map_positions()
|
||||
wr_path = env._env.get_wr_path()
|
||||
|
||||
all_x = [p[0] for p in platforms] + [p[0] for p in path_positions] + [p[0] for p in wr_path]
|
||||
all_z = [p[2] for p in platforms] + [p[2] for p in path_positions] + [p[2] for p in wr_path]
|
||||
|
||||
if not all_x:
|
||||
return Image.new("RGB", (img_size, img_size), (30, 30, 30))
|
||||
|
||||
min_x, max_x = min(all_x) - 50, max(all_x) + 50
|
||||
min_z, max_z = min(all_z) - 50, max(all_z) + 50
|
||||
range_x = max_x - min_x
|
||||
range_z = max_z - min_z
|
||||
scale = (img_size - 20) / max(range_x, range_z)
|
||||
|
||||
def to_pixel(x, z):
|
||||
px = int((x - min_x) * scale) + 10
|
||||
pz = int((z - min_z) * scale) + 10
|
||||
return px, pz
|
||||
|
||||
img = Image.new("RGB", (img_size, img_size), (30, 30, 30))
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
# platforms
|
||||
for p in platforms:
|
||||
px, pz = to_pixel(p[0], p[2])
|
||||
r = max(2, int(5 * scale))
|
||||
draw.rectangle([px - r, pz - r, px + r, pz + r], fill=(80, 80, 120))
|
||||
|
||||
# WR path as dotted blue line
|
||||
if len(wr_path) > 1:
|
||||
wr_points = [to_pixel(p[0], p[2]) for p in wr_path]
|
||||
for i in range(0, len(wr_points) - 1, 2):
|
||||
if i + 1 < len(wr_points):
|
||||
draw.line([wr_points[i], wr_points[i + 1]], fill=(50, 120, 255), width=1)
|
||||
|
||||
# agent path as solid red line
|
||||
if len(path_positions) > 1:
|
||||
points = [to_pixel(p[0], p[2]) for p in path_positions]
|
||||
draw.line(points, fill=(255, 50, 50), width=2)
|
||||
|
||||
sx, sz = points[0]
|
||||
draw.ellipse([sx - 4, sz - 4, sx + 4, sz + 4], fill=(50, 255, 50))
|
||||
ex, ez = points[-1]
|
||||
draw.ellipse([ex - 4, ez - 4, ex + 4, ez + 4], fill=(255, 255, 50))
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def make_env(map_file, bot_file):
|
||||
"""Factory function for creating environments."""
|
||||
def _init():
|
||||
return Monitor(StrafeEnvironment(map_file, bot_file))
|
||||
return _init
|
||||
|
||||
|
||||
def train(map_file, bot_file, total_timesteps, lr, device_name, n_envs, resume_from=None):
|
||||
print(f"Map: {map_file}")
|
||||
print(f"Bot: {bot_file}")
|
||||
print(f"Device: {device_name}")
|
||||
print(f"Environments: {n_envs}")
|
||||
|
||||
if n_envs > 1:
|
||||
env = SubprocVecEnv([make_env(map_file, bot_file) for _ in range(n_envs)])
|
||||
else:
|
||||
env = DummyVecEnv([make_env(map_file, bot_file)])
|
||||
|
||||
if resume_from:
|
||||
print(f"Resuming from: {resume_from}")
|
||||
model = PPO.load(resume_from, env=env, device=device_name,
|
||||
tensorboard_log="runs/")
|
||||
else:
|
||||
policy_kwargs = dict(
|
||||
features_extractor_class=StrafeFeaturesExtractor,
|
||||
net_arch=[256, 128],
|
||||
log_std_init=-0.5, # start with std ≈ 0.6
|
||||
)
|
||||
|
||||
model = PPO(
|
||||
"MultiInputPolicy",
|
||||
env,
|
||||
learning_rate=lr,
|
||||
n_steps=2048 // n_envs,
|
||||
batch_size=256,
|
||||
n_epochs=10,
|
||||
gamma=0.99,
|
||||
gae_lambda=0.95,
|
||||
clip_range=0.2,
|
||||
ent_coef=0.01,
|
||||
policy_kwargs=policy_kwargs,
|
||||
verbose=1,
|
||||
tensorboard_log="runs/",
|
||||
device=device_name,
|
||||
)
|
||||
|
||||
print(f"Policy: {sum(p.numel() for p in model.policy.parameters()):,} parameters")
|
||||
print(f"Training for {total_timesteps:,} timesteps...")
|
||||
|
||||
checkpoint_cb = CheckpointCallback(
|
||||
save_freq=250_000 // n_envs,
|
||||
save_path="checkpoints/",
|
||||
name_prefix="ppo_strafe",
|
||||
)
|
||||
|
||||
video_cb = VideoCallback(
|
||||
map_file=map_file,
|
||||
bot_file=bot_file,
|
||||
freq=250_000,
|
||||
video_steps=500,
|
||||
)
|
||||
|
||||
model.learn(
|
||||
total_timesteps=total_timesteps,
|
||||
callback=[LogCallback(), checkpoint_cb, video_cb],
|
||||
tb_log_name="strafe-ai",
|
||||
)
|
||||
|
||||
model.save("ppo_strafe_ai")
|
||||
print("Saved model to ppo_strafe_ai.zip")
|
||||
env.close()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Train strafe-ai with PPO")
|
||||
parser.add_argument("--map-file", required=True, help="Path to .snfm map file")
|
||||
parser.add_argument("--bot-file", required=True, help="Path to .qbot WR replay file")
|
||||
parser.add_argument("--timesteps", type=int, default=10_000_000)
|
||||
parser.add_argument("--lr", type=float, default=3e-4)
|
||||
parser.add_argument("--n-envs", type=int, default=1, help="Number of parallel environments")
|
||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
||||
parser.add_argument("--resume", default=None, help="Path to checkpoint to resume from (no .zip)")
|
||||
args = parser.parse_args()
|
||||
|
||||
train(args.map_file, args.bot_file, args.timesteps, args.lr, args.device, args.n_envs, args.resume)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
81
strafe_ai/watch.py
Normal file
81
strafe_ai/watch.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""
|
||||
Watch the agent play — prints actions and position each step.
|
||||
|
||||
Usage:
|
||||
python -m strafe_ai.watch --map-file map.snfm --bot-file replay.qbot [--model ppo_strafe_ai]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
from strafe_ai.environment import StrafeEnvironment
|
||||
|
||||
|
||||
def watch(map_file, bot_file, model_path, steps):
|
||||
env = StrafeEnvironment(map_file, bot_file)
|
||||
|
||||
model = None
|
||||
if model_path:
|
||||
from stable_baselines3 import PPO
|
||||
model = PPO.load(model_path)
|
||||
print(f"Loaded model from {model_path}")
|
||||
else:
|
||||
print("No model - using random actions")
|
||||
|
||||
obs, _ = env.reset()
|
||||
total_reward = 0.0
|
||||
|
||||
print(f"{'step':>5} | {'fwd':>3} {'lft':>3} {'bck':>3} {'rgt':>3} {'jmp':>3} | "
|
||||
f"{'m_dx':>6} {'m_dy':>6} | {'pos_x':>8} {'pos_y':>8} {'pos_z':>8} | "
|
||||
f"{'vel_x':>7} {'vel_y':>7} {'vel_z':>7} | {'speed':>6} | {'reward':>7}")
|
||||
print("-" * 120)
|
||||
|
||||
for step in range(steps):
|
||||
if model:
|
||||
action, _ = model.predict(obs, deterministic=True)
|
||||
else:
|
||||
action = env.action_space.sample()
|
||||
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
total_reward += reward
|
||||
|
||||
pos = env.get_position()
|
||||
vel = env._env.get_velocity()
|
||||
speed = (vel[0] ** 2 + vel[2] ** 2) ** 0.5
|
||||
|
||||
fb = float(action[0])
|
||||
lr = float(action[1])
|
||||
fwd = "W" if fb > 0 else "."
|
||||
bck = "S" if fb < 0 else "."
|
||||
lft = "A" if lr < 0 else "."
|
||||
rgt = "D" if lr > 0 else "."
|
||||
jmp = "^" if action[2] > 0 else "."
|
||||
m_dx = action[3] * 100.0
|
||||
m_dy = 0.0
|
||||
|
||||
print(f"{step:5d} | {fwd:>3} {lft:>3} {bck:>3} {rgt:>3} {jmp:>3} | "
|
||||
f"{m_dx:6.1f} {m_dy:6.1f} | "
|
||||
f"{pos[0]:8.1f} {pos[1]:8.1f} {pos[2]:8.1f} | "
|
||||
f"{vel[0]:7.1f} {vel[1]:7.1f} {vel[2]:7.1f} | "
|
||||
f"{speed:6.1f} | {reward:7.3f}")
|
||||
|
||||
if terminated or truncated:
|
||||
print(f"\n--- Episode ended at step {step} | Total reward: {total_reward:.2f} ---\n")
|
||||
obs, _ = env.reset()
|
||||
total_reward = 0.0
|
||||
|
||||
print(f"\nFinal total reward: {total_reward:.2f}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Watch the agent play")
|
||||
parser.add_argument("--map-file", required=True)
|
||||
parser.add_argument("--bot-file", required=True)
|
||||
parser.add_argument("--model", default=None, help="Path to saved PPO model (no .zip)")
|
||||
parser.add_argument("--steps", type=int, default=500)
|
||||
args = parser.parse_args()
|
||||
|
||||
watch(args.map_file, args.bot_file, args.model, args.steps)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user