From 59bb8eee12b14fd5e97e3c8c56c72f4fddc78825 Mon Sep 17 00:00:00 2001 From: Rhys Lloyd Date: Fri, 27 Mar 2026 11:13:23 -0700 Subject: [PATCH] implement training --- .gitignore | 1 + Cargo.lock | 9 +- Cargo.toml | 7 +- src/main.rs | 299 +++++++++++++++++++++++++++++++++++++++++++++------- 4 files changed, 268 insertions(+), 48 deletions(-) diff --git a/.gitignore b/.gitignore index ea8c4bf..15f9bfb 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ +/files /target diff --git a/Cargo.lock b/Cargo.lock index fe08d76..3cc2424 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5451,6 +5451,7 @@ name = "strafe-ai" version = "0.1.0" dependencies = [ "burn", + "glam", "pollster", "strafesnet_common", "strafesnet_graphics", @@ -5478,9 +5479,9 @@ dependencies = [ [[package]] name = "strafesnet_graphics" -version = "0.0.10" +version = "0.0.11-depth2" source = "sparse+https://git.itzana.me/api/packages/strafesnet/cargo/" -checksum = "5080cb31a6cf898daab6c960801828ce9834dba8e932dea6b02823651ea53c33" +checksum = "829804ab9c167365e576de8ebd8a245ad979cb24558b086e693e840697d7956c" dependencies = [ "bytemuck", "ddsfile", @@ -5515,9 +5516,9 @@ dependencies = [ [[package]] name = "strafesnet_roblox_bot_player" -version = "0.6.1" +version = "0.6.2-depth2" source = "sparse+https://git.itzana.me/api/packages/strafesnet/cargo/" -checksum = "0669779b58836ac36b0166f5a3f326ee46ce25b4d14b7fd6f75bf273e806c1bf" +checksum = "f39e7dfc0cb23e482089dc7eac235ad4b274ccfdb8df7617889a90e64a1e247a" dependencies = [ "glam", "strafesnet_common", diff --git a/Cargo.toml b/Cargo.toml index af9fb76..406f541 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,12 +5,13 @@ edition = "2024" [dependencies] burn = { version = "0.20.1", features = ["cuda", "autodiff"] } +glam = "0.32.1" +pollster = "0.4.0" wgpu = "29.0.0" strafesnet_common = { version = "0.9.0", registry = "strafesnet" } -strafesnet_graphics = { version = "0.0.10", registry = "strafesnet" } +strafesnet_graphics = { version = "=0.0.11-depth2", registry = "strafesnet" } strafesnet_physics = { version = "=0.0.2-surf", registry = "strafesnet" } strafesnet_roblox_bot_file = { version = "0.9.4", registry = "strafesnet" } -strafesnet_roblox_bot_player = { version = "0.6.1", registry = "strafesnet" } +strafesnet_roblox_bot_player = { version = "=0.6.2-depth2", registry = "strafesnet" } strafesnet_snf = { version = "0.4.0", registry = "strafesnet" } -pollster = "0.4.0" diff --git a/src/main.rs b/src/main.rs index 15a1aa3..6d4bb1d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,19 +1,20 @@ use burn::backend::Autodiff; -use burn::module::AutodiffModule; use burn::nn::loss::{MseLoss, Reduction}; -use burn::nn::{Linear, LinearConfig, Relu, Sigmoid}; -use burn::optim::{GradientsParams, Optimizer, SgdConfig}; +use burn::nn::{Linear, LinearConfig, Relu}; +use burn::optim::{AdamConfig, GradientsParams, Optimizer}; use burn::prelude::*; type InferenceBackend = burn::backend::Cuda; type TrainingBackend = Autodiff; const LIMITS: wgpu::Limits = wgpu::Limits::defaults(); -use strafesnet_common::session::Time as SessionTime; use strafesnet_graphics::setup; +use strafesnet_roblox_bot_file::v0; -const INPUT: usize = 2; -const HIDDEN: usize = 64; +const SIZE_X: usize = 64; +const SIZE_Y: usize = 36; +const INPUT: usize = SIZE_X * SIZE_Y; +const HIDDEN: [usize; 2] = [INPUT >> 3, INPUT >> 7]; // MoveForward // MoveLeft // MoveBack @@ -26,19 +27,27 @@ const OUTPUT: usize = 7; #[derive(Module, Debug)] struct Net { input: Linear, - hidden: [Linear; 64], + hidden: [Linear; HIDDEN.len() - 1], output: Linear, activation: Relu, - sigmoid: Sigmoid, } impl Net { fn init(device: &B::Device) -> Self { + let mut it = HIDDEN.into_iter(); + let mut last_size = it.next().unwrap(); + let input = LinearConfig::new(INPUT, last_size).init(device); + let hidden = core::array::from_fn(|_| { + let size = it.next().unwrap(); + let layer = LinearConfig::new(last_size, size).init(device); + last_size = size; + layer + }); + let output = LinearConfig::new(last_size, OUTPUT).init(device); Self { - input: LinearConfig::new(INPUT, HIDDEN).init(device), - hidden: core::array::from_fn(|_| LinearConfig::new(HIDDEN, HIDDEN).init(device)), - output: LinearConfig::new(HIDDEN, OUTPUT).init(device), + input, + hidden, + output, activation: Relu::new(), - sigmoid: Sigmoid::new(), } } fn forward(&self, input: Tensor) -> Tensor { @@ -48,20 +57,22 @@ impl Net { x = layer.forward(x); x = self.activation.forward(x); } - let x = self.output.forward(x); - self.sigmoid.forward(x) + self.output.forward(x) } } fn training() { + let gpu_id: usize = std::env::args() + .skip(1) + .next() + .map(|id| id.parse().unwrap()) + .unwrap_or_default(); // load map // load replay // setup player - const SIZE_X: usize = 64; - const SIZE_Y: usize = 36; - let map_file = include_bytes!("../bhop_marble_5692093612.snfm"); - let bot_file = include_bytes!("../bhop_marble_7cf33a64-7120-4514-b9fa-4fe29d9523d.qbot"); + let map_file = include_bytes!("../files/bhop_marble_5692093612.snfm"); + let bot_file = include_bytes!("../files/bhop_marble_7cf33a64-7120-4514-b9fa-4fe29d9523d.qbot"); // read files let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file)) @@ -70,10 +81,9 @@ fn training() { .unwrap(); let timelines = strafesnet_roblox_bot_file::v0::read_all_to_block(std::io::Cursor::new(bot_file)).unwrap(); - let bot = strafesnet_roblox_bot_player::bot::CompleteBot::new(timelines).unwrap(); - let mut playback_head = - strafesnet_roblox_bot_player::head::PlaybackHead::new(&bot, SessionTime::ZERO); + let world_offset = bot.world_offset(); + let timelines = bot.timelines(); // setup graphics let desc = wgpu::InstanceDescriptor::new_without_display_handle_from_env(); @@ -93,47 +103,252 @@ fn training() { }); const FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Rgba8UnormSrgb; - let graphics = strafesnet_roblox_bot_player::graphics::Graphics::new( - &device, - &queue, - [SIZE_X as u32, SIZE_Y as u32].into(), - FORMAT, - LIMITS, + let size = [SIZE_X as u32, SIZE_Y as u32].into(); + let mut graphics = strafesnet_roblox_bot_player::graphics::Graphics::new( + &device, &queue, size, FORMAT, LIMITS, ); + graphics.change_map(&device, &queue, &map).unwrap(); // setup simulation // run progressively longer segments of the map, starting very close to the end of the run and working the starting time backwards until the ai can run the whole map - let device = Default::default(); + // set up textures + let graphics_texture = device.create_texture(&wgpu::TextureDescriptor { + label: Some("RGB texture"), + format: FORMAT, + size: wgpu::Extent3d { + width: size.x, + height: size.y, + depth_or_array_layers: 1, + }, + mip_level_count: 1, + sample_count: 1, + dimension: wgpu::TextureDimension::D2, + usage: wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING, + view_formats: &[], + }); + let graphics_texture_view = graphics_texture.create_view(&wgpu::TextureViewDescriptor { + label: Some("RGB texture view"), + aspect: wgpu::TextureAspect::All, + usage: Some(wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING), + ..Default::default() + }); + // bytes_per_row needs to be a multiple of 256. + let stride_size = (size.x * size_of::() as u32).next_multiple_of(256); + let mut texture_data = Vec::::with_capacity((stride_size * size.y) as usize); + let output_staging_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Output staging buffer"), + size: texture_data.capacity() as u64, + usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ, + mapped_at_creation: false, + }); + + // training data + let training_samples = timelines.input_events.len() - 1; + + let input_size = INPUT * size_of::(); + let mut inputs = Vec::with_capacity(input_size * training_samples); + let mut targets = Vec::with_capacity(OUTPUT * training_samples); + + // generate all frames + println!("Generating {training_samples} frames of depth textures..."); + let mut it = timelines.input_events.iter(); + + // grab mouse position from first frame, omitting one frame from the training data + let first = it.next().unwrap(); + let mut last_mx = first.event.mouse_pos.x; + let mut last_my = first.event.mouse_pos.y; + + for input_event in it { + let mouse_dx = input_event.event.mouse_pos.x - last_mx; + let mouse_dy = input_event.event.mouse_pos.y - last_my; + last_mx = input_event.event.mouse_pos.x; + last_my = input_event.event.mouse_pos.y; + + // set targets + targets.extend([ + // MoveForward + input_event + .event + .game_controls + .contains(v0::GameControls::MoveForward) as i32 as f32, + // MoveLeft + input_event + .event + .game_controls + .contains(v0::GameControls::MoveLeft) as i32 as f32, + // MoveBack + input_event + .event + .game_controls + .contains(v0::GameControls::MoveBack) as i32 as f32, + // MoveRight + input_event + .event + .game_controls + .contains(v0::GameControls::MoveRight) as i32 as f32, + // Jump + input_event + .event + .game_controls + .contains(v0::GameControls::Jump) as i32 as f32, + mouse_dx, + mouse_dy, + ]); + + // find the closest output event to the input event time + let output_event_index = timelines + .output_events + .binary_search_by(|event| event.time.partial_cmp(&input_event.time).unwrap()); + + let output_event = match output_event_index { + // found the exact same timestamp + Ok(output_event_index) => &timelines.output_events[output_event_index], + // found first index greater than the time. + // check this index and the one before and return the closest one + Err(insert_index) => timelines + .output_events + .get(insert_index) + .into_iter() + .chain( + insert_index + .checked_sub(1) + .and_then(|index| timelines.output_events.get(index)), + ) + .min_by(|&e0, &e1| { + (e0.time - input_event.time) + .abs() + .partial_cmp(&(e1.time - input_event.time).abs()) + .unwrap() + }) + .unwrap(), + }; + + fn vec3(v: v0::Vector3) -> glam::Vec3 { + glam::vec3(v.x, v.y, v.z) + } + fn angles(a: v0::Vector3) -> glam::Vec2 { + glam::vec2(a.y, a.x) + } + + let pos = vec3(output_event.event.position) - world_offset; + let angles = angles(output_event.event.angles); + + let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("wgpu encoder"), + }); + + // render! + graphics.encode_commands(&mut encoder, &graphics_texture_view, pos, angles); + + // copy the depth texture into ram + encoder.copy_texture_to_buffer( + wgpu::TexelCopyTextureInfo { + texture: graphics.depth_texture(), + mip_level: 0, + origin: wgpu::Origin3d::ZERO, + aspect: wgpu::TextureAspect::All, + }, + wgpu::TexelCopyBufferInfo { + buffer: &output_staging_buffer, + layout: wgpu::TexelCopyBufferLayout { + offset: 0, + // This needs to be a multiple of 256. + bytes_per_row: Some(stride_size as u32), + rows_per_image: Some(size.y), + }, + }, + wgpu::Extent3d { + width: size.x, + height: size.y, + depth_or_array_layers: 1, + }, + ); + + queue.submit([encoder.finish()]); + + // map buffer + let buffer_slice = output_staging_buffer.slice(..); + let (sender, receiver) = std::sync::mpsc::channel(); + buffer_slice.map_async(wgpu::MapMode::Read, move |r| sender.send(r).unwrap()); + device.poll(wgpu::PollType::wait_indefinitely()).unwrap(); + receiver.recv().unwrap().unwrap(); + + // copy texture inside a scope so the mapped view gets dropped + { + let view = buffer_slice.get_mapped_range(); + texture_data.extend_from_slice(&view[..]); + } + output_staging_buffer.unmap(); + + // discombolulate stride + for y in 0..size.y { + inputs.extend( + texture_data[(stride_size * y) as usize + ..(stride_size * y + size.x * size_of::() as u32) as usize] + .chunks_exact(4) + .map(|b| f32::from_le_bytes(b.try_into().unwrap())), + ) + } + + texture_data.clear(); + } + + // normalize inputs + let global_min = *inputs + .iter() + .min_by(|a, b| a.partial_cmp(b).unwrap()) + .unwrap(); + let global_max = *inputs + .iter() + .max_by(|a, b| a.partial_cmp(b).unwrap()) + .unwrap(); + let global_range = global_max - global_min; + println!("Normalizing to range {global_min} - {global_max}"); + inputs.iter_mut().for_each(|value| { + *value = 1.0 - (*value - global_min) / global_range; + }); + + let device = burn::backend::cuda::CudaDevice::new(gpu_id); let mut model: Net = Net::init(&device); + println!("Training model ({} parameters)", model.num_params()); - let mut optim = SgdConfig::new().init(); + let mut optim = AdamConfig::new().init(); - let inputs = Tensor::from_floats([0.0f32; INPUT], &device); - let targets = Tensor::from_floats([0.0f32; OUTPUT], &device); + let inputs = Tensor::from_data( + TensorData::new(inputs, Shape::new([training_samples, INPUT])), + &device, + ); + let targets = Tensor::from_data( + TensorData::new(targets, Shape::new([training_samples, OUTPUT])), + &device, + ); - const LEARNING_RATE: f64 = 0.5; - const EPOCHS: usize = 100; + const LEARNING_RATE: f64 = 0.001; + const EPOCHS: usize = 100000; for epoch in 0..EPOCHS { let predictions = model.forward(inputs.clone()); let loss = MseLoss::new().forward(predictions, targets.clone(), Reduction::Mean); - if epoch % (EPOCHS >> 4) == 0 || epoch == EPOCHS - 1 { - // .clone().into_scalar() extracts the f32 value from a 1-element tensor. - println!( - " epoch {:>5} | loss = {:.8}", - epoch, - loss.clone().into_scalar() - ); + let loss_scalar = loss.clone().into_scalar(); + + if epoch == 0 { + // kinda a fake print, but that's what is happening after this point + println!("Compiling optimized GPU kernels..."); } let grads = loss.backward(); let grads = GradientsParams::from_grads(grads, &model); model = optim.step(LEARNING_RATE, model, grads); + + if epoch % (EPOCHS >> 4) == 0 || epoch == EPOCHS - 1 { + // .clone().into_scalar() extracts the f32 value from a 1-element tensor. + println!(" epoch {:>5} | loss = {:.8}", epoch, loss_scalar); + } } } @@ -144,4 +359,6 @@ fn inference() { // go! } -fn main() {} +fn main() { + training(); +}