244 lines
5.7 KiB
Rust
244 lines
5.7 KiB
Rust
#[derive(clap::Subcommand)]
|
|
pub enum Commands {
|
|
Simulate(SimulateSubcommand),
|
|
}
|
|
impl Commands {
|
|
pub fn run(self) {
|
|
match self {
|
|
Commands::Simulate(subcommand) => subcommand.run(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(clap::Args)]
|
|
pub struct SimulateSubcommand {
|
|
#[arg(long)]
|
|
gpu_id: Option<usize>,
|
|
#[arg(long)]
|
|
model_file: std::path::PathBuf,
|
|
#[arg(long)]
|
|
output_file: Option<std::path::PathBuf>,
|
|
#[arg(long)]
|
|
map_file: std::path::PathBuf,
|
|
}
|
|
|
|
impl SimulateSubcommand {
|
|
fn run(self) {
|
|
let output_file = self.output_file.unwrap_or_else(|| {
|
|
let mut file_name = self
|
|
.model_file
|
|
.file_stem()
|
|
.unwrap()
|
|
.to_str()
|
|
.unwrap()
|
|
.to_owned();
|
|
file_name.push_str("_replay.snfb");
|
|
let mut path = self.model_file.clone();
|
|
path.set_file_name(file_name);
|
|
path
|
|
});
|
|
inference(
|
|
self.gpu_id.unwrap_or_default(),
|
|
self.model_file,
|
|
output_file,
|
|
self.map_file,
|
|
);
|
|
}
|
|
}
|
|
|
|
use burn::prelude::*;
|
|
|
|
use crate::inputs::InputGenerator;
|
|
use crate::net::{INPUT, InferenceBackend, Net};
|
|
|
|
use strafesnet_common::instruction::TimedInstruction;
|
|
use strafesnet_common::mouse::MouseState;
|
|
use strafesnet_common::physics::{
|
|
Instruction as PhysicsInputInstruction, MiscInstruction, ModeInstruction, MouseInstruction,
|
|
SetControlInstruction, Time as PhysicsTime,
|
|
};
|
|
use strafesnet_physics::physics::{PhysicsContext, PhysicsData, PhysicsState};
|
|
|
|
pub struct Recording {
|
|
instructions: Vec<TimedInstruction<PhysicsInputInstruction, PhysicsTime>>,
|
|
}
|
|
struct FrameState {
|
|
trajectory: strafesnet_physics::physics::Trajectory,
|
|
camera: strafesnet_physics::physics::PhysicsCamera,
|
|
}
|
|
impl FrameState {
|
|
fn pos(&self, time: PhysicsTime) -> glam::Vec3 {
|
|
self.trajectory
|
|
.extrapolated_position(time)
|
|
.map(Into::<f32>::into)
|
|
.to_array()
|
|
.into()
|
|
}
|
|
fn angles(&self) -> glam::Vec2 {
|
|
self.camera.simulate_move_angles(glam::IVec2::ZERO)
|
|
}
|
|
}
|
|
struct Session {
|
|
geometry_shared: PhysicsData,
|
|
simulation: PhysicsState,
|
|
recording: Recording,
|
|
}
|
|
impl Session {
|
|
fn get_frame_state(&self) -> FrameState {
|
|
FrameState {
|
|
trajectory: self.simulation.camera_trajectory(&self.geometry_shared),
|
|
camera: self.simulation.camera(),
|
|
}
|
|
}
|
|
fn run(&mut self, time: PhysicsTime, instruction: PhysicsInputInstruction) {
|
|
let instruction = TimedInstruction { time, instruction };
|
|
self.recording.instructions.push(instruction.clone());
|
|
PhysicsContext::run_input_instruction(
|
|
&mut self.simulation,
|
|
&self.geometry_shared,
|
|
instruction,
|
|
);
|
|
}
|
|
}
|
|
|
|
fn inference(
|
|
gpu_id: usize,
|
|
model_file: std::path::PathBuf,
|
|
output_file: std::path::PathBuf,
|
|
map_file: std::path::PathBuf,
|
|
) {
|
|
// pick device
|
|
let device = burn::backend::cuda::CudaDevice::new(gpu_id);
|
|
|
|
// load model
|
|
let mut model: Net<InferenceBackend> = Net::init(&device);
|
|
model = model
|
|
.load_file(
|
|
model_file,
|
|
&burn::record::BinFileRecorder::<burn::record::FullPrecisionSettings>::new(),
|
|
&device,
|
|
)
|
|
.unwrap();
|
|
|
|
// load map
|
|
let map_file = std::fs::read(map_file).unwrap();
|
|
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
|
|
.unwrap()
|
|
.into_complete_map()
|
|
.unwrap();
|
|
let modes = map.modes.clone().denormalize();
|
|
let mode = modes
|
|
.get_mode(strafesnet_common::gameplay_modes::ModeId::MAIN)
|
|
.unwrap();
|
|
let start_zone = map.models.get(mode.get_start().get() as usize).unwrap();
|
|
let start_offset = glam::Vec3::from_array(
|
|
start_zone
|
|
.transform
|
|
.translation
|
|
.map(|f| f.into())
|
|
.to_array(),
|
|
);
|
|
|
|
// setup graphics
|
|
let mut g = InputGenerator::new(&map);
|
|
|
|
// setup simulation
|
|
let mut session = Session {
|
|
geometry_shared: PhysicsData::new(&map),
|
|
simulation: PhysicsState::default(),
|
|
recording: Recording {
|
|
instructions: Vec::new(),
|
|
},
|
|
};
|
|
|
|
let mut time = PhysicsTime::ZERO;
|
|
|
|
// reset to start zone
|
|
session.run(time, PhysicsInputInstruction::Mode(ModeInstruction::Reset));
|
|
// session.run(
|
|
// time,
|
|
// PhysicsInputInstruction::Misc(MiscInstruction::SetSensitivity(?)),
|
|
// );
|
|
session.run(
|
|
time,
|
|
PhysicsInputInstruction::Mode(ModeInstruction::Restart(
|
|
strafesnet_common::gameplay_modes::ModeId::MAIN,
|
|
)),
|
|
);
|
|
|
|
// TEMP: turn mouse left
|
|
let mut mouse_pos = glam::ivec2(-5300, 0);
|
|
|
|
const STEP: PhysicsTime = PhysicsTime::from_millis(10);
|
|
let mut input_floats = Vec::new();
|
|
// setup agent-simulation feedback loop
|
|
for _ in 0..20 * 100 {
|
|
// generate inputs
|
|
let frame_state = session.get_frame_state();
|
|
g.generate_inputs(
|
|
frame_state.pos(time) - start_offset,
|
|
frame_state.angles(),
|
|
&mut input_floats,
|
|
);
|
|
|
|
// inference
|
|
let inputs = Tensor::from_data(
|
|
TensorData::new(input_floats.clone(), Shape::new([1, INPUT])),
|
|
&device,
|
|
);
|
|
let outputs = model.forward(inputs).into_data().into_vec::<f32>().unwrap();
|
|
|
|
let &[
|
|
move_forward,
|
|
move_left,
|
|
move_back,
|
|
move_right,
|
|
jump,
|
|
mouse_dx,
|
|
mouse_dy,
|
|
] = outputs.as_slice()
|
|
else {
|
|
panic!()
|
|
};
|
|
|
|
macro_rules! set_control {
|
|
($control:ident,$output:expr) => {
|
|
session.run(
|
|
time,
|
|
PhysicsInputInstruction::SetControl(SetControlInstruction::$control(
|
|
0.5 < $output,
|
|
)),
|
|
);
|
|
};
|
|
}
|
|
set_control!(SetMoveForward, move_forward);
|
|
set_control!(SetMoveLeft, move_left);
|
|
set_control!(SetMoveBack, move_back);
|
|
set_control!(SetMoveRight, move_right);
|
|
set_control!(SetJump, jump);
|
|
|
|
mouse_pos += glam::vec2(mouse_dx, mouse_dy).round().as_ivec2();
|
|
let next_time = time + STEP;
|
|
session.run(
|
|
time,
|
|
PhysicsInputInstruction::Mouse(MouseInstruction::SetNextMouse(MouseState {
|
|
pos: mouse_pos,
|
|
time: next_time,
|
|
})),
|
|
);
|
|
|
|
time = next_time;
|
|
|
|
// clear
|
|
input_floats.clear();
|
|
}
|
|
|
|
let file = std::fs::File::create(output_file).unwrap();
|
|
strafesnet_snf::bot::write_bot(
|
|
std::io::BufWriter::new(file),
|
|
strafesnet_physics::VERSION.get(),
|
|
core::mem::take(&mut session.recording.instructions),
|
|
)
|
|
.unwrap();
|
|
}
|