forked from StrafesNET/strafe-ai
save best model
This commit is contained in:
17
src/main.rs
17
src/main.rs
@@ -328,6 +328,9 @@ fn training() {
|
||||
const LEARNING_RATE: f64 = 0.001;
|
||||
const EPOCHS: usize = 100000;
|
||||
|
||||
let mut best_model = model.clone();
|
||||
let mut best_loss = f32::INFINITY;
|
||||
|
||||
for epoch in 0..EPOCHS {
|
||||
let predictions = model.forward(inputs.clone());
|
||||
|
||||
@@ -343,6 +346,12 @@ fn training() {
|
||||
let grads = loss.backward();
|
||||
let grads = GradientsParams::from_grads(grads, &model);
|
||||
|
||||
// get the best model
|
||||
if loss_scalar < best_loss {
|
||||
best_loss = loss_scalar;
|
||||
best_model = model.clone();
|
||||
}
|
||||
|
||||
model = optim.step(LEARNING_RATE, model, grads);
|
||||
|
||||
if epoch % (EPOCHS >> 4) == 0 || epoch == EPOCHS - 1 {
|
||||
@@ -350,6 +359,14 @@ fn training() {
|
||||
println!(" epoch {:>5} | loss = {:.8}", epoch, loss_scalar);
|
||||
}
|
||||
}
|
||||
|
||||
let date_string = format!("{:?}_{}.model", std::time::Instant::now(), best_loss);
|
||||
best_model
|
||||
.save_file(
|
||||
date_string,
|
||||
&burn::record::BinFileRecorder::<burn::record::FullPrecisionSettings>::new(),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
fn inference() {
|
||||
|
||||
Reference in New Issue
Block a user