diff --git a/src/inference.rs b/src/inference.rs index ea30b69..cbf776f 100644 --- a/src/inference.rs +++ b/src/inference.rs @@ -15,31 +15,34 @@ pub struct SimulateSubcommand { #[arg(long)] gpu_id: Option, #[arg(long)] - model_path: std::path::PathBuf, + model_file: std::path::PathBuf, #[arg(long)] output_file: Option, + #[arg(long)] + map_file: std::path::PathBuf, } impl SimulateSubcommand { fn run(self) { let output_file = self.output_file.unwrap_or_else(|| { let mut file_name = self - .model_path + .model_file .file_stem() .unwrap() .to_str() .unwrap() .to_owned(); file_name.push_str("_replay"); - let mut path = self.model_path.clone(); + let mut path = self.model_file.clone(); path.set_file_name(file_name); path.set_extension("snfb"); path }); inference( self.gpu_id.unwrap_or_default(), - self.model_path, + self.model_file, output_file, + self.map_file, ); } } @@ -99,7 +102,12 @@ impl Session { } } -fn inference(gpu_id: usize, model_path: std::path::PathBuf, output_file: std::path::PathBuf) { +fn inference( + gpu_id: usize, + model_file: std::path::PathBuf, + output_file: std::path::PathBuf, + map_file: std::path::PathBuf, +) { // pick device let device = burn::backend::cuda::CudaDevice::new(gpu_id); @@ -107,14 +115,14 @@ fn inference(gpu_id: usize, model_path: std::path::PathBuf, output_file: std::pa let mut model: Net = Net::init(&device); model = model .load_file( - model_path, + model_file, &burn::record::BinFileRecorder::::new(), &device, ) .unwrap(); // load map - let map_file = include_bytes!("../files/bhop_marble_5692093612.snfm"); + let map_file = std::fs::read(map_file).unwrap(); let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file)) .unwrap() .into_complete_map() diff --git a/src/training.rs b/src/training.rs index 7e1e49a..7992bfc 100644 --- a/src/training.rs +++ b/src/training.rs @@ -18,6 +18,10 @@ pub struct TrainSubcommand { epochs: Option, #[arg(long)] learning_rate: Option, + #[arg(long)] + map_file: std::path::PathBuf, + #[arg(long)] + bot_file: std::path::PathBuf, } impl TrainSubcommand { fn run(self) { @@ -25,6 +29,8 @@ impl TrainSubcommand { self.gpu_id.unwrap_or_default(), self.epochs.unwrap_or(100_000), self.learning_rate.unwrap_or(0.001), + self.map_file, + self.bot_file, ); } } @@ -38,28 +44,28 @@ use crate::net::{INPUT, Net, OUTPUT, TrainingBackend}; use strafesnet_roblox_bot_file::v0; -fn training(gpu_id: usize, epochs: usize, learning_rate: f64) { - // load map - // load replay - // setup player - - let map_file = include_bytes!("../files/bhop_marble_5692093612.snfm"); - let bot_file = include_bytes!("../files/bhop_marble_7cf33a64-7120-4514-b9fa-4fe29d9523d.qbot"); - +fn training( + gpu_id: usize, + epochs: usize, + learning_rate: f64, + map_file: std::path::PathBuf, + bot_file: std::path::PathBuf, +) { // read files + let map_file = std::fs::read(map_file).unwrap(); + let bot_file = std::fs::read(bot_file).unwrap(); + // load map let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file)) .unwrap() .into_complete_map() .unwrap(); + // load replay let timelines = strafesnet_roblox_bot_file::v0::read_all_to_block(std::io::Cursor::new(bot_file)).unwrap(); let bot = strafesnet_roblox_bot_player::bot::CompleteBot::new(timelines).unwrap(); let world_offset = bot.world_offset(); let timelines = bot.timelines(); - // setup simulation - // run progressively longer segments of the map, starting very close to the end of the run and working the starting time backwards until the ai can run the whole map - // set up graphics let mut g = InputGenerator::new(&map);