diff --git a/src/inference.rs b/src/inference.rs index ffd3e73..ea30b69 100644 --- a/src/inference.rs +++ b/src/inference.rs @@ -16,11 +16,31 @@ pub struct SimulateSubcommand { gpu_id: Option, #[arg(long)] model_path: std::path::PathBuf, + #[arg(long)] + output_file: Option, } impl SimulateSubcommand { fn run(self) { - inference(self.gpu_id.unwrap_or_default(), self.model_path); + let output_file = self.output_file.unwrap_or_else(|| { + let mut file_name = self + .model_path + .file_stem() + .unwrap() + .to_str() + .unwrap() + .to_owned(); + file_name.push_str("_replay"); + let mut path = self.model_path.clone(); + path.set_file_name(file_name); + path.set_extension("snfb"); + path + }); + inference( + self.gpu_id.unwrap_or_default(), + self.model_path, + output_file, + ); } } @@ -79,7 +99,7 @@ impl Session { } } -fn inference(gpu_id: usize, model_path: std::path::PathBuf) { +fn inference(gpu_id: usize, model_path: std::path::PathBuf, output_file: std::path::PathBuf) { // pick device let device = burn::backend::cuda::CudaDevice::new(gpu_id); @@ -206,8 +226,7 @@ fn inference(gpu_id: usize, model_path: std::path::PathBuf) { input_floats.clear(); } - let date_string = format!("{}.snfb", chrono::Utc::now()); - let file = std::fs::File::create(date_string).unwrap(); + let file = std::fs::File::create(output_file).unwrap(); strafesnet_snf::bot::write_bot( std::io::BufWriter::new(file), strafesnet_physics::VERSION.get(),