diff --git a/src/main.rs b/src/main.rs index 6d4bb1d..858e3a9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -328,6 +328,9 @@ fn training() { const LEARNING_RATE: f64 = 0.001; const EPOCHS: usize = 100000; + let mut best_model = model.clone(); + let mut best_loss = f32::INFINITY; + for epoch in 0..EPOCHS { let predictions = model.forward(inputs.clone()); @@ -343,6 +346,12 @@ fn training() { let grads = loss.backward(); let grads = GradientsParams::from_grads(grads, &model); + // get the best model + if loss_scalar < best_loss { + best_loss = loss_scalar; + best_model = model.clone(); + } + model = optim.step(LEARNING_RATE, model, grads); if epoch % (EPOCHS >> 4) == 0 || epoch == EPOCHS - 1 { @@ -350,6 +359,14 @@ fn training() { println!(" epoch {:>5} | loss = {:.8}", epoch, loss_scalar); } } + + let date_string = format!("{:?}_{}.model", std::time::Instant::now(), best_loss); + best_model + .save_file( + date_string, + &burn::record::BinFileRecorder::::new(), + ) + .unwrap(); } fn inference() {