forked from StrafesNET/strafe-ai
tweaks
This commit is contained in:
10
src/main.rs
10
src/main.rs
@@ -1,11 +1,11 @@
|
|||||||
use burn::backend::{Autodiff, Cuda};
|
use burn::backend::Autodiff;
|
||||||
use burn::module::AutodiffModule;
|
use burn::module::AutodiffModule;
|
||||||
use burn::nn::loss::{MseLoss, Reduction};
|
use burn::nn::loss::{MseLoss, Reduction};
|
||||||
use burn::nn::{Linear, LinearConfig, Relu, Sigmoid};
|
use burn::nn::{Linear, LinearConfig, Relu, Sigmoid};
|
||||||
use burn::optim::{GradientsParams, Optimizer, SgdConfig};
|
use burn::optim::{GradientsParams, Optimizer, SgdConfig};
|
||||||
use burn::prelude::*;
|
use burn::prelude::*;
|
||||||
|
|
||||||
type InferenceBackend = Cuda<f32>;
|
type InferenceBackend = burn::backend::Cuda<f32>;
|
||||||
type TrainingBackend = Autodiff<InferenceBackend>;
|
type TrainingBackend = Autodiff<InferenceBackend>;
|
||||||
|
|
||||||
const INPUT: usize = 2;
|
const INPUT: usize = 2;
|
||||||
@@ -58,7 +58,7 @@ fn training() {
|
|||||||
let inputs = Tensor::from_floats([0.0f32; INPUT], &device);
|
let inputs = Tensor::from_floats([0.0f32; INPUT], &device);
|
||||||
let targets = Tensor::from_floats([0.0f32; OUTPUT], &device);
|
let targets = Tensor::from_floats([0.0f32; OUTPUT], &device);
|
||||||
|
|
||||||
const LR: f64 = 0.5;
|
const LEARNING_RATE: f64 = 0.5;
|
||||||
const EPOCHS: usize = 100;
|
const EPOCHS: usize = 100;
|
||||||
|
|
||||||
for epoch in 0..EPOCHS {
|
for epoch in 0..EPOCHS {
|
||||||
@@ -66,7 +66,7 @@ fn training() {
|
|||||||
|
|
||||||
let loss = MseLoss::new().forward(predictions, targets.clone(), Reduction::Mean);
|
let loss = MseLoss::new().forward(predictions, targets.clone(), Reduction::Mean);
|
||||||
|
|
||||||
if epoch % 2000 == 0 || epoch == EPOCHS - 1 {
|
if epoch % (EPOCHS >> 4) == 0 || epoch == EPOCHS - 1 {
|
||||||
// .clone().into_scalar() extracts the f32 value from a 1-element tensor.
|
// .clone().into_scalar() extracts the f32 value from a 1-element tensor.
|
||||||
println!(
|
println!(
|
||||||
" epoch {:>5} | loss = {:.8}",
|
" epoch {:>5} | loss = {:.8}",
|
||||||
@@ -78,7 +78,7 @@ fn training() {
|
|||||||
let grads = loss.backward();
|
let grads = loss.backward();
|
||||||
let grads = GradientsParams::from_grads(grads, &model);
|
let grads = GradientsParams::from_grads(grads, &model);
|
||||||
|
|
||||||
model = optim.step(LR, model, grads);
|
model = optim.step(LEARNING_RATE, model, grads);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user