diff --git a/src/main.rs b/src/main.rs index c321b31..5b3348a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -411,13 +411,201 @@ fn training() { .unwrap(); } +use strafesnet_common::instruction::TimedInstruction; +use strafesnet_common::mouse::MouseState; +use strafesnet_common::physics::{ + Instruction as PhysicsInputInstruction, MiscInstruction, ModeInstruction, MouseInstruction, + SetControlInstruction, Time as PhysicsTime, +}; +use strafesnet_physics::physics::{PhysicsContext, PhysicsData, PhysicsState}; + +pub struct Recording { + instructions: Vec>, +} +struct FrameState { + trajectory: strafesnet_physics::physics::Trajectory, + camera: strafesnet_physics::physics::PhysicsCamera, +} +impl FrameState { + fn pos(&self, time: PhysicsTime) -> glam::Vec3 { + self.trajectory + .extrapolated_position(time) + .map(Into::::into) + .to_array() + .into() + } + fn angles(&self) -> glam::Vec2 { + self.camera.simulate_move_angles(glam::IVec2::ZERO) + } +} +struct Session { + geometry_shared: PhysicsData, + simulation: PhysicsState, + recording: Recording, +} +impl Session { + fn get_frame_state(&self) -> FrameState { + FrameState { + trajectory: self.simulation.camera_trajectory(&self.geometry_shared), + camera: self.simulation.camera(), + } + } + fn run(&mut self, time: PhysicsTime, instruction: PhysicsInputInstruction) { + let instruction = TimedInstruction { time, instruction }; + self.recording.instructions.push(instruction.clone()); + PhysicsContext::run_input_instruction( + &mut self.simulation, + &self.geometry_shared, + instruction, + ); + } +} + fn inference() { + let mut args = std::env::args().skip(1); + + // pick device + let gpu_id: usize = args + .next() + .map(|id| id.parse().unwrap()) + .unwrap_or_default(); + let device = burn::backend::cuda::CudaDevice::new(gpu_id); + + // load model + let path: std::path::PathBuf = args.next().unwrap().parse().unwrap(); + let mut model: Net = Net::init(&device); + model = model + .load_file( + path, + &burn::record::BinFileRecorder::::new(), + &device, + ) + .unwrap(); + // load map + let map_file = include_bytes!("../files/bhop_marble_5692093612.snfm"); + let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file)) + .unwrap() + .into_complete_map() + .unwrap(); + let modes = map.modes.clone().denormalize(); + let mode = modes + .get_mode(strafesnet_common::gameplay_modes::ModeId::MAIN) + .unwrap(); + let start_zone = map.models.get(mode.get_start().get() as usize).unwrap(); + let start_offset = glam::Vec3::from_array( + start_zone + .transform + .translation + .map(|f| f.into()) + .to_array(), + ); + + // setup graphics + let mut g = GraphicsState::new(&map); + // setup simulation + let mut session = Session { + geometry_shared: PhysicsData::new(&map), + simulation: PhysicsState::default(), + recording: Recording { + instructions: Vec::new(), + }, + }; + + let mut time = PhysicsTime::ZERO; + + // reset to start zone + session.run(time, PhysicsInputInstruction::Mode(ModeInstruction::Reset)); + // session.run( + // time, + // PhysicsInputInstruction::Misc(MiscInstruction::SetSensitivity(?)), + // ); + session.run( + time, + PhysicsInputInstruction::Mode(ModeInstruction::Restart( + strafesnet_common::gameplay_modes::ModeId::MAIN, + )), + ); + + // TEMP: turn mouse left + let mut mouse_pos = glam::ivec2(-5300, 0); + + const STEP: PhysicsTime = PhysicsTime::from_millis(10); + let mut input_floats = Vec::new(); // setup agent-simulation feedback loop - // go! + for _ in 0..20 * 100 { + // generate inputs + let frame_state = session.get_frame_state(); + g.generate_inputs( + frame_state.pos(time) - start_offset, + frame_state.angles(), + &mut input_floats, + ); + + // inference + let inputs = Tensor::from_data( + TensorData::new(input_floats.clone(), Shape::new([1, INPUT])), + &device, + ); + let outputs = model.forward(inputs).into_data().into_vec::().unwrap(); + + let &[ + move_forward, + move_left, + move_back, + move_right, + jump, + mouse_dx, + mouse_dy, + ] = outputs.as_slice() + else { + panic!() + }; + + macro_rules! set_control { + ($control:ident,$output:expr) => { + session.run( + time, + PhysicsInputInstruction::SetControl(SetControlInstruction::$control( + 0.5 < $output, + )), + ); + }; + } + set_control!(SetMoveForward, move_forward); + set_control!(SetMoveLeft, move_left); + set_control!(SetMoveBack, move_back); + set_control!(SetMoveRight, move_right); + set_control!(SetJump, jump); + + mouse_pos += glam::vec2(mouse_dx, mouse_dy).round().as_ivec2(); + let next_time = time + STEP; + session.run( + time, + PhysicsInputInstruction::Mouse(MouseInstruction::SetNextMouse(MouseState { + pos: mouse_pos, + time: next_time, + })), + ); + + time = next_time; + + // clear + input_floats.clear(); + } + + let date_string = format!("{}.snfb", chrono::Utc::now()); + let file = std::fs::File::create(date_string).unwrap(); + strafesnet_snf::bot::write_bot( + std::io::BufWriter::new(file), + strafesnet_physics::VERSION.get(), + core::mem::take(&mut session.recording.instructions), + ) + .unwrap(); } fn main() { - training(); + // training(); + inference(); }