don't hardcode map and bot
This commit is contained in:
@@ -15,31 +15,34 @@ pub struct SimulateSubcommand {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
gpu_id: Option<usize>,
|
gpu_id: Option<usize>,
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model_path: std::path::PathBuf,
|
model_file: std::path::PathBuf,
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
output_file: Option<std::path::PathBuf>,
|
output_file: Option<std::path::PathBuf>,
|
||||||
|
#[arg(long)]
|
||||||
|
map_file: std::path::PathBuf,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SimulateSubcommand {
|
impl SimulateSubcommand {
|
||||||
fn run(self) {
|
fn run(self) {
|
||||||
let output_file = self.output_file.unwrap_or_else(|| {
|
let output_file = self.output_file.unwrap_or_else(|| {
|
||||||
let mut file_name = self
|
let mut file_name = self
|
||||||
.model_path
|
.model_file
|
||||||
.file_stem()
|
.file_stem()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.to_str()
|
.to_str()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.to_owned();
|
.to_owned();
|
||||||
file_name.push_str("_replay");
|
file_name.push_str("_replay");
|
||||||
let mut path = self.model_path.clone();
|
let mut path = self.model_file.clone();
|
||||||
path.set_file_name(file_name);
|
path.set_file_name(file_name);
|
||||||
path.set_extension("snfb");
|
path.set_extension("snfb");
|
||||||
path
|
path
|
||||||
});
|
});
|
||||||
inference(
|
inference(
|
||||||
self.gpu_id.unwrap_or_default(),
|
self.gpu_id.unwrap_or_default(),
|
||||||
self.model_path,
|
self.model_file,
|
||||||
output_file,
|
output_file,
|
||||||
|
self.map_file,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -99,7 +102,12 @@ impl Session {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn inference(gpu_id: usize, model_path: std::path::PathBuf, output_file: std::path::PathBuf) {
|
fn inference(
|
||||||
|
gpu_id: usize,
|
||||||
|
model_file: std::path::PathBuf,
|
||||||
|
output_file: std::path::PathBuf,
|
||||||
|
map_file: std::path::PathBuf,
|
||||||
|
) {
|
||||||
// pick device
|
// pick device
|
||||||
let device = burn::backend::cuda::CudaDevice::new(gpu_id);
|
let device = burn::backend::cuda::CudaDevice::new(gpu_id);
|
||||||
|
|
||||||
@@ -107,14 +115,14 @@ fn inference(gpu_id: usize, model_path: std::path::PathBuf, output_file: std::pa
|
|||||||
let mut model: Net<InferenceBackend> = Net::init(&device);
|
let mut model: Net<InferenceBackend> = Net::init(&device);
|
||||||
model = model
|
model = model
|
||||||
.load_file(
|
.load_file(
|
||||||
model_path,
|
model_file,
|
||||||
&burn::record::BinFileRecorder::<burn::record::FullPrecisionSettings>::new(),
|
&burn::record::BinFileRecorder::<burn::record::FullPrecisionSettings>::new(),
|
||||||
&device,
|
&device,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
// load map
|
// load map
|
||||||
let map_file = include_bytes!("../files/bhop_marble_5692093612.snfm");
|
let map_file = std::fs::read(map_file).unwrap();
|
||||||
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
|
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.into_complete_map()
|
.into_complete_map()
|
||||||
|
|||||||
@@ -18,6 +18,10 @@ pub struct TrainSubcommand {
|
|||||||
epochs: Option<usize>,
|
epochs: Option<usize>,
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
learning_rate: Option<f64>,
|
learning_rate: Option<f64>,
|
||||||
|
#[arg(long)]
|
||||||
|
map_file: std::path::PathBuf,
|
||||||
|
#[arg(long)]
|
||||||
|
bot_file: std::path::PathBuf,
|
||||||
}
|
}
|
||||||
impl TrainSubcommand {
|
impl TrainSubcommand {
|
||||||
fn run(self) {
|
fn run(self) {
|
||||||
@@ -25,6 +29,8 @@ impl TrainSubcommand {
|
|||||||
self.gpu_id.unwrap_or_default(),
|
self.gpu_id.unwrap_or_default(),
|
||||||
self.epochs.unwrap_or(100_000),
|
self.epochs.unwrap_or(100_000),
|
||||||
self.learning_rate.unwrap_or(0.001),
|
self.learning_rate.unwrap_or(0.001),
|
||||||
|
self.map_file,
|
||||||
|
self.bot_file,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -38,28 +44,28 @@ use crate::net::{INPUT, Net, OUTPUT, TrainingBackend};
|
|||||||
|
|
||||||
use strafesnet_roblox_bot_file::v0;
|
use strafesnet_roblox_bot_file::v0;
|
||||||
|
|
||||||
fn training(gpu_id: usize, epochs: usize, learning_rate: f64) {
|
fn training(
|
||||||
// load map
|
gpu_id: usize,
|
||||||
// load replay
|
epochs: usize,
|
||||||
// setup player
|
learning_rate: f64,
|
||||||
|
map_file: std::path::PathBuf,
|
||||||
let map_file = include_bytes!("../files/bhop_marble_5692093612.snfm");
|
bot_file: std::path::PathBuf,
|
||||||
let bot_file = include_bytes!("../files/bhop_marble_7cf33a64-7120-4514-b9fa-4fe29d9523d.qbot");
|
) {
|
||||||
|
|
||||||
// read files
|
// read files
|
||||||
|
let map_file = std::fs::read(map_file).unwrap();
|
||||||
|
let bot_file = std::fs::read(bot_file).unwrap();
|
||||||
|
// load map
|
||||||
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
|
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.into_complete_map()
|
.into_complete_map()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
// load replay
|
||||||
let timelines =
|
let timelines =
|
||||||
strafesnet_roblox_bot_file::v0::read_all_to_block(std::io::Cursor::new(bot_file)).unwrap();
|
strafesnet_roblox_bot_file::v0::read_all_to_block(std::io::Cursor::new(bot_file)).unwrap();
|
||||||
let bot = strafesnet_roblox_bot_player::bot::CompleteBot::new(timelines).unwrap();
|
let bot = strafesnet_roblox_bot_player::bot::CompleteBot::new(timelines).unwrap();
|
||||||
let world_offset = bot.world_offset();
|
let world_offset = bot.world_offset();
|
||||||
let timelines = bot.timelines();
|
let timelines = bot.timelines();
|
||||||
|
|
||||||
// setup simulation
|
|
||||||
// run progressively longer segments of the map, starting very close to the end of the run and working the starting time backwards until the ai can run the whole map
|
|
||||||
|
|
||||||
// set up graphics
|
// set up graphics
|
||||||
let mut g = InputGenerator::new(&map);
|
let mut g = InputGenerator::new(&map);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user