From e890623f2e4e8676c0e1dd5458c258f063c61193 Mon Sep 17 00:00:00 2001 From: Rhys Lloyd Date: Fri, 27 Mar 2026 14:57:42 -0700 Subject: [PATCH] add dropout to input --- src/main.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/main.rs b/src/main.rs index fcabaed..b1ca2ae 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,6 @@ use burn::backend::Autodiff; use burn::nn::loss::{MseLoss, Reduction}; -use burn::nn::{Linear, LinearConfig, Relu}; +use burn::nn::{Dropout, DropoutConfig, Linear, LinearConfig, Relu}; use burn::optim::{AdamConfig, GradientsParams, Optimizer}; use burn::prelude::*; @@ -30,6 +30,7 @@ const STRIDE_SIZE: u32 = (SIZE.x * size_of::() as u32).next_multiple_of(256 #[derive(Module, Debug)] struct Net { input: Linear, + dropout: Dropout, hidden: [Linear; HIDDEN.len() - 1], output: Linear, activation: Relu, @@ -46,8 +47,10 @@ impl Net { layer }); let output = LinearConfig::new(last_size, OUTPUT).init(device); + let dropout = DropoutConfig::new(0.1).init(); Self { input, + dropout, hidden, output, activation: Relu::new(), @@ -55,6 +58,7 @@ impl Net { } fn forward(&self, input: Tensor) -> Tensor { let x = self.input.forward(input); + let x = self.dropout.forward(x); let mut x = self.activation.forward(x); for layer in &self.hidden { x = layer.forward(x);