Files
strafe-ai/src/net.rs
2026-03-27 16:04:03 -07:00

61 lines
1.6 KiB
Rust

use burn::backend::Autodiff;
use burn::nn::{Dropout, DropoutConfig, Linear, LinearConfig, Relu};
use burn::prelude::*;
pub type InferenceBackend = burn::backend::Cuda<f32>;
pub type TrainingBackend = Autodiff<InferenceBackend>;
pub const SIZE: glam::UVec2 = glam::uvec2(64, 36);
pub const POSITION_HISTORY: usize = 10;
pub const INPUT: usize = (SIZE.x * SIZE.y) as usize + POSITION_HISTORY * 3;
pub const HIDDEN: [usize; 3] = [INPUT >> 3, INPUT >> 5, INPUT >> 7];
// MoveForward
// MoveLeft
// MoveBack
// MoveRight
// Jump
// mouse_dx
// mouse_dy
pub const OUTPUT: usize = 7;
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
input: Linear<B>,
dropout: Dropout,
hidden: [Linear<B>; HIDDEN.len() - 1],
output: Linear<B>,
activation: Relu,
}
impl<B: Backend> Net<B> {
pub fn init(device: &B::Device) -> Self {
let mut it = HIDDEN.into_iter();
let mut last_size = it.next().unwrap();
let input = LinearConfig::new(INPUT, last_size).init(device);
let hidden = core::array::from_fn(|_| {
let size = it.next().unwrap();
let layer = LinearConfig::new(last_size, size).init(device);
last_size = size;
layer
});
let output = LinearConfig::new(last_size, OUTPUT).init(device);
let dropout = DropoutConfig::new(0.1).init();
Self {
input,
dropout,
hidden,
output,
activation: Relu::new(),
}
}
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let x = self.input.forward(input);
let x = self.dropout.forward(x);
let mut x = self.activation.forward(x);
for layer in &self.hidden {
x = layer.forward(x);
x = self.activation.forward(x);
}
self.output.forward(x)
}
}