write things

This commit is contained in:
2026-03-26 10:11:18 -07:00
parent 48eefba747
commit 3ce6ad84a3

View File

@@ -9,8 +9,8 @@ type InferenceBackend = burn::backend::Cuda<f32>;
type TrainingBackend = Autodiff<InferenceBackend>;
const LIMITS: wgpu::Limits = wgpu::Limits::defaults();
use strafesnet_common::session::Time as SessionTime;
use strafesnet_graphics::setup;
use strafesnet_roblox_bot_file::v0;
const INPUT: usize = 2;
const HIDDEN: usize = 64;
@@ -98,7 +98,7 @@ fn training() {
// setup simulation
// run progressively longer segments of the map, starting very close to the end of the run and working the starting time backwards until the ai can run the whole map
// generate all frames
// set up textures
let graphics_texture = device.create_texture(&wgpu::TextureDescriptor {
label: Some("RGB texture"),
format: FORMAT,
@@ -127,8 +127,57 @@ fn training() {
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
// let mut inputs = Vec::new();
for input_event in &timelines.input_events {
let mut inputs = Vec::with_capacity(
(size.x * size.y) as usize * size_of::<f32>() * timelines.input_events.len(),
);
let mut targets = Vec::with_capacity(OUTPUT * timelines.input_events.len());
// generate all frames
let mut it = timelines.input_events.iter();
// grab mouse position from first frame, omitting one frame from the training data
let first = it.next().unwrap();
let mut last_mx = first.event.mouse_pos.x;
let mut last_my = first.event.mouse_pos.y;
for input_event in it {
let mouse_dx = input_event.event.mouse_pos.x - last_mx;
let mouse_dy = input_event.event.mouse_pos.y - last_my;
last_mx = input_event.event.mouse_pos.x;
last_my = input_event.event.mouse_pos.y;
// set targets
targets.extend([
// MoveForward
input_event
.event
.game_controls
.contains(v0::GameControls::MoveForward) as i32 as f32,
// MoveLeft
input_event
.event
.game_controls
.contains(v0::GameControls::MoveLeft) as i32 as f32,
// MoveBack
input_event
.event
.game_controls
.contains(v0::GameControls::MoveBack) as i32 as f32,
// MoveRight
input_event
.event
.game_controls
.contains(v0::GameControls::MoveRight) as i32 as f32,
// Jump
input_event
.event
.game_controls
.contains(v0::GameControls::Jump) as i32 as f32,
mouse_dx,
mouse_dy,
]);
let output_event_index = timelines
.output_events
.binary_search_by(|event| event.time.total_cmp(&input_event.time));
@@ -155,10 +204,10 @@ fn training() {
.unwrap(),
};
fn p(v: strafesnet_roblox_bot_file::v0::Vector3) -> [f32; 3] {
fn p(v: v0::Vector3) -> [f32; 3] {
[v.x, v.y, v.z]
}
fn a(a: strafesnet_roblox_bot_file::v0::Vector3) -> [f32; 2] {
fn a(a: v0::Vector3) -> [f32; 2] {
[a.x, a.y]
}
@@ -200,16 +249,52 @@ fn training() {
);
queue.submit([encoder.finish()]);
// map buffer
let buffer_slice = output_staging_buffer.slice(..);
let (sender, receiver) = std::sync::mpsc::channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |r| sender.send(r).unwrap());
device.poll(wgpu::PollType::wait_indefinitely()).unwrap();
receiver.recv().unwrap().unwrap();
// copy texture
let view = buffer_slice.get_mapped_range();
texture_data.extend_from_slice(view.iter().as_slice());
// discombolulate stride
for y in 0..size.y {
inputs.extend_from_slice(
&texture_data[(stride * y) as usize..(stride * y + size.x) as usize],
);
}
texture_data.clear();
}
// finished with this buffer
output_staging_buffer.unmap();
let device = Default::default();
let mut model: Net<TrainingBackend> = Net::init(&device);
let mut optim = SgdConfig::new().init();
let inputs = Tensor::from_floats([0.0f32; INPUT], &device);
let targets = Tensor::from_floats([0.0f32; OUTPUT], &device);
let inputs = Tensor::from_data(
TensorData::new(
inputs,
Shape::new([
size.x as usize,
size.y as usize,
timelines.input_events.len(),
]),
),
&device,
);
let targets = Tensor::from_data(
TensorData::new(targets, Shape::new([OUTPUT, timelines.input_events.len()])),
&device,
);
const LEARNING_RATE: f64 = 0.5;
const EPOCHS: usize = 100;