This commit is contained in:
2026-03-25 15:23:36 -07:00
parent 4d2daaf2e0
commit c8b2480e6e

View File

@@ -1,11 +1,11 @@
use burn::backend::{Autodiff, Cuda}; use burn::backend::Autodiff;
use burn::module::AutodiffModule; use burn::module::AutodiffModule;
use burn::nn::loss::{MseLoss, Reduction}; use burn::nn::loss::{MseLoss, Reduction};
use burn::nn::{Linear, LinearConfig, Relu, Sigmoid}; use burn::nn::{Linear, LinearConfig, Relu, Sigmoid};
use burn::optim::{GradientsParams, Optimizer, SgdConfig}; use burn::optim::{GradientsParams, Optimizer, SgdConfig};
use burn::prelude::*; use burn::prelude::*;
type InferenceBackend = Cuda<f32>; type InferenceBackend = burn::backend::Cuda<f32>;
type TrainingBackend = Autodiff<InferenceBackend>; type TrainingBackend = Autodiff<InferenceBackend>;
const INPUT: usize = 2; const INPUT: usize = 2;
@@ -58,7 +58,7 @@ fn training() {
let inputs = Tensor::from_floats([0.0f32; INPUT], &device); let inputs = Tensor::from_floats([0.0f32; INPUT], &device);
let targets = Tensor::from_floats([0.0f32; OUTPUT], &device); let targets = Tensor::from_floats([0.0f32; OUTPUT], &device);
const LR: f64 = 0.5; const LEARNING_RATE: f64 = 0.5;
const EPOCHS: usize = 100; const EPOCHS: usize = 100;
for epoch in 0..EPOCHS { for epoch in 0..EPOCHS {
@@ -66,7 +66,7 @@ fn training() {
let loss = MseLoss::new().forward(predictions, targets.clone(), Reduction::Mean); let loss = MseLoss::new().forward(predictions, targets.clone(), Reduction::Mean);
if epoch % 2000 == 0 || epoch == EPOCHS - 1 { if epoch % (EPOCHS >> 4) == 0 || epoch == EPOCHS - 1 {
// .clone().into_scalar() extracts the f32 value from a 1-element tensor. // .clone().into_scalar() extracts the f32 value from a 1-element tensor.
println!( println!(
" epoch {:>5} | loss = {:.8}", " epoch {:>5} | loss = {:.8}",
@@ -78,7 +78,7 @@ fn training() {
let grads = loss.backward(); let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &model); let grads = GradientsParams::from_grads(grads, &model);
model = optim.step(LR, model, grads); model = optim.step(LEARNING_RATE, model, grads);
} }
} }