forked from StrafesNET/strafe-ai
training things
This commit is contained in:
33
src/main.rs
33
src/main.rs
@@ -12,6 +12,7 @@ const INPUT: usize = 2;
|
||||
const HIDDEN: usize = 64;
|
||||
const OUTPUT: usize = 8;
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
struct Net<B: Backend> {
|
||||
input: Linear<B>,
|
||||
hidden: [Linear<B>; 64],
|
||||
@@ -47,6 +48,38 @@ fn training() {
|
||||
// setup player
|
||||
// setup simulation
|
||||
// run progressively longer segments of the map, starting very close to the end of the run and working the starting time backwards until the ai can run the whole map
|
||||
|
||||
let device = Default::default();
|
||||
|
||||
let mut model: Net<TrainingBackend> = Net::init(&device);
|
||||
|
||||
let mut optim = SgdConfig::new().init();
|
||||
|
||||
let inputs = Tensor::from_floats([0.0f32; INPUT], &device);
|
||||
let targets = Tensor::from_floats([0.0f32; OUTPUT], &device);
|
||||
|
||||
const LR: f64 = 0.5;
|
||||
const EPOCHS: usize = 100;
|
||||
|
||||
for epoch in 0..EPOCHS {
|
||||
let predictions = model.forward(inputs.clone());
|
||||
|
||||
let loss = MseLoss::new().forward(predictions, targets.clone(), Reduction::Mean);
|
||||
|
||||
if epoch % 2000 == 0 || epoch == EPOCHS - 1 {
|
||||
// .clone().into_scalar() extracts the f32 value from a 1-element tensor.
|
||||
println!(
|
||||
" epoch {:>5} | loss = {:.8}",
|
||||
epoch,
|
||||
loss.clone().into_scalar()
|
||||
);
|
||||
}
|
||||
|
||||
let grads = loss.backward();
|
||||
let grads = GradientsParams::from_grads(grads, &model);
|
||||
|
||||
model = optim.step(LR, model, grads);
|
||||
}
|
||||
}
|
||||
|
||||
fn inference() {
|
||||
|
||||
Reference in New Issue
Block a user