forked from StrafesNET/strafe-ai
Compare commits
21 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d907672daa | ||
|
|
899278ff64 | ||
|
|
5da88a0f69 | ||
|
96c21fffa9
|
|||
|
357e0f4a20
|
|||
|
31bfa208f8
|
|||
|
1d09378bfd
|
|||
|
bf2bf6d693
|
|||
|
a144ff1178
|
|||
|
48f9657d0f
|
|||
|
e38c0a92b4
|
|||
|
148471dce1
|
|||
|
7bf439395b
|
|||
|
03f5eb5c13
|
|||
|
1e7bb6c4ce
|
|||
|
d8b0f9abbb
|
|||
|
fb8c6e2492
|
|||
|
18cad85b62
|
|||
|
b195a7eb95
|
|||
|
4208090da0
|
|||
|
e31b148f41
|
27
.gitignore
vendored
27
.gitignore
vendored
@@ -1,2 +1,27 @@
|
||||
/files
|
||||
# Rust
|
||||
/target
|
||||
|
||||
# Python
|
||||
__pycache__
|
||||
*.pyc
|
||||
*.egg-info
|
||||
.eggs
|
||||
dist
|
||||
build
|
||||
.venv
|
||||
|
||||
# Data files
|
||||
/files
|
||||
*.snfm
|
||||
*.qbot
|
||||
*.snfb
|
||||
*.model
|
||||
*.bin
|
||||
|
||||
# TensorBoard
|
||||
runs/
|
||||
|
||||
# IDE / tools
|
||||
.claude
|
||||
.idea
|
||||
_rust.*
|
||||
|
||||
6517
Cargo.lock
generated
6517
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
10
Cargo.toml
10
Cargo.toml
@@ -1,10 +1,15 @@
|
||||
[package]
|
||||
name = "strafe-ai"
|
||||
name = "strafe-ai-py"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[lib]
|
||||
name = "_rust"
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
burn = { version = "0.20.1", features = ["cuda", "autodiff"] }
|
||||
pyo3 = { version = "0.28", features = ["extension-module"] }
|
||||
numpy = "0.28"
|
||||
glam = "0.32.1"
|
||||
pollster = "0.4.0"
|
||||
wgpu = "29.0.0"
|
||||
@@ -15,4 +20,3 @@ strafesnet_physics = { version = "=0.0.2-surf", registry = "strafesnet" }
|
||||
strafesnet_roblox_bot_file = { version = "0.9.4", registry = "strafesnet" }
|
||||
strafesnet_roblox_bot_player = { version = "=0.6.2-depth2", registry = "strafesnet" }
|
||||
strafesnet_snf = { version = "0.4.0", registry = "strafesnet" }
|
||||
chrono = { version = "0.4.44", default-features = false, features = ["now"] }
|
||||
|
||||
76
README.md
Normal file
76
README.md
Normal file
@@ -0,0 +1,76 @@
|
||||
# strafe-ai-py
|
||||
|
||||
PyTorch + Rust environment for training an AI to play bhop maps.
|
||||
|
||||
## Architecture
|
||||
|
||||
- **Rust** (via PyO3): Physics simulation + depth rendering — fast, deterministic
|
||||
- **Python** (PyTorch): Model, training, RL algorithms, TensorBoard logging
|
||||
|
||||
The Rust side exposes a [Gymnasium](https://gymnasium.farama.org/)-compatible environment,
|
||||
so any standard RL library (Stable Baselines3, CleanRL, etc.) works out of the box.
|
||||
|
||||
## Setup
|
||||
|
||||
Requires:
|
||||
- Python 3.10+
|
||||
- Rust toolchain
|
||||
- CUDA toolkit
|
||||
- [maturin](https://github.com/PyO3/maturin) for building the Rust extension
|
||||
|
||||
```bash
|
||||
# create a virtual environment
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate # or .venv\Scripts\activate on Windows
|
||||
|
||||
# install maturin
|
||||
pip install maturin
|
||||
|
||||
# build and install the Rust extension + Python deps
|
||||
maturin develop --release
|
||||
|
||||
# or install with RL extras
|
||||
pip install -e ".[rl]"
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# test the environment
|
||||
python -m strafe_ai.train --map-file path/to/map.snfm
|
||||
|
||||
# tensorboard
|
||||
tensorboard --logdir runs
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
strafe_ai/ Python package
|
||||
environment.py Gymnasium env wrapping Rust sim
|
||||
model.py PyTorch model (StrafeNet)
|
||||
train.py Training script
|
||||
src/
|
||||
lib.rs PyO3 bindings (physics + rendering)
|
||||
```
|
||||
|
||||
## Environment API
|
||||
|
||||
```python
|
||||
from strafe_ai import StrafeEnvironment
|
||||
|
||||
env = StrafeEnvironment("map.snfm")
|
||||
obs, info = env.reset()
|
||||
|
||||
# obs["position_history"] — (50,) float32 — 10 recent positions + angles
|
||||
# obs["depth"] — (2304,) float32 — 64x36 depth buffer
|
||||
|
||||
action = env.action_space.sample() # 7 floats
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
```
|
||||
|
||||
## Next Steps
|
||||
|
||||
- [ ] Implement reward function (curve_dt along WR path)
|
||||
- [ ] Train with PPO (Stable Baselines3)
|
||||
- [ ] Add .qbot replay loading for imitation learning pretraining
|
||||
2319
poetry.lock
generated
Normal file
2319
poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
22
pyproject.toml
Normal file
22
pyproject.toml
Normal file
@@ -0,0 +1,22 @@
|
||||
[build-system]
|
||||
requires = ["maturin>=1.5,<2.0"]
|
||||
build-backend = "maturin"
|
||||
|
||||
[project]
|
||||
name = "strafe-ai"
|
||||
version = "0.1.0"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"torch>=2.0",
|
||||
"gymnasium",
|
||||
"numpy",
|
||||
"tensorboard",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
rl = ["stable-baselines3"]
|
||||
|
||||
[tool.maturin]
|
||||
python-source = "."
|
||||
features = ["pyo3/extension-module"]
|
||||
module-name = "strafe_ai._rust"
|
||||
@@ -1 +0,0 @@
|
||||
hard_tabs = true
|
||||
375
src/lib.rs
Normal file
375
src/lib.rs
Normal file
@@ -0,0 +1,375 @@
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::exceptions::PyRuntimeError;
|
||||
|
||||
use strafesnet_common::instruction::TimedInstruction;
|
||||
use strafesnet_common::mouse::MouseState;
|
||||
use strafesnet_common::physics::{
|
||||
Instruction as PhysicsInputInstruction, ModeInstruction, MouseInstruction,
|
||||
SetControlInstruction, Time as PhysicsTime,
|
||||
};
|
||||
use strafesnet_physics::physics::{PhysicsContext, PhysicsData, PhysicsState};
|
||||
|
||||
const LIMITS: wgpu::Limits = wgpu::Limits::defaults();
|
||||
const FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Rgba8UnormSrgb;
|
||||
|
||||
const SIZE_X: u32 = 64;
|
||||
const SIZE_Y: u32 = 36;
|
||||
const SIZE: glam::UVec2 = glam::uvec2(SIZE_X, SIZE_Y);
|
||||
const DEPTH_PIXELS: usize = (SIZE_X * SIZE_Y) as usize;
|
||||
const STRIDE_SIZE: u32 = (SIZE_X * size_of::<f32>() as u32).next_multiple_of(256);
|
||||
const POSITION_HISTORY: usize = 10;
|
||||
|
||||
const STEP: PhysicsTime = PhysicsTime::from_millis(10);
|
||||
|
||||
/// The Rust-side environment that wraps physics simulation and depth rendering.
|
||||
/// Exposed to Python via PyO3.
|
||||
#[pyclass]
|
||||
struct StrafeEnv {
|
||||
// rendering
|
||||
wgpu_device: wgpu::Device,
|
||||
wgpu_queue: wgpu::Queue,
|
||||
graphics: strafesnet_roblox_bot_player::graphics::Graphics,
|
||||
graphics_texture_view: wgpu::TextureView,
|
||||
output_staging_buffer: wgpu::Buffer,
|
||||
texture_data: Vec<u8>,
|
||||
|
||||
// simulation
|
||||
geometry: PhysicsData,
|
||||
simulation: PhysicsState,
|
||||
time: PhysicsTime,
|
||||
mouse_pos: glam::IVec2,
|
||||
start_offset: glam::Vec3,
|
||||
|
||||
// position history (relative positions + angles)
|
||||
position_history: Vec<(glam::Vec3, glam::Vec2)>,
|
||||
|
||||
// map data (kept for reset)
|
||||
map: strafesnet_common::map::CompleteMap,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl StrafeEnv {
|
||||
/// Create a new environment from a map file path.
|
||||
#[new]
|
||||
fn new(map_path: &str) -> PyResult<Self> {
|
||||
let map_file = std::fs::read(map_path)
|
||||
.map_err(|e| PyRuntimeError::new_err(format!("Failed to read map: {e}")))?;
|
||||
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
|
||||
.map_err(|e| PyRuntimeError::new_err(format!("Failed to decode map: {e:?}")))?
|
||||
.into_complete_map()
|
||||
.map_err(|e| PyRuntimeError::new_err(format!("Failed to load map: {e:?}")))?;
|
||||
|
||||
// compute start offset
|
||||
let modes = map.modes.clone().denormalize();
|
||||
let mode = modes
|
||||
.get_mode(strafesnet_common::gameplay_modes::ModeId::MAIN)
|
||||
.ok_or_else(|| PyRuntimeError::new_err("No MAIN mode in map"))?;
|
||||
let start_zone = map
|
||||
.models
|
||||
.get(mode.get_start().get() as usize)
|
||||
.ok_or_else(|| PyRuntimeError::new_err("Start zone not found"))?;
|
||||
let start_offset = glam::Vec3::from_array(
|
||||
start_zone.transform.translation.map(|f| f.into()).to_array(),
|
||||
);
|
||||
|
||||
// init wgpu
|
||||
let desc = wgpu::InstanceDescriptor::new_without_display_handle_from_env();
|
||||
let instance = wgpu::Instance::new(desc);
|
||||
let (wgpu_device, wgpu_queue) = pollster::block_on(async {
|
||||
let adapter = instance
|
||||
.request_adapter(&wgpu::RequestAdapterOptions {
|
||||
power_preference: wgpu::PowerPreference::HighPerformance,
|
||||
force_fallback_adapter: false,
|
||||
compatible_surface: None,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
strafesnet_graphics::setup::step4::request_device(&adapter, LIMITS)
|
||||
.await
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let mut graphics = strafesnet_roblox_bot_player::graphics::Graphics::new(
|
||||
&wgpu_device, &wgpu_queue, SIZE, FORMAT, LIMITS,
|
||||
);
|
||||
graphics
|
||||
.change_map(&wgpu_device, &wgpu_queue, &map)
|
||||
.unwrap();
|
||||
|
||||
let graphics_texture = wgpu_device.create_texture(&wgpu::TextureDescriptor {
|
||||
label: Some("RGB texture"),
|
||||
format: FORMAT,
|
||||
size: wgpu::Extent3d {
|
||||
width: SIZE_X,
|
||||
height: SIZE_Y,
|
||||
depth_or_array_layers: 1,
|
||||
},
|
||||
mip_level_count: 1,
|
||||
sample_count: 1,
|
||||
dimension: wgpu::TextureDimension::D2,
|
||||
usage: wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
|
||||
view_formats: &[],
|
||||
});
|
||||
let graphics_texture_view = graphics_texture.create_view(&wgpu::TextureViewDescriptor {
|
||||
label: Some("RGB texture view"),
|
||||
aspect: wgpu::TextureAspect::All,
|
||||
usage: Some(
|
||||
wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
|
||||
),
|
||||
..Default::default()
|
||||
});
|
||||
let texture_data = Vec::<u8>::with_capacity((STRIDE_SIZE * SIZE_Y) as usize);
|
||||
let output_staging_buffer = wgpu_device.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: Some("Output staging buffer"),
|
||||
size: texture_data.capacity() as u64,
|
||||
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
|
||||
let geometry = PhysicsData::new(&map);
|
||||
let simulation = PhysicsState::default();
|
||||
|
||||
let mut env = Self {
|
||||
wgpu_device,
|
||||
wgpu_queue,
|
||||
graphics,
|
||||
graphics_texture_view,
|
||||
output_staging_buffer,
|
||||
texture_data,
|
||||
geometry,
|
||||
simulation,
|
||||
time: PhysicsTime::ZERO,
|
||||
mouse_pos: glam::ivec2(-5300, 0),
|
||||
start_offset,
|
||||
position_history: Vec::with_capacity(POSITION_HISTORY),
|
||||
map,
|
||||
};
|
||||
|
||||
// initial reset
|
||||
env.do_reset();
|
||||
Ok(env)
|
||||
}
|
||||
|
||||
/// Reset the environment. Returns (position_history, depth) as flat f32 arrays.
|
||||
fn reset(&mut self) -> (Vec<f32>, Vec<f32>) {
|
||||
self.do_reset();
|
||||
self.get_observation()
|
||||
}
|
||||
|
||||
/// Take one step with the given action.
|
||||
/// action: [move_forward, move_left, move_back, move_right, jump, mouse_dx, mouse_dy]
|
||||
/// Returns (position_history, depth, done)
|
||||
fn step(&mut self, action: Vec<f32>) -> PyResult<(Vec<f32>, Vec<f32>, bool)> {
|
||||
if action.len() != 7 {
|
||||
return Err(PyRuntimeError::new_err("Action must have 7 elements"));
|
||||
}
|
||||
|
||||
let move_forward = action[0] > 0.5;
|
||||
let move_left = action[1] > 0.5;
|
||||
let move_back = action[2] > 0.5;
|
||||
let move_right = action[3] > 0.5;
|
||||
let jump = action[4] > 0.5;
|
||||
let mouse_dx = action[5];
|
||||
let mouse_dy = action[6];
|
||||
|
||||
// apply controls
|
||||
self.run_instruction(PhysicsInputInstruction::SetControl(
|
||||
SetControlInstruction::SetMoveForward(move_forward),
|
||||
));
|
||||
self.run_instruction(PhysicsInputInstruction::SetControl(
|
||||
SetControlInstruction::SetMoveLeft(move_left),
|
||||
));
|
||||
self.run_instruction(PhysicsInputInstruction::SetControl(
|
||||
SetControlInstruction::SetMoveBack(move_back),
|
||||
));
|
||||
self.run_instruction(PhysicsInputInstruction::SetControl(
|
||||
SetControlInstruction::SetMoveRight(move_right),
|
||||
));
|
||||
self.run_instruction(PhysicsInputInstruction::SetControl(
|
||||
SetControlInstruction::SetJump(jump),
|
||||
));
|
||||
|
||||
// apply mouse
|
||||
self.mouse_pos += glam::vec2(mouse_dx, mouse_dy).round().as_ivec2();
|
||||
let next_time = self.time + STEP;
|
||||
PhysicsContext::run_input_instruction(
|
||||
&mut self.simulation,
|
||||
&self.geometry,
|
||||
TimedInstruction {
|
||||
time: self.time,
|
||||
instruction: PhysicsInputInstruction::Mouse(MouseInstruction::SetNextMouse(
|
||||
MouseState {
|
||||
pos: self.mouse_pos,
|
||||
time: next_time,
|
||||
},
|
||||
)),
|
||||
},
|
||||
);
|
||||
self.time = next_time;
|
||||
|
||||
let (pos_hist, depth) = self.get_observation();
|
||||
|
||||
// done after 20 seconds of simulation
|
||||
let done = self.time >= PhysicsTime::from_millis(20_000);
|
||||
|
||||
Ok((pos_hist, depth, done))
|
||||
}
|
||||
|
||||
/// Get the current position as [x, y, z].
|
||||
fn get_position(&self) -> Vec<f32> {
|
||||
let trajectory = self.simulation.camera_trajectory(&self.geometry);
|
||||
let pos = trajectory
|
||||
.extrapolated_position(self.time)
|
||||
.map(Into::<f32>::into)
|
||||
.to_array();
|
||||
pos.to_vec()
|
||||
}
|
||||
|
||||
/// Get observation dimensions.
|
||||
#[staticmethod]
|
||||
fn observation_sizes() -> (usize, usize) {
|
||||
(POSITION_HISTORY * 5, DEPTH_PIXELS)
|
||||
}
|
||||
|
||||
/// Get action size.
|
||||
#[staticmethod]
|
||||
fn action_size() -> usize {
|
||||
7
|
||||
}
|
||||
}
|
||||
|
||||
impl StrafeEnv {
|
||||
fn do_reset(&mut self) {
|
||||
self.simulation = PhysicsState::default();
|
||||
self.time = PhysicsTime::ZERO;
|
||||
self.mouse_pos = glam::ivec2(-5300, 0);
|
||||
self.position_history.clear();
|
||||
|
||||
self.run_instruction(PhysicsInputInstruction::Mode(ModeInstruction::Reset));
|
||||
self.run_instruction(PhysicsInputInstruction::Mode(ModeInstruction::Restart(
|
||||
strafesnet_common::gameplay_modes::ModeId::MAIN,
|
||||
)));
|
||||
}
|
||||
|
||||
fn run_instruction(&mut self, instruction: PhysicsInputInstruction) {
|
||||
PhysicsContext::run_input_instruction(
|
||||
&mut self.simulation,
|
||||
&self.geometry,
|
||||
TimedInstruction {
|
||||
time: self.time,
|
||||
instruction,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
fn get_observation(&mut self) -> (Vec<f32>, Vec<f32>) {
|
||||
let trajectory = self.simulation.camera_trajectory(&self.geometry);
|
||||
let pos: glam::Vec3 = trajectory
|
||||
.extrapolated_position(self.time)
|
||||
.map(Into::<f32>::into)
|
||||
.to_array()
|
||||
.into();
|
||||
let camera = self.simulation.camera();
|
||||
let angles = camera.simulate_move_angles(glam::IVec2::ZERO);
|
||||
|
||||
// build position history input
|
||||
let mut pos_hist = Vec::with_capacity(POSITION_HISTORY * 5);
|
||||
if !self.position_history.is_empty() {
|
||||
let cam_matrix =
|
||||
strafesnet_graphics::graphics::view_inv(pos - self.start_offset, angles).inverse();
|
||||
for &(p, a) in self.position_history.iter().rev() {
|
||||
let relative_pos = cam_matrix.transform_vector3(p);
|
||||
let relative_ang = glam::vec2(angles.x - a.x, a.y);
|
||||
pos_hist.extend_from_slice(&relative_pos.to_array());
|
||||
pos_hist.extend_from_slice(&relative_ang.to_array());
|
||||
}
|
||||
}
|
||||
// pad remaining history with zeros
|
||||
for _ in self.position_history.len()..POSITION_HISTORY {
|
||||
pos_hist.extend_from_slice(&[0.0, 0.0, 0.0, 0.0, 0.0]);
|
||||
}
|
||||
|
||||
// update position history
|
||||
if self.position_history.len() < POSITION_HISTORY {
|
||||
self.position_history.push((pos, angles));
|
||||
} else {
|
||||
self.position_history.rotate_left(1);
|
||||
*self.position_history.last_mut().unwrap() = (pos, angles);
|
||||
}
|
||||
|
||||
// render depth
|
||||
let render_pos = pos - self.start_offset;
|
||||
let depth = self.render_depth(render_pos, angles);
|
||||
|
||||
(pos_hist, depth)
|
||||
}
|
||||
|
||||
fn render_depth(&mut self, pos: glam::Vec3, angles: glam::Vec2) -> Vec<f32> {
|
||||
let mut encoder = self
|
||||
.wgpu_device
|
||||
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
|
||||
label: Some("depth encoder"),
|
||||
});
|
||||
|
||||
self.graphics
|
||||
.encode_commands(&mut encoder, &self.graphics_texture_view, pos, angles);
|
||||
|
||||
encoder.copy_texture_to_buffer(
|
||||
wgpu::TexelCopyTextureInfo {
|
||||
texture: self.graphics.depth_texture(),
|
||||
mip_level: 0,
|
||||
origin: wgpu::Origin3d::ZERO,
|
||||
aspect: wgpu::TextureAspect::All,
|
||||
},
|
||||
wgpu::TexelCopyBufferInfo {
|
||||
buffer: &self.output_staging_buffer,
|
||||
layout: wgpu::TexelCopyBufferLayout {
|
||||
offset: 0,
|
||||
bytes_per_row: Some(STRIDE_SIZE),
|
||||
rows_per_image: Some(SIZE_Y),
|
||||
},
|
||||
},
|
||||
wgpu::Extent3d {
|
||||
width: SIZE_X,
|
||||
height: SIZE_Y,
|
||||
depth_or_array_layers: 1,
|
||||
},
|
||||
);
|
||||
|
||||
self.wgpu_queue.submit([encoder.finish()]);
|
||||
|
||||
let buffer_slice = self.output_staging_buffer.slice(..);
|
||||
let (sender, receiver) = std::sync::mpsc::channel();
|
||||
buffer_slice.map_async(wgpu::MapMode::Read, move |r| sender.send(r).unwrap());
|
||||
self.wgpu_device
|
||||
.poll(wgpu::PollType::wait_indefinitely())
|
||||
.unwrap();
|
||||
receiver.recv().unwrap().unwrap();
|
||||
|
||||
let mut depth = Vec::with_capacity(DEPTH_PIXELS);
|
||||
{
|
||||
let view = buffer_slice.get_mapped_range();
|
||||
self.texture_data.extend_from_slice(&view[..]);
|
||||
}
|
||||
self.output_staging_buffer.unmap();
|
||||
|
||||
for y in 0..SIZE_Y {
|
||||
depth.extend(
|
||||
self.texture_data[(STRIDE_SIZE * y) as usize
|
||||
..(STRIDE_SIZE * y + SIZE_X * size_of::<f32>() as u32) as usize]
|
||||
.chunks_exact(4)
|
||||
.map(|b| 1.0 - 2.0 * f32::from_le_bytes(b.try_into().unwrap())),
|
||||
);
|
||||
}
|
||||
|
||||
self.texture_data.clear();
|
||||
depth
|
||||
}
|
||||
}
|
||||
|
||||
/// Python module definition
|
||||
#[pymodule]
|
||||
fn _rust(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<StrafeEnv>()?;
|
||||
Ok(())
|
||||
}
|
||||
423
src/main.rs
423
src/main.rs
@@ -1,423 +0,0 @@
|
||||
use burn::backend::Autodiff;
|
||||
use burn::nn::loss::{MseLoss, Reduction};
|
||||
use burn::nn::{Dropout, DropoutConfig, Linear, LinearConfig, Relu};
|
||||
use burn::optim::{AdamConfig, GradientsParams, Optimizer};
|
||||
use burn::prelude::*;
|
||||
|
||||
type InferenceBackend = burn::backend::Cuda<f32>;
|
||||
type TrainingBackend = Autodiff<InferenceBackend>;
|
||||
|
||||
const LIMITS: wgpu::Limits = wgpu::Limits::defaults();
|
||||
const FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Rgba8UnormSrgb;
|
||||
use strafesnet_graphics::setup;
|
||||
use strafesnet_roblox_bot_file::v0;
|
||||
|
||||
const SIZE: glam::UVec2 = glam::uvec2(64, 36);
|
||||
const POSITION_HISTORY: usize = 4;
|
||||
const INPUT: usize = (SIZE.x * SIZE.y) as usize + POSITION_HISTORY * 3;
|
||||
const HIDDEN: [usize; 2] = [INPUT >> 3, INPUT >> 7];
|
||||
// MoveForward
|
||||
// MoveLeft
|
||||
// MoveBack
|
||||
// MoveRight
|
||||
// Jump
|
||||
// mouse_dx
|
||||
// mouse_dy
|
||||
const OUTPUT: usize = 7;
|
||||
|
||||
// bytes_per_row needs to be a multiple of 256.
|
||||
const STRIDE_SIZE: u32 = (SIZE.x * size_of::<f32>() as u32).next_multiple_of(256);
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
struct Net<B: Backend> {
|
||||
input: Linear<B>,
|
||||
dropout: Dropout,
|
||||
hidden: [Linear<B>; HIDDEN.len() - 1],
|
||||
output: Linear<B>,
|
||||
activation: Relu,
|
||||
}
|
||||
impl<B: Backend> Net<B> {
|
||||
fn init(device: &B::Device) -> Self {
|
||||
let mut it = HIDDEN.into_iter();
|
||||
let mut last_size = it.next().unwrap();
|
||||
let input = LinearConfig::new(INPUT, last_size).init(device);
|
||||
let hidden = core::array::from_fn(|_| {
|
||||
let size = it.next().unwrap();
|
||||
let layer = LinearConfig::new(last_size, size).init(device);
|
||||
last_size = size;
|
||||
layer
|
||||
});
|
||||
let output = LinearConfig::new(last_size, OUTPUT).init(device);
|
||||
let dropout = DropoutConfig::new(0.1).init();
|
||||
Self {
|
||||
input,
|
||||
dropout,
|
||||
hidden,
|
||||
output,
|
||||
activation: Relu::new(),
|
||||
}
|
||||
}
|
||||
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
let x = self.input.forward(input);
|
||||
let x = self.dropout.forward(x);
|
||||
let mut x = self.activation.forward(x);
|
||||
for layer in &self.hidden {
|
||||
x = layer.forward(x);
|
||||
x = self.activation.forward(x);
|
||||
}
|
||||
self.output.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
struct GraphicsState {
|
||||
device: wgpu::Device,
|
||||
queue: wgpu::Queue,
|
||||
graphics: strafesnet_roblox_bot_player::graphics::Graphics,
|
||||
graphics_texture_view: wgpu::TextureView,
|
||||
output_staging_buffer: wgpu::Buffer,
|
||||
texture_data: Vec<u8>,
|
||||
position_history: Vec<glam::Vec3>,
|
||||
}
|
||||
impl GraphicsState {
|
||||
fn new(map: &strafesnet_common::map::CompleteMap) -> Self {
|
||||
let desc = wgpu::InstanceDescriptor::new_without_display_handle_from_env();
|
||||
let instance = wgpu::Instance::new(desc);
|
||||
let (device, queue) = pollster::block_on(async {
|
||||
let adapter = instance
|
||||
.request_adapter(&wgpu::RequestAdapterOptions {
|
||||
power_preference: wgpu::PowerPreference::HighPerformance,
|
||||
force_fallback_adapter: false,
|
||||
compatible_surface: None,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
setup::step4::request_device(&adapter, LIMITS)
|
||||
.await
|
||||
.unwrap()
|
||||
});
|
||||
let mut graphics = strafesnet_roblox_bot_player::graphics::Graphics::new(
|
||||
&device, &queue, SIZE, FORMAT, LIMITS,
|
||||
);
|
||||
graphics.change_map(&device, &queue, map).unwrap();
|
||||
let graphics_texture = device.create_texture(&wgpu::TextureDescriptor {
|
||||
label: Some("RGB texture"),
|
||||
format: FORMAT,
|
||||
size: wgpu::Extent3d {
|
||||
width: SIZE.x,
|
||||
height: SIZE.y,
|
||||
depth_or_array_layers: 1,
|
||||
},
|
||||
mip_level_count: 1,
|
||||
sample_count: 1,
|
||||
dimension: wgpu::TextureDimension::D2,
|
||||
usage: wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
|
||||
view_formats: &[],
|
||||
});
|
||||
let graphics_texture_view = graphics_texture.create_view(&wgpu::TextureViewDescriptor {
|
||||
label: Some("RGB texture view"),
|
||||
aspect: wgpu::TextureAspect::All,
|
||||
usage: Some(
|
||||
wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
|
||||
),
|
||||
..Default::default()
|
||||
});
|
||||
let texture_data = Vec::<u8>::with_capacity((STRIDE_SIZE * SIZE.y) as usize);
|
||||
let output_staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: Some("Output staging buffer"),
|
||||
size: texture_data.capacity() as u64,
|
||||
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
let position_history = Vec::with_capacity(POSITION_HISTORY);
|
||||
Self {
|
||||
device,
|
||||
queue,
|
||||
graphics,
|
||||
graphics_texture_view,
|
||||
output_staging_buffer,
|
||||
texture_data,
|
||||
position_history,
|
||||
}
|
||||
}
|
||||
fn generate_inputs(&mut self, pos: glam::Vec3, angles: glam::Vec2, inputs: &mut Vec<f32>) {
|
||||
// write position history to model inputs
|
||||
if !self.position_history.is_empty() {
|
||||
let camera = strafesnet_graphics::graphics::view_inv(pos, angles).inverse();
|
||||
for &pos in self.position_history.iter().rev() {
|
||||
let relative_pos = camera.transform_vector3(pos);
|
||||
inputs.extend_from_slice(&relative_pos.to_array());
|
||||
}
|
||||
}
|
||||
// fill remaining history with zeroes
|
||||
for _ in self.position_history.len()..POSITION_HISTORY {
|
||||
inputs.extend_from_slice(&[0.0, 0.0, 0.0]);
|
||||
}
|
||||
|
||||
// track position history
|
||||
if self.position_history.len() < POSITION_HISTORY {
|
||||
self.position_history.push(pos);
|
||||
} else {
|
||||
self.position_history.rotate_left(1);
|
||||
*self.position_history.last_mut().unwrap() = pos;
|
||||
}
|
||||
|
||||
let mut encoder = self
|
||||
.device
|
||||
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
|
||||
label: Some("wgpu encoder"),
|
||||
});
|
||||
|
||||
// render!
|
||||
self.graphics
|
||||
.encode_commands(&mut encoder, &self.graphics_texture_view, pos, angles);
|
||||
|
||||
// copy the depth texture into ram
|
||||
encoder.copy_texture_to_buffer(
|
||||
wgpu::TexelCopyTextureInfo {
|
||||
texture: self.graphics.depth_texture(),
|
||||
mip_level: 0,
|
||||
origin: wgpu::Origin3d::ZERO,
|
||||
aspect: wgpu::TextureAspect::All,
|
||||
},
|
||||
wgpu::TexelCopyBufferInfo {
|
||||
buffer: &self.output_staging_buffer,
|
||||
layout: wgpu::TexelCopyBufferLayout {
|
||||
offset: 0,
|
||||
// This needs to be a multiple of 256.
|
||||
bytes_per_row: Some(STRIDE_SIZE),
|
||||
rows_per_image: Some(SIZE.y),
|
||||
},
|
||||
},
|
||||
wgpu::Extent3d {
|
||||
width: SIZE.x,
|
||||
height: SIZE.y,
|
||||
depth_or_array_layers: 1,
|
||||
},
|
||||
);
|
||||
|
||||
self.queue.submit([encoder.finish()]);
|
||||
|
||||
// map buffer
|
||||
let buffer_slice = self.output_staging_buffer.slice(..);
|
||||
let (sender, receiver) = std::sync::mpsc::channel();
|
||||
buffer_slice.map_async(wgpu::MapMode::Read, move |r| sender.send(r).unwrap());
|
||||
self.device
|
||||
.poll(wgpu::PollType::wait_indefinitely())
|
||||
.unwrap();
|
||||
receiver.recv().unwrap().unwrap();
|
||||
|
||||
// copy texture inside a scope so the mapped view gets dropped
|
||||
{
|
||||
let view = buffer_slice.get_mapped_range();
|
||||
self.texture_data.extend_from_slice(&view[..]);
|
||||
}
|
||||
self.output_staging_buffer.unmap();
|
||||
|
||||
// discombolulate stride
|
||||
for y in 0..SIZE.y {
|
||||
inputs.extend(
|
||||
self.texture_data[(STRIDE_SIZE * y) as usize
|
||||
..(STRIDE_SIZE * y + SIZE.x * size_of::<f32>() as u32) as usize]
|
||||
.chunks_exact(4)
|
||||
.map(|b| 1.0 - 2.0 * f32::from_le_bytes(b.try_into().unwrap())),
|
||||
)
|
||||
}
|
||||
|
||||
self.texture_data.clear();
|
||||
}
|
||||
}
|
||||
|
||||
fn training() {
|
||||
let gpu_id: usize = std::env::args()
|
||||
.skip(1)
|
||||
.next()
|
||||
.map(|id| id.parse().unwrap())
|
||||
.unwrap_or_default();
|
||||
// load map
|
||||
// load replay
|
||||
// setup player
|
||||
|
||||
let map_file = include_bytes!("../files/bhop_marble_5692093612.snfm");
|
||||
let bot_file = include_bytes!("../files/bhop_marble_7cf33a64-7120-4514-b9fa-4fe29d9523d.qbot");
|
||||
|
||||
// read files
|
||||
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
|
||||
.unwrap()
|
||||
.into_complete_map()
|
||||
.unwrap();
|
||||
let timelines =
|
||||
strafesnet_roblox_bot_file::v0::read_all_to_block(std::io::Cursor::new(bot_file)).unwrap();
|
||||
let bot = strafesnet_roblox_bot_player::bot::CompleteBot::new(timelines).unwrap();
|
||||
let world_offset = bot.world_offset();
|
||||
let timelines = bot.timelines();
|
||||
|
||||
// setup simulation
|
||||
// run progressively longer segments of the map, starting very close to the end of the run and working the starting time backwards until the ai can run the whole map
|
||||
|
||||
// set up graphics
|
||||
let mut g = GraphicsState::new(&map);
|
||||
|
||||
// training data
|
||||
let training_samples = timelines.input_events.len() - 1;
|
||||
|
||||
let input_size = INPUT * size_of::<f32>();
|
||||
let mut inputs = Vec::with_capacity(input_size * training_samples);
|
||||
let mut targets = Vec::with_capacity(OUTPUT * training_samples);
|
||||
|
||||
// generate all frames
|
||||
println!("Generating {training_samples} frames of depth textures...");
|
||||
let mut it = timelines.input_events.iter();
|
||||
|
||||
// grab mouse position from first frame, omitting one frame from the training data
|
||||
let first = it.next().unwrap();
|
||||
let mut last_mx = first.event.mouse_pos.x;
|
||||
let mut last_my = first.event.mouse_pos.y;
|
||||
|
||||
for input_event in it {
|
||||
let mouse_dx = input_event.event.mouse_pos.x - last_mx;
|
||||
let mouse_dy = input_event.event.mouse_pos.y - last_my;
|
||||
last_mx = input_event.event.mouse_pos.x;
|
||||
last_my = input_event.event.mouse_pos.y;
|
||||
|
||||
// set targets
|
||||
targets.extend([
|
||||
// MoveForward
|
||||
input_event
|
||||
.event
|
||||
.game_controls
|
||||
.contains(v0::GameControls::MoveForward) as i32 as f32,
|
||||
// MoveLeft
|
||||
input_event
|
||||
.event
|
||||
.game_controls
|
||||
.contains(v0::GameControls::MoveLeft) as i32 as f32,
|
||||
// MoveBack
|
||||
input_event
|
||||
.event
|
||||
.game_controls
|
||||
.contains(v0::GameControls::MoveBack) as i32 as f32,
|
||||
// MoveRight
|
||||
input_event
|
||||
.event
|
||||
.game_controls
|
||||
.contains(v0::GameControls::MoveRight) as i32 as f32,
|
||||
// Jump
|
||||
input_event
|
||||
.event
|
||||
.game_controls
|
||||
.contains(v0::GameControls::Jump) as i32 as f32,
|
||||
mouse_dx,
|
||||
mouse_dy,
|
||||
]);
|
||||
|
||||
// find the closest output event to the input event time
|
||||
let output_event_index = timelines
|
||||
.output_events
|
||||
.binary_search_by(|event| event.time.partial_cmp(&input_event.time).unwrap());
|
||||
|
||||
let output_event = match output_event_index {
|
||||
// found the exact same timestamp
|
||||
Ok(output_event_index) => &timelines.output_events[output_event_index],
|
||||
// found first index greater than the time.
|
||||
// check this index and the one before and return the closest one
|
||||
Err(insert_index) => timelines
|
||||
.output_events
|
||||
.get(insert_index)
|
||||
.into_iter()
|
||||
.chain(
|
||||
insert_index
|
||||
.checked_sub(1)
|
||||
.and_then(|index| timelines.output_events.get(index)),
|
||||
)
|
||||
.min_by(|&e0, &e1| {
|
||||
(e0.time - input_event.time)
|
||||
.abs()
|
||||
.partial_cmp(&(e1.time - input_event.time).abs())
|
||||
.unwrap()
|
||||
})
|
||||
.unwrap(),
|
||||
};
|
||||
|
||||
fn vec3(v: v0::Vector3) -> glam::Vec3 {
|
||||
glam::vec3(v.x, v.y, v.z)
|
||||
}
|
||||
fn angles(a: v0::Vector3) -> glam::Vec2 {
|
||||
glam::vec2(a.y, a.x)
|
||||
}
|
||||
|
||||
let pos = vec3(output_event.event.position) - world_offset;
|
||||
let angles = angles(output_event.event.angles);
|
||||
|
||||
g.generate_inputs(pos, angles, &mut inputs);
|
||||
}
|
||||
|
||||
let device = burn::backend::cuda::CudaDevice::new(gpu_id);
|
||||
|
||||
let mut model: Net<TrainingBackend> = Net::init(&device);
|
||||
println!("Training model ({} parameters)", model.num_params());
|
||||
|
||||
let mut optim = AdamConfig::new().init();
|
||||
|
||||
let inputs = Tensor::from_data(
|
||||
TensorData::new(inputs, Shape::new([training_samples, INPUT])),
|
||||
&device,
|
||||
);
|
||||
let targets = Tensor::from_data(
|
||||
TensorData::new(targets, Shape::new([training_samples, OUTPUT])),
|
||||
&device,
|
||||
);
|
||||
|
||||
const LEARNING_RATE: f64 = 0.001;
|
||||
const EPOCHS: usize = 100000;
|
||||
|
||||
let mut best_model = model.clone();
|
||||
let mut best_loss = f32::INFINITY;
|
||||
|
||||
for epoch in 0..EPOCHS {
|
||||
let predictions = model.forward(inputs.clone());
|
||||
|
||||
let loss = MseLoss::new().forward(predictions, targets.clone(), Reduction::Mean);
|
||||
|
||||
let loss_scalar = loss.clone().into_scalar();
|
||||
|
||||
if epoch == 0 {
|
||||
// kinda a fake print, but that's what is happening after this point
|
||||
println!("Compiling optimized GPU kernels...");
|
||||
}
|
||||
|
||||
let grads = loss.backward();
|
||||
let grads = GradientsParams::from_grads(grads, &model);
|
||||
|
||||
// get the best model
|
||||
if loss_scalar < best_loss {
|
||||
best_loss = loss_scalar;
|
||||
best_model = model.clone();
|
||||
}
|
||||
|
||||
model = optim.step(LEARNING_RATE, model, grads);
|
||||
|
||||
if epoch % (EPOCHS >> 4) == 0 || epoch == EPOCHS - 1 {
|
||||
// .clone().into_scalar() extracts the f32 value from a 1-element tensor.
|
||||
println!(" epoch {:>5} | loss = {:.8}", epoch, loss_scalar);
|
||||
}
|
||||
}
|
||||
|
||||
let date_string = format!("{}_{}.model", chrono::Utc::now(), best_loss);
|
||||
best_model
|
||||
.save_file(
|
||||
date_string,
|
||||
&burn::record::BinFileRecorder::<burn::record::FullPrecisionSettings>::new(),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
fn inference() {
|
||||
// load map
|
||||
// setup simulation
|
||||
// setup agent-simulation feedback loop
|
||||
// go!
|
||||
}
|
||||
|
||||
fn main() {
|
||||
training();
|
||||
}
|
||||
4
strafe_ai/__init__.py
Normal file
4
strafe_ai/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from strafe_ai.environment import StrafeEnvironment
|
||||
from strafe_ai.model import StrafeNet
|
||||
|
||||
__all__ = ["StrafeEnvironment", "StrafeNet"]
|
||||
71
strafe_ai/environment.py
Normal file
71
strafe_ai/environment.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
Gymnasium-compatible environment wrapping the Rust physics sim + depth renderer.
|
||||
|
||||
Usage:
|
||||
env = StrafeEnvironment("path/to/map.snfm")
|
||||
obs, info = env.reset()
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
"""
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
||||
from strafe_ai._rust import StrafeEnv
|
||||
|
||||
|
||||
class StrafeEnvironment(gym.Env):
|
||||
"""
|
||||
A bhop environment.
|
||||
|
||||
Observation: dict with "position_history" and "depth" arrays.
|
||||
Action: 7 floats — [forward, left, back, right, jump, mouse_dx, mouse_dy]
|
||||
First 5 are binary (thresholded at 0.5), last 2 are continuous.
|
||||
"""
|
||||
|
||||
metadata = {"render_modes": ["none"]}
|
||||
|
||||
def __init__(self, map_path: str):
|
||||
super().__init__()
|
||||
|
||||
self._env = StrafeEnv(map_path)
|
||||
pos_size, depth_size = StrafeEnv.observation_sizes()
|
||||
|
||||
# observation: position history + depth buffer
|
||||
self.observation_space = spaces.Dict({
|
||||
"position_history": spaces.Box(-np.inf, np.inf, shape=(pos_size,), dtype=np.float32),
|
||||
"depth": spaces.Box(-1.0, 1.0, shape=(depth_size,), dtype=np.float32),
|
||||
})
|
||||
|
||||
# action: 5 binary controls + 2 continuous mouse deltas
|
||||
self.action_space = spaces.Box(
|
||||
low=np.array([0, 0, 0, 0, 0, -100, -100], dtype=np.float32),
|
||||
high=np.array([1, 1, 1, 1, 1, 100, 100], dtype=np.float32),
|
||||
)
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
super().reset(seed=seed)
|
||||
pos_hist, depth = self._env.reset()
|
||||
obs = {
|
||||
"position_history": np.array(pos_hist, dtype=np.float32),
|
||||
"depth": np.array(depth, dtype=np.float32),
|
||||
}
|
||||
return obs, {}
|
||||
|
||||
def step(self, action):
|
||||
action_list = action.tolist() if hasattr(action, "tolist") else list(action)
|
||||
pos_hist, depth, done = self._env.step(action_list)
|
||||
|
||||
obs = {
|
||||
"position_history": np.array(pos_hist, dtype=np.float32),
|
||||
"depth": np.array(depth, dtype=np.float32),
|
||||
}
|
||||
|
||||
# TODO: implement Rhys's reward function (curve_dt progress along WR path)
|
||||
reward = 0.0
|
||||
|
||||
return obs, reward, done, False, {}
|
||||
|
||||
def get_position(self):
|
||||
"""Get the agent's current [x, y, z] position."""
|
||||
return np.array(self._env.get_position(), dtype=np.float32)
|
||||
57
strafe_ai/model.py
Normal file
57
strafe_ai/model.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
Neural network model — PyTorch equivalent of the Rust Net struct.
|
||||
|
||||
Takes position history + depth buffer, outputs 7 control values.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from strafe_ai._rust import StrafeEnv
|
||||
|
||||
POS_SIZE, DEPTH_SIZE = StrafeEnv.observation_sizes()
|
||||
ACTION_SIZE = StrafeEnv.action_size()
|
||||
|
||||
# hidden layer sizes (same ratios as the Rust version)
|
||||
INPUT_SIZE = POS_SIZE + DEPTH_SIZE
|
||||
HIDDEN = [INPUT_SIZE >> 3, INPUT_SIZE >> 5, INPUT_SIZE >> 7]
|
||||
|
||||
|
||||
class StrafeNet(nn.Module):
|
||||
"""
|
||||
Simple feedforward network for bhop control.
|
||||
|
||||
Architecture matches the Rust version:
|
||||
- Dropout on depth input
|
||||
- Concatenate position history + depth
|
||||
- 3 hidden layers with ReLU
|
||||
- Linear output (7 values)
|
||||
"""
|
||||
|
||||
def __init__(self, dropout_rate: float = 0.1):
|
||||
super().__init__()
|
||||
|
||||
self.depth_dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
layers = []
|
||||
prev_size = INPUT_SIZE
|
||||
for hidden_size in HIDDEN:
|
||||
layers.append(nn.Linear(prev_size, hidden_size))
|
||||
layers.append(nn.ReLU())
|
||||
prev_size = hidden_size
|
||||
layers.append(nn.Linear(prev_size, ACTION_SIZE))
|
||||
|
||||
self.network = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, position_history: torch.Tensor, depth: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
position_history: (batch, POS_SIZE) — relative position + angle history
|
||||
depth: (batch, DEPTH_SIZE) — depth buffer pixels
|
||||
|
||||
Returns:
|
||||
(batch, 7) — [forward, left, back, right, jump, mouse_dx, mouse_dy]
|
||||
"""
|
||||
x = self.depth_dropout(depth)
|
||||
x = torch.cat([position_history, x], dim=1)
|
||||
return self.network(x)
|
||||
99
strafe_ai/train.py
Normal file
99
strafe_ai/train.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Training script — imitation learning from .qbot replay files.
|
||||
|
||||
This is the Python equivalent of the Rust training code.
|
||||
It trains the model to predict player controls from depth frames.
|
||||
|
||||
Usage:
|
||||
python -m strafe_ai.train --map-file map.snfm --bot-file replay.qbot --epochs 10000
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from strafe_ai.model import StrafeNet, POS_SIZE, DEPTH_SIZE, ACTION_SIZE
|
||||
from strafe_ai.environment import StrafeEnvironment
|
||||
|
||||
|
||||
def load_training_data(env: StrafeEnvironment, bot_file: str):
|
||||
"""
|
||||
Generate training data by replaying a .qbot file through the environment.
|
||||
|
||||
For now this is a placeholder — the actual .qbot parsing happens in Rust.
|
||||
You would call the Rust training data generator and load the results.
|
||||
"""
|
||||
# TODO: add a Rust function to generate training pairs from .qbot files
|
||||
# For now, this returns dummy data to verify the pipeline works
|
||||
raise NotImplementedError(
|
||||
"Training data generation from .qbot files requires Rust bindings. "
|
||||
"Use the Rust training code for imitation learning, or implement RL below."
|
||||
)
|
||||
|
||||
|
||||
def train_rl(map_file: str, epochs: int, lr: float, device: str):
|
||||
"""
|
||||
Reinforcement learning training loop.
|
||||
|
||||
This is the main training path going forward.
|
||||
"""
|
||||
print(f"Setting up environment with map: {map_file}")
|
||||
env = StrafeEnvironment(map_file)
|
||||
|
||||
print(f"Using device: {device}")
|
||||
torch_device = torch.device(device)
|
||||
|
||||
model = StrafeNet().to(torch_device)
|
||||
num_params = sum(p.numel() for p in model.parameters())
|
||||
print(f"Model has {num_params:,} parameters")
|
||||
|
||||
writer = SummaryWriter("runs/strafe-ai")
|
||||
|
||||
# TODO: implement RL algorithm (PPO via Stable Baselines3, or custom)
|
||||
# For now, run random actions to verify the environment works
|
||||
print("Running environment test (random actions)...")
|
||||
obs, info = env.reset()
|
||||
total_reward = 0.0
|
||||
|
||||
for step in range(100):
|
||||
# random action
|
||||
action = env.action_space.sample()
|
||||
|
||||
# forward pass through model (just to verify it works)
|
||||
pos_hist = torch.tensor(obs["position_history"], device=torch_device).unsqueeze(0)
|
||||
depth = torch.tensor(obs["depth"], device=torch_device).unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
predicted = model(pos_hist, depth)
|
||||
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
total_reward += reward
|
||||
|
||||
if terminated or truncated:
|
||||
obs, info = env.reset()
|
||||
|
||||
if step % 10 == 0:
|
||||
pos = env.get_position()
|
||||
print(f" step {step:4d} | pos = ({pos[0]:.1f}, {pos[1]:.1f}, {pos[2]:.1f})")
|
||||
|
||||
print(f"Test complete. Total reward: {total_reward}")
|
||||
writer.close()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Train strafe-ai")
|
||||
parser.add_argument("--map-file", required=True, help="Path to .snfm map file")
|
||||
parser.add_argument("--bot-file", help="Path to .qbot file (for imitation learning)")
|
||||
parser.add_argument("--epochs", type=int, default=10000)
|
||||
parser.add_argument("--lr", type=float, default=0.001)
|
||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
||||
args = parser.parse_args()
|
||||
|
||||
train_rl(args.map_file, args.epochs, args.lr, args.device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user