forked from StrafesNET/strafe-ai
add cli args
This commit is contained in:
@@ -11,11 +11,16 @@ impl Commands {
|
||||
}
|
||||
|
||||
#[derive(clap::Args)]
|
||||
pub struct SimulateSubcommand {}
|
||||
pub struct SimulateSubcommand {
|
||||
#[arg(long)]
|
||||
gpu_id: Option<usize>,
|
||||
#[arg(long)]
|
||||
model_path: std::path::PathBuf,
|
||||
}
|
||||
|
||||
impl SimulateSubcommand {
|
||||
fn run(self) {
|
||||
inference();
|
||||
inference(self.gpu_id.unwrap_or_default(), self.model_path);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,22 +79,15 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
fn inference() {
|
||||
let mut args = std::env::args().skip(1);
|
||||
|
||||
fn inference(gpu_id: usize, model_path: std::path::PathBuf) {
|
||||
// pick device
|
||||
let gpu_id: usize = args
|
||||
.next()
|
||||
.map(|id| id.parse().unwrap())
|
||||
.unwrap_or_default();
|
||||
let device = burn::backend::cuda::CudaDevice::new(gpu_id);
|
||||
|
||||
// load model
|
||||
let path: std::path::PathBuf = args.next().unwrap().parse().unwrap();
|
||||
let mut model: Net<InferenceBackend> = Net::init(&device);
|
||||
model = model
|
||||
.load_file(
|
||||
path,
|
||||
model_path,
|
||||
&burn::record::BinFileRecorder::<burn::record::FullPrecisionSettings>::new(),
|
||||
&device,
|
||||
)
|
||||
|
||||
@@ -11,10 +11,13 @@ impl Commands {
|
||||
}
|
||||
|
||||
#[derive(clap::Args)]
|
||||
pub struct TrainSubcommand {}
|
||||
pub struct TrainSubcommand {
|
||||
#[arg(long)]
|
||||
gpu_id: Option<usize>,
|
||||
}
|
||||
impl TrainSubcommand {
|
||||
fn run(self) {
|
||||
training();
|
||||
training(self.gpu_id.unwrap_or_default());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,12 +30,7 @@ use crate::net::{INPUT, Net, OUTPUT, TrainingBackend};
|
||||
|
||||
use strafesnet_roblox_bot_file::v0;
|
||||
|
||||
fn training() {
|
||||
let gpu_id: usize = std::env::args()
|
||||
.skip(1)
|
||||
.next()
|
||||
.map(|id| id.parse().unwrap())
|
||||
.unwrap_or_default();
|
||||
fn training(gpu_id: usize) {
|
||||
// load map
|
||||
// load replay
|
||||
// setup player
|
||||
|
||||
Reference in New Issue
Block a user