implement training

This commit is contained in:
2026-03-27 11:13:23 -07:00
parent c04e8d4f3b
commit 59bb8eee12
4 changed files with 268 additions and 48 deletions

1
.gitignore vendored
View File

@@ -1 +1,2 @@
/files
/target /target

9
Cargo.lock generated
View File

@@ -5451,6 +5451,7 @@ name = "strafe-ai"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"burn", "burn",
"glam",
"pollster", "pollster",
"strafesnet_common", "strafesnet_common",
"strafesnet_graphics", "strafesnet_graphics",
@@ -5478,9 +5479,9 @@ dependencies = [
[[package]] [[package]]
name = "strafesnet_graphics" name = "strafesnet_graphics"
version = "0.0.10" version = "0.0.11-depth2"
source = "sparse+https://git.itzana.me/api/packages/strafesnet/cargo/" source = "sparse+https://git.itzana.me/api/packages/strafesnet/cargo/"
checksum = "5080cb31a6cf898daab6c960801828ce9834dba8e932dea6b02823651ea53c33" checksum = "829804ab9c167365e576de8ebd8a245ad979cb24558b086e693e840697d7956c"
dependencies = [ dependencies = [
"bytemuck", "bytemuck",
"ddsfile", "ddsfile",
@@ -5515,9 +5516,9 @@ dependencies = [
[[package]] [[package]]
name = "strafesnet_roblox_bot_player" name = "strafesnet_roblox_bot_player"
version = "0.6.1" version = "0.6.2-depth2"
source = "sparse+https://git.itzana.me/api/packages/strafesnet/cargo/" source = "sparse+https://git.itzana.me/api/packages/strafesnet/cargo/"
checksum = "0669779b58836ac36b0166f5a3f326ee46ce25b4d14b7fd6f75bf273e806c1bf" checksum = "f39e7dfc0cb23e482089dc7eac235ad4b274ccfdb8df7617889a90e64a1e247a"
dependencies = [ dependencies = [
"glam", "glam",
"strafesnet_common", "strafesnet_common",

View File

@@ -5,12 +5,13 @@ edition = "2024"
[dependencies] [dependencies]
burn = { version = "0.20.1", features = ["cuda", "autodiff"] } burn = { version = "0.20.1", features = ["cuda", "autodiff"] }
glam = "0.32.1"
pollster = "0.4.0"
wgpu = "29.0.0" wgpu = "29.0.0"
strafesnet_common = { version = "0.9.0", registry = "strafesnet" } strafesnet_common = { version = "0.9.0", registry = "strafesnet" }
strafesnet_graphics = { version = "0.0.10", registry = "strafesnet" } strafesnet_graphics = { version = "=0.0.11-depth2", registry = "strafesnet" }
strafesnet_physics = { version = "=0.0.2-surf", registry = "strafesnet" } strafesnet_physics = { version = "=0.0.2-surf", registry = "strafesnet" }
strafesnet_roblox_bot_file = { version = "0.9.4", registry = "strafesnet" } strafesnet_roblox_bot_file = { version = "0.9.4", registry = "strafesnet" }
strafesnet_roblox_bot_player = { version = "0.6.1", registry = "strafesnet" } strafesnet_roblox_bot_player = { version = "=0.6.2-depth2", registry = "strafesnet" }
strafesnet_snf = { version = "0.4.0", registry = "strafesnet" } strafesnet_snf = { version = "0.4.0", registry = "strafesnet" }
pollster = "0.4.0"

View File

@@ -1,19 +1,20 @@
use burn::backend::Autodiff; use burn::backend::Autodiff;
use burn::module::AutodiffModule;
use burn::nn::loss::{MseLoss, Reduction}; use burn::nn::loss::{MseLoss, Reduction};
use burn::nn::{Linear, LinearConfig, Relu, Sigmoid}; use burn::nn::{Linear, LinearConfig, Relu};
use burn::optim::{GradientsParams, Optimizer, SgdConfig}; use burn::optim::{AdamConfig, GradientsParams, Optimizer};
use burn::prelude::*; use burn::prelude::*;
type InferenceBackend = burn::backend::Cuda<f32>; type InferenceBackend = burn::backend::Cuda<f32>;
type TrainingBackend = Autodiff<InferenceBackend>; type TrainingBackend = Autodiff<InferenceBackend>;
const LIMITS: wgpu::Limits = wgpu::Limits::defaults(); const LIMITS: wgpu::Limits = wgpu::Limits::defaults();
use strafesnet_common::session::Time as SessionTime;
use strafesnet_graphics::setup; use strafesnet_graphics::setup;
use strafesnet_roblox_bot_file::v0;
const INPUT: usize = 2; const SIZE_X: usize = 64;
const HIDDEN: usize = 64; const SIZE_Y: usize = 36;
const INPUT: usize = SIZE_X * SIZE_Y;
const HIDDEN: [usize; 2] = [INPUT >> 3, INPUT >> 7];
// MoveForward // MoveForward
// MoveLeft // MoveLeft
// MoveBack // MoveBack
@@ -26,19 +27,27 @@ const OUTPUT: usize = 7;
#[derive(Module, Debug)] #[derive(Module, Debug)]
struct Net<B: Backend> { struct Net<B: Backend> {
input: Linear<B>, input: Linear<B>,
hidden: [Linear<B>; 64], hidden: [Linear<B>; HIDDEN.len() - 1],
output: Linear<B>, output: Linear<B>,
activation: Relu, activation: Relu,
sigmoid: Sigmoid,
} }
impl<B: Backend> Net<B> { impl<B: Backend> Net<B> {
fn init(device: &B::Device) -> Self { fn init(device: &B::Device) -> Self {
let mut it = HIDDEN.into_iter();
let mut last_size = it.next().unwrap();
let input = LinearConfig::new(INPUT, last_size).init(device);
let hidden = core::array::from_fn(|_| {
let size = it.next().unwrap();
let layer = LinearConfig::new(last_size, size).init(device);
last_size = size;
layer
});
let output = LinearConfig::new(last_size, OUTPUT).init(device);
Self { Self {
input: LinearConfig::new(INPUT, HIDDEN).init(device), input,
hidden: core::array::from_fn(|_| LinearConfig::new(HIDDEN, HIDDEN).init(device)), hidden,
output: LinearConfig::new(HIDDEN, OUTPUT).init(device), output,
activation: Relu::new(), activation: Relu::new(),
sigmoid: Sigmoid::new(),
} }
} }
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> { fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
@@ -48,20 +57,22 @@ impl<B: Backend> Net<B> {
x = layer.forward(x); x = layer.forward(x);
x = self.activation.forward(x); x = self.activation.forward(x);
} }
let x = self.output.forward(x); self.output.forward(x)
self.sigmoid.forward(x)
} }
} }
fn training() { fn training() {
let gpu_id: usize = std::env::args()
.skip(1)
.next()
.map(|id| id.parse().unwrap())
.unwrap_or_default();
// load map // load map
// load replay // load replay
// setup player // setup player
const SIZE_X: usize = 64;
const SIZE_Y: usize = 36;
let map_file = include_bytes!("../bhop_marble_5692093612.snfm"); let map_file = include_bytes!("../files/bhop_marble_5692093612.snfm");
let bot_file = include_bytes!("../bhop_marble_7cf33a64-7120-4514-b9fa-4fe29d9523d.qbot"); let bot_file = include_bytes!("../files/bhop_marble_7cf33a64-7120-4514-b9fa-4fe29d9523d.qbot");
// read files // read files
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file)) let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
@@ -70,10 +81,9 @@ fn training() {
.unwrap(); .unwrap();
let timelines = let timelines =
strafesnet_roblox_bot_file::v0::read_all_to_block(std::io::Cursor::new(bot_file)).unwrap(); strafesnet_roblox_bot_file::v0::read_all_to_block(std::io::Cursor::new(bot_file)).unwrap();
let bot = strafesnet_roblox_bot_player::bot::CompleteBot::new(timelines).unwrap(); let bot = strafesnet_roblox_bot_player::bot::CompleteBot::new(timelines).unwrap();
let mut playback_head = let world_offset = bot.world_offset();
strafesnet_roblox_bot_player::head::PlaybackHead::new(&bot, SessionTime::ZERO); let timelines = bot.timelines();
// setup graphics // setup graphics
let desc = wgpu::InstanceDescriptor::new_without_display_handle_from_env(); let desc = wgpu::InstanceDescriptor::new_without_display_handle_from_env();
@@ -93,47 +103,252 @@ fn training() {
}); });
const FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Rgba8UnormSrgb; const FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Rgba8UnormSrgb;
let graphics = strafesnet_roblox_bot_player::graphics::Graphics::new( let size = [SIZE_X as u32, SIZE_Y as u32].into();
&device, let mut graphics = strafesnet_roblox_bot_player::graphics::Graphics::new(
&queue, &device, &queue, size, FORMAT, LIMITS,
[SIZE_X as u32, SIZE_Y as u32].into(),
FORMAT,
LIMITS,
); );
graphics.change_map(&device, &queue, &map).unwrap();
// setup simulation // setup simulation
// run progressively longer segments of the map, starting very close to the end of the run and working the starting time backwards until the ai can run the whole map // run progressively longer segments of the map, starting very close to the end of the run and working the starting time backwards until the ai can run the whole map
let device = Default::default(); // set up textures
let graphics_texture = device.create_texture(&wgpu::TextureDescriptor {
label: Some("RGB texture"),
format: FORMAT,
size: wgpu::Extent3d {
width: size.x,
height: size.y,
depth_or_array_layers: 1,
},
mip_level_count: 1,
sample_count: 1,
dimension: wgpu::TextureDimension::D2,
usage: wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
view_formats: &[],
});
let graphics_texture_view = graphics_texture.create_view(&wgpu::TextureViewDescriptor {
label: Some("RGB texture view"),
aspect: wgpu::TextureAspect::All,
usage: Some(wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING),
..Default::default()
});
// bytes_per_row needs to be a multiple of 256.
let stride_size = (size.x * size_of::<f32>() as u32).next_multiple_of(256);
let mut texture_data = Vec::<u8>::with_capacity((stride_size * size.y) as usize);
let output_staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Output staging buffer"),
size: texture_data.capacity() as u64,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
// training data
let training_samples = timelines.input_events.len() - 1;
let input_size = INPUT * size_of::<f32>();
let mut inputs = Vec::with_capacity(input_size * training_samples);
let mut targets = Vec::with_capacity(OUTPUT * training_samples);
// generate all frames
println!("Generating {training_samples} frames of depth textures...");
let mut it = timelines.input_events.iter();
// grab mouse position from first frame, omitting one frame from the training data
let first = it.next().unwrap();
let mut last_mx = first.event.mouse_pos.x;
let mut last_my = first.event.mouse_pos.y;
for input_event in it {
let mouse_dx = input_event.event.mouse_pos.x - last_mx;
let mouse_dy = input_event.event.mouse_pos.y - last_my;
last_mx = input_event.event.mouse_pos.x;
last_my = input_event.event.mouse_pos.y;
// set targets
targets.extend([
// MoveForward
input_event
.event
.game_controls
.contains(v0::GameControls::MoveForward) as i32 as f32,
// MoveLeft
input_event
.event
.game_controls
.contains(v0::GameControls::MoveLeft) as i32 as f32,
// MoveBack
input_event
.event
.game_controls
.contains(v0::GameControls::MoveBack) as i32 as f32,
// MoveRight
input_event
.event
.game_controls
.contains(v0::GameControls::MoveRight) as i32 as f32,
// Jump
input_event
.event
.game_controls
.contains(v0::GameControls::Jump) as i32 as f32,
mouse_dx,
mouse_dy,
]);
// find the closest output event to the input event time
let output_event_index = timelines
.output_events
.binary_search_by(|event| event.time.partial_cmp(&input_event.time).unwrap());
let output_event = match output_event_index {
// found the exact same timestamp
Ok(output_event_index) => &timelines.output_events[output_event_index],
// found first index greater than the time.
// check this index and the one before and return the closest one
Err(insert_index) => timelines
.output_events
.get(insert_index)
.into_iter()
.chain(
insert_index
.checked_sub(1)
.and_then(|index| timelines.output_events.get(index)),
)
.min_by(|&e0, &e1| {
(e0.time - input_event.time)
.abs()
.partial_cmp(&(e1.time - input_event.time).abs())
.unwrap()
})
.unwrap(),
};
fn vec3(v: v0::Vector3) -> glam::Vec3 {
glam::vec3(v.x, v.y, v.z)
}
fn angles(a: v0::Vector3) -> glam::Vec2 {
glam::vec2(a.y, a.x)
}
let pos = vec3(output_event.event.position) - world_offset;
let angles = angles(output_event.event.angles);
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("wgpu encoder"),
});
// render!
graphics.encode_commands(&mut encoder, &graphics_texture_view, pos, angles);
// copy the depth texture into ram
encoder.copy_texture_to_buffer(
wgpu::TexelCopyTextureInfo {
texture: graphics.depth_texture(),
mip_level: 0,
origin: wgpu::Origin3d::ZERO,
aspect: wgpu::TextureAspect::All,
},
wgpu::TexelCopyBufferInfo {
buffer: &output_staging_buffer,
layout: wgpu::TexelCopyBufferLayout {
offset: 0,
// This needs to be a multiple of 256.
bytes_per_row: Some(stride_size as u32),
rows_per_image: Some(size.y),
},
},
wgpu::Extent3d {
width: size.x,
height: size.y,
depth_or_array_layers: 1,
},
);
queue.submit([encoder.finish()]);
// map buffer
let buffer_slice = output_staging_buffer.slice(..);
let (sender, receiver) = std::sync::mpsc::channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |r| sender.send(r).unwrap());
device.poll(wgpu::PollType::wait_indefinitely()).unwrap();
receiver.recv().unwrap().unwrap();
// copy texture inside a scope so the mapped view gets dropped
{
let view = buffer_slice.get_mapped_range();
texture_data.extend_from_slice(&view[..]);
}
output_staging_buffer.unmap();
// discombolulate stride
for y in 0..size.y {
inputs.extend(
texture_data[(stride_size * y) as usize
..(stride_size * y + size.x * size_of::<f32>() as u32) as usize]
.chunks_exact(4)
.map(|b| f32::from_le_bytes(b.try_into().unwrap())),
)
}
texture_data.clear();
}
// normalize inputs
let global_min = *inputs
.iter()
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap();
let global_max = *inputs
.iter()
.max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap();
let global_range = global_max - global_min;
println!("Normalizing to range {global_min} - {global_max}");
inputs.iter_mut().for_each(|value| {
*value = 1.0 - (*value - global_min) / global_range;
});
let device = burn::backend::cuda::CudaDevice::new(gpu_id);
let mut model: Net<TrainingBackend> = Net::init(&device); let mut model: Net<TrainingBackend> = Net::init(&device);
println!("Training model ({} parameters)", model.num_params());
let mut optim = SgdConfig::new().init(); let mut optim = AdamConfig::new().init();
let inputs = Tensor::from_floats([0.0f32; INPUT], &device); let inputs = Tensor::from_data(
let targets = Tensor::from_floats([0.0f32; OUTPUT], &device); TensorData::new(inputs, Shape::new([training_samples, INPUT])),
&device,
);
let targets = Tensor::from_data(
TensorData::new(targets, Shape::new([training_samples, OUTPUT])),
&device,
);
const LEARNING_RATE: f64 = 0.5; const LEARNING_RATE: f64 = 0.001;
const EPOCHS: usize = 100; const EPOCHS: usize = 100000;
for epoch in 0..EPOCHS { for epoch in 0..EPOCHS {
let predictions = model.forward(inputs.clone()); let predictions = model.forward(inputs.clone());
let loss = MseLoss::new().forward(predictions, targets.clone(), Reduction::Mean); let loss = MseLoss::new().forward(predictions, targets.clone(), Reduction::Mean);
if epoch % (EPOCHS >> 4) == 0 || epoch == EPOCHS - 1 { let loss_scalar = loss.clone().into_scalar();
// .clone().into_scalar() extracts the f32 value from a 1-element tensor.
println!( if epoch == 0 {
" epoch {:>5} | loss = {:.8}", // kinda a fake print, but that's what is happening after this point
epoch, println!("Compiling optimized GPU kernels...");
loss.clone().into_scalar()
);
} }
let grads = loss.backward(); let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &model); let grads = GradientsParams::from_grads(grads, &model);
model = optim.step(LEARNING_RATE, model, grads); model = optim.step(LEARNING_RATE, model, grads);
if epoch % (EPOCHS >> 4) == 0 || epoch == EPOCHS - 1 {
// .clone().into_scalar() extracts the f32 value from a 1-element tensor.
println!(" epoch {:>5} | loss = {:.8}", epoch, loss_scalar);
}
} }
} }
@@ -144,4 +359,6 @@ fn inference() {
// go! // go!
} }
fn main() {} fn main() {
training();
}