save best model

This commit is contained in:
2026-03-27 11:25:34 -07:00
parent 59bb8eee12
commit e19c46d851

View File

@@ -328,6 +328,9 @@ fn training() {
const LEARNING_RATE: f64 = 0.001;
const EPOCHS: usize = 100000;
let mut best_model = model.clone();
let mut best_loss = f32::INFINITY;
for epoch in 0..EPOCHS {
let predictions = model.forward(inputs.clone());
@@ -343,6 +346,12 @@ fn training() {
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
// get the best model
if loss_scalar < best_loss {
best_loss = loss_scalar;
best_model = model.clone();
}
model = optim.step(LEARNING_RATE, model, grads);
if epoch % (EPOCHS >> 4) == 0 || epoch == EPOCHS - 1 {
@@ -350,6 +359,14 @@ fn training() {
println!(" epoch {:>5} | loss = {:.8}", epoch, loss_scalar);
}
}
let date_string = format!("{:?}_{}.model", std::time::Instant::now(), best_loss);
best_model
.save_file(
date_string,
&burn::record::BinFileRecorder::<burn::record::FullPrecisionSettings>::new(),
)
.unwrap();
}
fn inference() {