add cli args

This commit is contained in:
2026-03-27 15:51:36 -07:00
parent b195a7eb95
commit 18cad85b62
2 changed files with 15 additions and 19 deletions

View File

@@ -11,11 +11,16 @@ impl Commands {
}
#[derive(clap::Args)]
pub struct SimulateSubcommand {}
pub struct SimulateSubcommand {
#[arg(long)]
gpu_id: Option<usize>,
#[arg(long)]
model_path: std::path::PathBuf,
}
impl SimulateSubcommand {
fn run(self) {
inference();
inference(self.gpu_id.unwrap_or_default(), self.model_path);
}
}
@@ -74,22 +79,15 @@ impl Session {
}
}
fn inference() {
let mut args = std::env::args().skip(1);
fn inference(gpu_id: usize, model_path: std::path::PathBuf) {
// pick device
let gpu_id: usize = args
.next()
.map(|id| id.parse().unwrap())
.unwrap_or_default();
let device = burn::backend::cuda::CudaDevice::new(gpu_id);
// load model
let path: std::path::PathBuf = args.next().unwrap().parse().unwrap();
let mut model: Net<InferenceBackend> = Net::init(&device);
model = model
.load_file(
path,
model_path,
&burn::record::BinFileRecorder::<burn::record::FullPrecisionSettings>::new(),
&device,
)

View File

@@ -11,10 +11,13 @@ impl Commands {
}
#[derive(clap::Args)]
pub struct TrainSubcommand {}
pub struct TrainSubcommand {
#[arg(long)]
gpu_id: Option<usize>,
}
impl TrainSubcommand {
fn run(self) {
training();
training(self.gpu_id.unwrap_or_default());
}
}
@@ -27,12 +30,7 @@ use crate::net::{INPUT, Net, OUTPUT, TrainingBackend};
use strafesnet_roblox_bot_file::v0;
fn training() {
let gpu_id: usize = std::env::args()
.skip(1)
.next()
.map(|id| id.parse().unwrap())
.unwrap_or_default();
fn training(gpu_id: usize) {
// load map
// load replay
// setup player