training things

This commit is contained in:
2026-03-25 14:47:47 -07:00
parent 70d5dc5f31
commit 4d2daaf2e0

View File

@@ -12,6 +12,7 @@ const INPUT: usize = 2;
const HIDDEN: usize = 64;
const OUTPUT: usize = 8;
#[derive(Module, Debug)]
struct Net<B: Backend> {
input: Linear<B>,
hidden: [Linear<B>; 64],
@@ -47,6 +48,38 @@ fn training() {
// setup player
// setup simulation
// run progressively longer segments of the map, starting very close to the end of the run and working the starting time backwards until the ai can run the whole map
let device = Default::default();
let mut model: Net<TrainingBackend> = Net::init(&device);
let mut optim = SgdConfig::new().init();
let inputs = Tensor::from_floats([0.0f32; INPUT], &device);
let targets = Tensor::from_floats([0.0f32; OUTPUT], &device);
const LR: f64 = 0.5;
const EPOCHS: usize = 100;
for epoch in 0..EPOCHS {
let predictions = model.forward(inputs.clone());
let loss = MseLoss::new().forward(predictions, targets.clone(), Reduction::Mean);
if epoch % 2000 == 0 || epoch == EPOCHS - 1 {
// .clone().into_scalar() extracts the f32 value from a 1-element tensor.
println!(
" epoch {:>5} | loss = {:.8}",
epoch,
loss.clone().into_scalar()
);
}
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
model = optim.step(LR, model, grads);
}
}
fn inference() {