forked from StrafesNET/strafe-ai
tweaks
This commit is contained in:
10
src/main.rs
10
src/main.rs
@@ -1,11 +1,11 @@
|
||||
use burn::backend::{Autodiff, Cuda};
|
||||
use burn::backend::Autodiff;
|
||||
use burn::module::AutodiffModule;
|
||||
use burn::nn::loss::{MseLoss, Reduction};
|
||||
use burn::nn::{Linear, LinearConfig, Relu, Sigmoid};
|
||||
use burn::optim::{GradientsParams, Optimizer, SgdConfig};
|
||||
use burn::prelude::*;
|
||||
|
||||
type InferenceBackend = Cuda<f32>;
|
||||
type InferenceBackend = burn::backend::Cuda<f32>;
|
||||
type TrainingBackend = Autodiff<InferenceBackend>;
|
||||
|
||||
const INPUT: usize = 2;
|
||||
@@ -58,7 +58,7 @@ fn training() {
|
||||
let inputs = Tensor::from_floats([0.0f32; INPUT], &device);
|
||||
let targets = Tensor::from_floats([0.0f32; OUTPUT], &device);
|
||||
|
||||
const LR: f64 = 0.5;
|
||||
const LEARNING_RATE: f64 = 0.5;
|
||||
const EPOCHS: usize = 100;
|
||||
|
||||
for epoch in 0..EPOCHS {
|
||||
@@ -66,7 +66,7 @@ fn training() {
|
||||
|
||||
let loss = MseLoss::new().forward(predictions, targets.clone(), Reduction::Mean);
|
||||
|
||||
if epoch % 2000 == 0 || epoch == EPOCHS - 1 {
|
||||
if epoch % (EPOCHS >> 4) == 0 || epoch == EPOCHS - 1 {
|
||||
// .clone().into_scalar() extracts the f32 value from a 1-element tensor.
|
||||
println!(
|
||||
" epoch {:>5} | loss = {:.8}",
|
||||
@@ -78,7 +78,7 @@ fn training() {
|
||||
let grads = loss.backward();
|
||||
let grads = GradientsParams::from_grads(grads, &model);
|
||||
|
||||
model = optim.step(LR, model, grads);
|
||||
model = optim.step(LEARNING_RATE, model, grads);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user