training options

This commit is contained in:
2026-03-27 15:56:13 -07:00
parent 18cad85b62
commit fb8c6e2492

View File

@@ -14,10 +14,18 @@ impl Commands {
pub struct TrainSubcommand {
#[arg(long)]
gpu_id: Option<usize>,
#[arg(long)]
epochs: Option<usize>,
#[arg(long)]
learning_rate: Option<f64>,
}
impl TrainSubcommand {
fn run(self) {
training(self.gpu_id.unwrap_or_default());
training(
self.gpu_id.unwrap_or_default(),
self.epochs.unwrap_or(100_000),
self.learning_rate.unwrap_or(0.001),
);
}
}
@@ -30,7 +38,7 @@ use crate::net::{INPUT, Net, OUTPUT, TrainingBackend};
use strafesnet_roblox_bot_file::v0;
fn training(gpu_id: usize) {
fn training(gpu_id: usize, epochs: usize, learning_rate: f64) {
// load map
// load replay
// setup player
@@ -165,13 +173,10 @@ fn training(gpu_id: usize) {
&device,
);
const LEARNING_RATE: f64 = 0.001;
const EPOCHS: usize = 100000;
let mut best_model = model.clone();
let mut best_loss = f32::INFINITY;
for epoch in 0..EPOCHS {
for epoch in 0..epochs {
let predictions = model.forward(inputs.clone());
let loss = MseLoss::new().forward(predictions, targets.clone(), Reduction::Mean);
@@ -192,9 +197,9 @@ fn training(gpu_id: usize) {
best_model = model.clone();
}
model = optim.step(LEARNING_RATE, model, grads);
model = optim.step(learning_rate, model, grads);
if epoch % (EPOCHS >> 4) == 0 || epoch == EPOCHS - 1 {
if epoch % (epochs >> 4) == 0 || epoch == epochs - 1 {
// .clone().into_scalar() extracts the f32 value from a 1-element tensor.
println!(" epoch {:>5} | loss = {:.8}", epoch, loss_scalar);
}