don't hardcode map and bot

This commit is contained in:
2026-03-27 16:46:02 -07:00
parent e38c0a92b4
commit 48f9657d0f
2 changed files with 32 additions and 18 deletions

View File

@@ -15,31 +15,34 @@ pub struct SimulateSubcommand {
#[arg(long)] #[arg(long)]
gpu_id: Option<usize>, gpu_id: Option<usize>,
#[arg(long)] #[arg(long)]
model_path: std::path::PathBuf, model_file: std::path::PathBuf,
#[arg(long)] #[arg(long)]
output_file: Option<std::path::PathBuf>, output_file: Option<std::path::PathBuf>,
#[arg(long)]
map_file: std::path::PathBuf,
} }
impl SimulateSubcommand { impl SimulateSubcommand {
fn run(self) { fn run(self) {
let output_file = self.output_file.unwrap_or_else(|| { let output_file = self.output_file.unwrap_or_else(|| {
let mut file_name = self let mut file_name = self
.model_path .model_file
.file_stem() .file_stem()
.unwrap() .unwrap()
.to_str() .to_str()
.unwrap() .unwrap()
.to_owned(); .to_owned();
file_name.push_str("_replay"); file_name.push_str("_replay");
let mut path = self.model_path.clone(); let mut path = self.model_file.clone();
path.set_file_name(file_name); path.set_file_name(file_name);
path.set_extension("snfb"); path.set_extension("snfb");
path path
}); });
inference( inference(
self.gpu_id.unwrap_or_default(), self.gpu_id.unwrap_or_default(),
self.model_path, self.model_file,
output_file, output_file,
self.map_file,
); );
} }
} }
@@ -99,7 +102,12 @@ impl Session {
} }
} }
fn inference(gpu_id: usize, model_path: std::path::PathBuf, output_file: std::path::PathBuf) { fn inference(
gpu_id: usize,
model_file: std::path::PathBuf,
output_file: std::path::PathBuf,
map_file: std::path::PathBuf,
) {
// pick device // pick device
let device = burn::backend::cuda::CudaDevice::new(gpu_id); let device = burn::backend::cuda::CudaDevice::new(gpu_id);
@@ -107,14 +115,14 @@ fn inference(gpu_id: usize, model_path: std::path::PathBuf, output_file: std::pa
let mut model: Net<InferenceBackend> = Net::init(&device); let mut model: Net<InferenceBackend> = Net::init(&device);
model = model model = model
.load_file( .load_file(
model_path, model_file,
&burn::record::BinFileRecorder::<burn::record::FullPrecisionSettings>::new(), &burn::record::BinFileRecorder::<burn::record::FullPrecisionSettings>::new(),
&device, &device,
) )
.unwrap(); .unwrap();
// load map // load map
let map_file = include_bytes!("../files/bhop_marble_5692093612.snfm"); let map_file = std::fs::read(map_file).unwrap();
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file)) let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
.unwrap() .unwrap()
.into_complete_map() .into_complete_map()

View File

@@ -18,6 +18,10 @@ pub struct TrainSubcommand {
epochs: Option<usize>, epochs: Option<usize>,
#[arg(long)] #[arg(long)]
learning_rate: Option<f64>, learning_rate: Option<f64>,
#[arg(long)]
map_file: std::path::PathBuf,
#[arg(long)]
bot_file: std::path::PathBuf,
} }
impl TrainSubcommand { impl TrainSubcommand {
fn run(self) { fn run(self) {
@@ -25,6 +29,8 @@ impl TrainSubcommand {
self.gpu_id.unwrap_or_default(), self.gpu_id.unwrap_or_default(),
self.epochs.unwrap_or(100_000), self.epochs.unwrap_or(100_000),
self.learning_rate.unwrap_or(0.001), self.learning_rate.unwrap_or(0.001),
self.map_file,
self.bot_file,
); );
} }
} }
@@ -38,28 +44,28 @@ use crate::net::{INPUT, Net, OUTPUT, TrainingBackend};
use strafesnet_roblox_bot_file::v0; use strafesnet_roblox_bot_file::v0;
fn training(gpu_id: usize, epochs: usize, learning_rate: f64) { fn training(
// load map gpu_id: usize,
// load replay epochs: usize,
// setup player learning_rate: f64,
map_file: std::path::PathBuf,
let map_file = include_bytes!("../files/bhop_marble_5692093612.snfm"); bot_file: std::path::PathBuf,
let bot_file = include_bytes!("../files/bhop_marble_7cf33a64-7120-4514-b9fa-4fe29d9523d.qbot"); ) {
// read files // read files
let map_file = std::fs::read(map_file).unwrap();
let bot_file = std::fs::read(bot_file).unwrap();
// load map
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file)) let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
.unwrap() .unwrap()
.into_complete_map() .into_complete_map()
.unwrap(); .unwrap();
// load replay
let timelines = let timelines =
strafesnet_roblox_bot_file::v0::read_all_to_block(std::io::Cursor::new(bot_file)).unwrap(); strafesnet_roblox_bot_file::v0::read_all_to_block(std::io::Cursor::new(bot_file)).unwrap();
let bot = strafesnet_roblox_bot_player::bot::CompleteBot::new(timelines).unwrap(); let bot = strafesnet_roblox_bot_player::bot::CompleteBot::new(timelines).unwrap();
let world_offset = bot.world_offset(); let world_offset = bot.world_offset();
let timelines = bot.timelines(); let timelines = bot.timelines();
// setup simulation
// run progressively longer segments of the map, starting very close to the end of the run and working the starting time backwards until the ai can run the whole map
// set up graphics // set up graphics
let mut g = InputGenerator::new(&map); let mut g = InputGenerator::new(&map);