don't hardcode map and bot

This commit is contained in:
2026-03-27 16:46:02 -07:00
parent e38c0a92b4
commit 48f9657d0f
2 changed files with 32 additions and 18 deletions

View File

@@ -15,31 +15,34 @@ pub struct SimulateSubcommand {
#[arg(long)]
gpu_id: Option<usize>,
#[arg(long)]
model_path: std::path::PathBuf,
model_file: std::path::PathBuf,
#[arg(long)]
output_file: Option<std::path::PathBuf>,
#[arg(long)]
map_file: std::path::PathBuf,
}
impl SimulateSubcommand {
fn run(self) {
let output_file = self.output_file.unwrap_or_else(|| {
let mut file_name = self
.model_path
.model_file
.file_stem()
.unwrap()
.to_str()
.unwrap()
.to_owned();
file_name.push_str("_replay");
let mut path = self.model_path.clone();
let mut path = self.model_file.clone();
path.set_file_name(file_name);
path.set_extension("snfb");
path
});
inference(
self.gpu_id.unwrap_or_default(),
self.model_path,
self.model_file,
output_file,
self.map_file,
);
}
}
@@ -99,7 +102,12 @@ impl Session {
}
}
fn inference(gpu_id: usize, model_path: std::path::PathBuf, output_file: std::path::PathBuf) {
fn inference(
gpu_id: usize,
model_file: std::path::PathBuf,
output_file: std::path::PathBuf,
map_file: std::path::PathBuf,
) {
// pick device
let device = burn::backend::cuda::CudaDevice::new(gpu_id);
@@ -107,14 +115,14 @@ fn inference(gpu_id: usize, model_path: std::path::PathBuf, output_file: std::pa
let mut model: Net<InferenceBackend> = Net::init(&device);
model = model
.load_file(
model_path,
model_file,
&burn::record::BinFileRecorder::<burn::record::FullPrecisionSettings>::new(),
&device,
)
.unwrap();
// load map
let map_file = include_bytes!("../files/bhop_marble_5692093612.snfm");
let map_file = std::fs::read(map_file).unwrap();
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
.unwrap()
.into_complete_map()

View File

@@ -18,6 +18,10 @@ pub struct TrainSubcommand {
epochs: Option<usize>,
#[arg(long)]
learning_rate: Option<f64>,
#[arg(long)]
map_file: std::path::PathBuf,
#[arg(long)]
bot_file: std::path::PathBuf,
}
impl TrainSubcommand {
fn run(self) {
@@ -25,6 +29,8 @@ impl TrainSubcommand {
self.gpu_id.unwrap_or_default(),
self.epochs.unwrap_or(100_000),
self.learning_rate.unwrap_or(0.001),
self.map_file,
self.bot_file,
);
}
}
@@ -38,28 +44,28 @@ use crate::net::{INPUT, Net, OUTPUT, TrainingBackend};
use strafesnet_roblox_bot_file::v0;
fn training(gpu_id: usize, epochs: usize, learning_rate: f64) {
// load map
// load replay
// setup player
let map_file = include_bytes!("../files/bhop_marble_5692093612.snfm");
let bot_file = include_bytes!("../files/bhop_marble_7cf33a64-7120-4514-b9fa-4fe29d9523d.qbot");
fn training(
gpu_id: usize,
epochs: usize,
learning_rate: f64,
map_file: std::path::PathBuf,
bot_file: std::path::PathBuf,
) {
// read files
let map_file = std::fs::read(map_file).unwrap();
let bot_file = std::fs::read(bot_file).unwrap();
// load map
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
.unwrap()
.into_complete_map()
.unwrap();
// load replay
let timelines =
strafesnet_roblox_bot_file::v0::read_all_to_block(std::io::Cursor::new(bot_file)).unwrap();
let bot = strafesnet_roblox_bot_player::bot::CompleteBot::new(timelines).unwrap();
let world_offset = bot.world_offset();
let timelines = bot.timelines();
// setup simulation
// run progressively longer segments of the map, starting very close to the end of the run and working the starting time backwards until the ai can run the whole map
// set up graphics
let mut g = InputGenerator::new(&map);