forked from StrafesNET/strafe-ai
write things
This commit is contained in:
101
src/main.rs
101
src/main.rs
@@ -9,8 +9,8 @@ type InferenceBackend = burn::backend::Cuda<f32>;
|
||||
type TrainingBackend = Autodiff<InferenceBackend>;
|
||||
|
||||
const LIMITS: wgpu::Limits = wgpu::Limits::defaults();
|
||||
use strafesnet_common::session::Time as SessionTime;
|
||||
use strafesnet_graphics::setup;
|
||||
use strafesnet_roblox_bot_file::v0;
|
||||
|
||||
const INPUT: usize = 2;
|
||||
const HIDDEN: usize = 64;
|
||||
@@ -98,7 +98,7 @@ fn training() {
|
||||
// setup simulation
|
||||
// run progressively longer segments of the map, starting very close to the end of the run and working the starting time backwards until the ai can run the whole map
|
||||
|
||||
// generate all frames
|
||||
// set up textures
|
||||
let graphics_texture = device.create_texture(&wgpu::TextureDescriptor {
|
||||
label: Some("RGB texture"),
|
||||
format: FORMAT,
|
||||
@@ -127,8 +127,57 @@ fn training() {
|
||||
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
// let mut inputs = Vec::new();
|
||||
for input_event in &timelines.input_events {
|
||||
|
||||
let mut inputs = Vec::with_capacity(
|
||||
(size.x * size.y) as usize * size_of::<f32>() * timelines.input_events.len(),
|
||||
);
|
||||
let mut targets = Vec::with_capacity(OUTPUT * timelines.input_events.len());
|
||||
|
||||
// generate all frames
|
||||
let mut it = timelines.input_events.iter();
|
||||
|
||||
// grab mouse position from first frame, omitting one frame from the training data
|
||||
let first = it.next().unwrap();
|
||||
let mut last_mx = first.event.mouse_pos.x;
|
||||
let mut last_my = first.event.mouse_pos.y;
|
||||
|
||||
for input_event in it {
|
||||
let mouse_dx = input_event.event.mouse_pos.x - last_mx;
|
||||
let mouse_dy = input_event.event.mouse_pos.y - last_my;
|
||||
last_mx = input_event.event.mouse_pos.x;
|
||||
last_my = input_event.event.mouse_pos.y;
|
||||
|
||||
// set targets
|
||||
targets.extend([
|
||||
// MoveForward
|
||||
input_event
|
||||
.event
|
||||
.game_controls
|
||||
.contains(v0::GameControls::MoveForward) as i32 as f32,
|
||||
// MoveLeft
|
||||
input_event
|
||||
.event
|
||||
.game_controls
|
||||
.contains(v0::GameControls::MoveLeft) as i32 as f32,
|
||||
// MoveBack
|
||||
input_event
|
||||
.event
|
||||
.game_controls
|
||||
.contains(v0::GameControls::MoveBack) as i32 as f32,
|
||||
// MoveRight
|
||||
input_event
|
||||
.event
|
||||
.game_controls
|
||||
.contains(v0::GameControls::MoveRight) as i32 as f32,
|
||||
// Jump
|
||||
input_event
|
||||
.event
|
||||
.game_controls
|
||||
.contains(v0::GameControls::Jump) as i32 as f32,
|
||||
mouse_dx,
|
||||
mouse_dy,
|
||||
]);
|
||||
|
||||
let output_event_index = timelines
|
||||
.output_events
|
||||
.binary_search_by(|event| event.time.total_cmp(&input_event.time));
|
||||
@@ -155,10 +204,10 @@ fn training() {
|
||||
.unwrap(),
|
||||
};
|
||||
|
||||
fn p(v: strafesnet_roblox_bot_file::v0::Vector3) -> [f32; 3] {
|
||||
fn p(v: v0::Vector3) -> [f32; 3] {
|
||||
[v.x, v.y, v.z]
|
||||
}
|
||||
fn a(a: strafesnet_roblox_bot_file::v0::Vector3) -> [f32; 2] {
|
||||
fn a(a: v0::Vector3) -> [f32; 2] {
|
||||
[a.x, a.y]
|
||||
}
|
||||
|
||||
@@ -200,16 +249,52 @@ fn training() {
|
||||
);
|
||||
|
||||
queue.submit([encoder.finish()]);
|
||||
|
||||
// map buffer
|
||||
let buffer_slice = output_staging_buffer.slice(..);
|
||||
let (sender, receiver) = std::sync::mpsc::channel();
|
||||
buffer_slice.map_async(wgpu::MapMode::Read, move |r| sender.send(r).unwrap());
|
||||
device.poll(wgpu::PollType::wait_indefinitely()).unwrap();
|
||||
receiver.recv().unwrap().unwrap();
|
||||
|
||||
// copy texture
|
||||
let view = buffer_slice.get_mapped_range();
|
||||
texture_data.extend_from_slice(view.iter().as_slice());
|
||||
|
||||
// discombolulate stride
|
||||
for y in 0..size.y {
|
||||
inputs.extend_from_slice(
|
||||
&texture_data[(stride * y) as usize..(stride * y + size.x) as usize],
|
||||
);
|
||||
}
|
||||
|
||||
texture_data.clear();
|
||||
}
|
||||
|
||||
// finished with this buffer
|
||||
output_staging_buffer.unmap();
|
||||
|
||||
let device = Default::default();
|
||||
|
||||
let mut model: Net<TrainingBackend> = Net::init(&device);
|
||||
|
||||
let mut optim = SgdConfig::new().init();
|
||||
|
||||
let inputs = Tensor::from_floats([0.0f32; INPUT], &device);
|
||||
let targets = Tensor::from_floats([0.0f32; OUTPUT], &device);
|
||||
let inputs = Tensor::from_data(
|
||||
TensorData::new(
|
||||
inputs,
|
||||
Shape::new([
|
||||
size.x as usize,
|
||||
size.y as usize,
|
||||
timelines.input_events.len(),
|
||||
]),
|
||||
),
|
||||
&device,
|
||||
);
|
||||
let targets = Tensor::from_data(
|
||||
TensorData::new(targets, Shape::new([OUTPUT, timelines.input_events.len()])),
|
||||
&device,
|
||||
);
|
||||
|
||||
const LEARNING_RATE: f64 = 0.5;
|
||||
const EPOCHS: usize = 100;
|
||||
|
||||
Reference in New Issue
Block a user