don't hardcode map and bot
This commit is contained in:
@@ -15,31 +15,34 @@ pub struct SimulateSubcommand {
|
||||
#[arg(long)]
|
||||
gpu_id: Option<usize>,
|
||||
#[arg(long)]
|
||||
model_path: std::path::PathBuf,
|
||||
model_file: std::path::PathBuf,
|
||||
#[arg(long)]
|
||||
output_file: Option<std::path::PathBuf>,
|
||||
#[arg(long)]
|
||||
map_file: std::path::PathBuf,
|
||||
}
|
||||
|
||||
impl SimulateSubcommand {
|
||||
fn run(self) {
|
||||
let output_file = self.output_file.unwrap_or_else(|| {
|
||||
let mut file_name = self
|
||||
.model_path
|
||||
.model_file
|
||||
.file_stem()
|
||||
.unwrap()
|
||||
.to_str()
|
||||
.unwrap()
|
||||
.to_owned();
|
||||
file_name.push_str("_replay");
|
||||
let mut path = self.model_path.clone();
|
||||
let mut path = self.model_file.clone();
|
||||
path.set_file_name(file_name);
|
||||
path.set_extension("snfb");
|
||||
path
|
||||
});
|
||||
inference(
|
||||
self.gpu_id.unwrap_or_default(),
|
||||
self.model_path,
|
||||
self.model_file,
|
||||
output_file,
|
||||
self.map_file,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -99,7 +102,12 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
fn inference(gpu_id: usize, model_path: std::path::PathBuf, output_file: std::path::PathBuf) {
|
||||
fn inference(
|
||||
gpu_id: usize,
|
||||
model_file: std::path::PathBuf,
|
||||
output_file: std::path::PathBuf,
|
||||
map_file: std::path::PathBuf,
|
||||
) {
|
||||
// pick device
|
||||
let device = burn::backend::cuda::CudaDevice::new(gpu_id);
|
||||
|
||||
@@ -107,14 +115,14 @@ fn inference(gpu_id: usize, model_path: std::path::PathBuf, output_file: std::pa
|
||||
let mut model: Net<InferenceBackend> = Net::init(&device);
|
||||
model = model
|
||||
.load_file(
|
||||
model_path,
|
||||
model_file,
|
||||
&burn::record::BinFileRecorder::<burn::record::FullPrecisionSettings>::new(),
|
||||
&device,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// load map
|
||||
let map_file = include_bytes!("../files/bhop_marble_5692093612.snfm");
|
||||
let map_file = std::fs::read(map_file).unwrap();
|
||||
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
|
||||
.unwrap()
|
||||
.into_complete_map()
|
||||
|
||||
@@ -18,6 +18,10 @@ pub struct TrainSubcommand {
|
||||
epochs: Option<usize>,
|
||||
#[arg(long)]
|
||||
learning_rate: Option<f64>,
|
||||
#[arg(long)]
|
||||
map_file: std::path::PathBuf,
|
||||
#[arg(long)]
|
||||
bot_file: std::path::PathBuf,
|
||||
}
|
||||
impl TrainSubcommand {
|
||||
fn run(self) {
|
||||
@@ -25,6 +29,8 @@ impl TrainSubcommand {
|
||||
self.gpu_id.unwrap_or_default(),
|
||||
self.epochs.unwrap_or(100_000),
|
||||
self.learning_rate.unwrap_or(0.001),
|
||||
self.map_file,
|
||||
self.bot_file,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -38,28 +44,28 @@ use crate::net::{INPUT, Net, OUTPUT, TrainingBackend};
|
||||
|
||||
use strafesnet_roblox_bot_file::v0;
|
||||
|
||||
fn training(gpu_id: usize, epochs: usize, learning_rate: f64) {
|
||||
// load map
|
||||
// load replay
|
||||
// setup player
|
||||
|
||||
let map_file = include_bytes!("../files/bhop_marble_5692093612.snfm");
|
||||
let bot_file = include_bytes!("../files/bhop_marble_7cf33a64-7120-4514-b9fa-4fe29d9523d.qbot");
|
||||
|
||||
fn training(
|
||||
gpu_id: usize,
|
||||
epochs: usize,
|
||||
learning_rate: f64,
|
||||
map_file: std::path::PathBuf,
|
||||
bot_file: std::path::PathBuf,
|
||||
) {
|
||||
// read files
|
||||
let map_file = std::fs::read(map_file).unwrap();
|
||||
let bot_file = std::fs::read(bot_file).unwrap();
|
||||
// load map
|
||||
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
|
||||
.unwrap()
|
||||
.into_complete_map()
|
||||
.unwrap();
|
||||
// load replay
|
||||
let timelines =
|
||||
strafesnet_roblox_bot_file::v0::read_all_to_block(std::io::Cursor::new(bot_file)).unwrap();
|
||||
let bot = strafesnet_roblox_bot_player::bot::CompleteBot::new(timelines).unwrap();
|
||||
let world_offset = bot.world_offset();
|
||||
let timelines = bot.timelines();
|
||||
|
||||
// setup simulation
|
||||
// run progressively longer segments of the map, starting very close to the end of the run and working the starting time backwards until the ai can run the whole map
|
||||
|
||||
// set up graphics
|
||||
let mut g = InputGenerator::new(&map);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user