training options
This commit is contained in:
@@ -14,10 +14,18 @@ impl Commands {
|
||||
pub struct TrainSubcommand {
|
||||
#[arg(long)]
|
||||
gpu_id: Option<usize>,
|
||||
#[arg(long)]
|
||||
epochs: Option<usize>,
|
||||
#[arg(long)]
|
||||
learning_rate: Option<f64>,
|
||||
}
|
||||
impl TrainSubcommand {
|
||||
fn run(self) {
|
||||
training(self.gpu_id.unwrap_or_default());
|
||||
training(
|
||||
self.gpu_id.unwrap_or_default(),
|
||||
self.epochs.unwrap_or(100_000),
|
||||
self.learning_rate.unwrap_or(0.001),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,7 +38,7 @@ use crate::net::{INPUT, Net, OUTPUT, TrainingBackend};
|
||||
|
||||
use strafesnet_roblox_bot_file::v0;
|
||||
|
||||
fn training(gpu_id: usize) {
|
||||
fn training(gpu_id: usize, epochs: usize, learning_rate: f64) {
|
||||
// load map
|
||||
// load replay
|
||||
// setup player
|
||||
@@ -165,13 +173,10 @@ fn training(gpu_id: usize) {
|
||||
&device,
|
||||
);
|
||||
|
||||
const LEARNING_RATE: f64 = 0.001;
|
||||
const EPOCHS: usize = 100000;
|
||||
|
||||
let mut best_model = model.clone();
|
||||
let mut best_loss = f32::INFINITY;
|
||||
|
||||
for epoch in 0..EPOCHS {
|
||||
for epoch in 0..epochs {
|
||||
let predictions = model.forward(inputs.clone());
|
||||
|
||||
let loss = MseLoss::new().forward(predictions, targets.clone(), Reduction::Mean);
|
||||
@@ -192,9 +197,9 @@ fn training(gpu_id: usize) {
|
||||
best_model = model.clone();
|
||||
}
|
||||
|
||||
model = optim.step(LEARNING_RATE, model, grads);
|
||||
model = optim.step(learning_rate, model, grads);
|
||||
|
||||
if epoch % (EPOCHS >> 4) == 0 || epoch == EPOCHS - 1 {
|
||||
if epoch % (epochs >> 4) == 0 || epoch == epochs - 1 {
|
||||
// .clone().into_scalar() extracts the f32 value from a 1-element tensor.
|
||||
println!(" epoch {:>5} | loss = {:.8}", epoch, loss_scalar);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user