forked from StrafesNET/strafe-ai
Compare commits
37 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
9248bdc362
|
|||
|
696a148786
|
|||
|
9f278f3269
|
|||
|
ee3e173256
|
|||
|
d908febda0
|
|||
|
4dd494aff1
|
|||
|
aa9d7eaace
|
|||
|
f406f126ee
|
|||
|
ec73f62f89
|
|||
|
25bee24e4c
|
|||
|
93c01910cb
|
|||
|
293daf3ab2
|
|||
|
8283759d47
|
|||
|
d31868f033
|
|||
|
da20ad0464
|
|||
|
8ac19cb5fe
|
|||
|
6a4c222c90
|
|||
|
989bc37dc4
|
|||
|
9e441c1d95
|
|||
|
be265fae24
|
|||
|
07c9535cb1
|
|||
|
4c551ed1a8
|
|||
|
46a4f25e55
|
|||
|
5724cabba1
|
|||
|
02cfd6c052
|
|||
|
4d3811a590
|
|||
|
4f51665372
|
|||
|
669b6f7614
|
|||
|
dd26c463f0
|
|||
|
6775f647cc
|
|||
|
1cb5c01bb0
|
|||
|
ba406eb2dd
|
|||
|
1f0475ef46
|
|||
|
dc72a8d5e3
|
|||
|
3ce6ad84a3
|
|||
|
48eefba747
|
|||
|
842949b6d6
|
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1 +1,2 @@
|
|||||||
|
/files
|
||||||
/target
|
/target
|
||||||
|
|||||||
25
Cargo.lock
generated
25
Cargo.lock
generated
@@ -1109,7 +1109,7 @@ checksum = "fe6d2e5af09e8c8ad56c969f2157a3d4238cebc7c55f0a517728c38f7b200f81"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
"termcolor",
|
"termcolor",
|
||||||
"unicode-width 0.1.14",
|
"unicode-width 0.2.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1120,7 +1120,7 @@ checksum = "af491d569909a7e4dee0ad7db7f5341fef5c614d5b8ec8cf765732aba3cff681"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
"termcolor",
|
"termcolor",
|
||||||
"unicode-width 0.1.14",
|
"unicode-width 0.2.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1135,7 +1135,7 @@ version = "3.1.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "faf9468729b8cbcea668e36183cb69d317348c2e08e994829fb56ebfdfbaac34"
|
checksum = "faf9468729b8cbcea668e36183cb69d317348c2e08e994829fb56ebfdfbaac34"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"windows-sys 0.52.0",
|
"windows-sys 0.59.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -2232,7 +2232,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
|
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
"windows-sys 0.52.0",
|
"windows-sys 0.59.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -2919,7 +2919,7 @@ dependencies = [
|
|||||||
"log",
|
"log",
|
||||||
"presser",
|
"presser",
|
||||||
"thiserror 2.0.18",
|
"thiserror 2.0.18",
|
||||||
"windows 0.58.0",
|
"windows 0.61.3",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -5054,7 +5054,7 @@ dependencies = [
|
|||||||
"errno",
|
"errno",
|
||||||
"libc",
|
"libc",
|
||||||
"linux-raw-sys 0.12.1",
|
"linux-raw-sys 0.12.1",
|
||||||
"windows-sys 0.52.0",
|
"windows-sys 0.59.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -5451,6 +5451,7 @@ name = "strafe-ai"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"burn",
|
"burn",
|
||||||
|
"glam",
|
||||||
"pollster",
|
"pollster",
|
||||||
"strafesnet_common",
|
"strafesnet_common",
|
||||||
"strafesnet_graphics",
|
"strafesnet_graphics",
|
||||||
@@ -5478,9 +5479,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "strafesnet_graphics"
|
name = "strafesnet_graphics"
|
||||||
version = "0.0.10"
|
version = "0.0.11-depth2"
|
||||||
source = "sparse+https://git.itzana.me/api/packages/strafesnet/cargo/"
|
source = "sparse+https://git.itzana.me/api/packages/strafesnet/cargo/"
|
||||||
checksum = "5080cb31a6cf898daab6c960801828ce9834dba8e932dea6b02823651ea53c33"
|
checksum = "829804ab9c167365e576de8ebd8a245ad979cb24558b086e693e840697d7956c"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bytemuck",
|
"bytemuck",
|
||||||
"ddsfile",
|
"ddsfile",
|
||||||
@@ -5515,9 +5516,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "strafesnet_roblox_bot_player"
|
name = "strafesnet_roblox_bot_player"
|
||||||
version = "0.6.1"
|
version = "0.6.2-depth2"
|
||||||
source = "sparse+https://git.itzana.me/api/packages/strafesnet/cargo/"
|
source = "sparse+https://git.itzana.me/api/packages/strafesnet/cargo/"
|
||||||
checksum = "0669779b58836ac36b0166f5a3f326ee46ce25b4d14b7fd6f75bf273e806c1bf"
|
checksum = "f39e7dfc0cb23e482089dc7eac235ad4b274ccfdb8df7617889a90e64a1e247a"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"glam",
|
"glam",
|
||||||
"strafesnet_common",
|
"strafesnet_common",
|
||||||
@@ -5729,7 +5730,7 @@ dependencies = [
|
|||||||
"getrandom 0.4.2",
|
"getrandom 0.4.2",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"rustix 1.1.4",
|
"rustix 1.1.4",
|
||||||
"windows-sys 0.52.0",
|
"windows-sys 0.59.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -6988,7 +6989,7 @@ version = "0.1.11"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
|
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"windows-sys 0.52.0",
|
"windows-sys 0.59.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|||||||
@@ -8,9 +8,10 @@ burn = { version = "0.20.1", features = ["cuda", "autodiff"] }
|
|||||||
wgpu = "29.0.0"
|
wgpu = "29.0.0"
|
||||||
|
|
||||||
strafesnet_common = { version = "0.9.0", registry = "strafesnet" }
|
strafesnet_common = { version = "0.9.0", registry = "strafesnet" }
|
||||||
strafesnet_graphics = { version = "0.0.10", registry = "strafesnet" }
|
strafesnet_graphics = { version = "=0.0.11-depth2", registry = "strafesnet" }
|
||||||
strafesnet_physics = { version = "=0.0.2-surf", registry = "strafesnet" }
|
strafesnet_physics = { version = "=0.0.2-surf", registry = "strafesnet" }
|
||||||
strafesnet_roblox_bot_file = { version = "0.9.4", registry = "strafesnet" }
|
strafesnet_roblox_bot_file = { version = "0.9.4", registry = "strafesnet" }
|
||||||
strafesnet_roblox_bot_player = { version = "0.6.1", registry = "strafesnet" }
|
strafesnet_roblox_bot_player = { version = "=0.6.2-depth2", registry = "strafesnet" }
|
||||||
strafesnet_snf = { version = "0.4.0", registry = "strafesnet" }
|
strafesnet_snf = { version = "0.4.0", registry = "strafesnet" }
|
||||||
pollster = "0.4.0"
|
pollster = "0.4.0"
|
||||||
|
glam = "0.32.1"
|
||||||
|
|||||||
299
src/main.rs
299
src/main.rs
@@ -1,19 +1,20 @@
|
|||||||
use burn::backend::Autodiff;
|
use burn::backend::Autodiff;
|
||||||
use burn::module::AutodiffModule;
|
|
||||||
use burn::nn::loss::{MseLoss, Reduction};
|
use burn::nn::loss::{MseLoss, Reduction};
|
||||||
use burn::nn::{Linear, LinearConfig, Relu, Sigmoid};
|
use burn::nn::{Linear, LinearConfig, Relu};
|
||||||
use burn::optim::{GradientsParams, Optimizer, SgdConfig};
|
use burn::optim::{AdamConfig, GradientsParams, Optimizer};
|
||||||
use burn::prelude::*;
|
use burn::prelude::*;
|
||||||
|
|
||||||
type InferenceBackend = burn::backend::Cuda<f32>;
|
type InferenceBackend = burn::backend::Cuda<f32>;
|
||||||
type TrainingBackend = Autodiff<InferenceBackend>;
|
type TrainingBackend = Autodiff<InferenceBackend>;
|
||||||
|
|
||||||
const LIMITS: wgpu::Limits = wgpu::Limits::defaults();
|
const LIMITS: wgpu::Limits = wgpu::Limits::defaults();
|
||||||
use strafesnet_common::session::Time as SessionTime;
|
|
||||||
use strafesnet_graphics::setup;
|
use strafesnet_graphics::setup;
|
||||||
|
use strafesnet_roblox_bot_file::v0;
|
||||||
|
|
||||||
const INPUT: usize = 2;
|
const SIZE_X: usize = 64;
|
||||||
const HIDDEN: usize = 64;
|
const SIZE_Y: usize = 36;
|
||||||
|
const INPUT: usize = SIZE_X * SIZE_Y;
|
||||||
|
const HIDDEN: [usize; 2] = [INPUT >> 3, INPUT >> 7];
|
||||||
// MoveForward
|
// MoveForward
|
||||||
// MoveLeft
|
// MoveLeft
|
||||||
// MoveBack
|
// MoveBack
|
||||||
@@ -26,19 +27,27 @@ const OUTPUT: usize = 7;
|
|||||||
#[derive(Module, Debug)]
|
#[derive(Module, Debug)]
|
||||||
struct Net<B: Backend> {
|
struct Net<B: Backend> {
|
||||||
input: Linear<B>,
|
input: Linear<B>,
|
||||||
hidden: [Linear<B>; 64],
|
hidden: [Linear<B>; HIDDEN.len() - 1],
|
||||||
output: Linear<B>,
|
output: Linear<B>,
|
||||||
activation: Relu,
|
activation: Relu,
|
||||||
sigmoid: Sigmoid,
|
|
||||||
}
|
}
|
||||||
impl<B: Backend> Net<B> {
|
impl<B: Backend> Net<B> {
|
||||||
fn init(device: &B::Device) -> Self {
|
fn init(device: &B::Device) -> Self {
|
||||||
|
let mut it = HIDDEN.into_iter();
|
||||||
|
let mut last_size = it.next().unwrap();
|
||||||
|
let input = LinearConfig::new(INPUT, last_size).init(device);
|
||||||
|
let hidden = core::array::from_fn(|_| {
|
||||||
|
let size = it.next().unwrap();
|
||||||
|
let layer = LinearConfig::new(last_size, size).init(device);
|
||||||
|
last_size = size;
|
||||||
|
layer
|
||||||
|
});
|
||||||
|
let output = LinearConfig::new(last_size, OUTPUT).init(device);
|
||||||
Self {
|
Self {
|
||||||
input: LinearConfig::new(INPUT, HIDDEN).init(device),
|
input,
|
||||||
hidden: core::array::from_fn(|_| LinearConfig::new(HIDDEN, HIDDEN).init(device)),
|
hidden,
|
||||||
output: LinearConfig::new(HIDDEN, OUTPUT).init(device),
|
output,
|
||||||
activation: Relu::new(),
|
activation: Relu::new(),
|
||||||
sigmoid: Sigmoid::new(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||||
@@ -48,20 +57,22 @@ impl<B: Backend> Net<B> {
|
|||||||
x = layer.forward(x);
|
x = layer.forward(x);
|
||||||
x = self.activation.forward(x);
|
x = self.activation.forward(x);
|
||||||
}
|
}
|
||||||
let x = self.output.forward(x);
|
self.output.forward(x)
|
||||||
self.sigmoid.forward(x)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn training() {
|
fn training() {
|
||||||
|
let gpu_id: usize = std::env::args()
|
||||||
|
.skip(1)
|
||||||
|
.next()
|
||||||
|
.map(|id| id.parse().unwrap())
|
||||||
|
.unwrap_or_default();
|
||||||
// load map
|
// load map
|
||||||
// load replay
|
// load replay
|
||||||
// setup player
|
// setup player
|
||||||
const SIZE_X: usize = 64;
|
|
||||||
const SIZE_Y: usize = 36;
|
|
||||||
|
|
||||||
let map_file = include_bytes!("../bhop_marble_5692093612.snfm");
|
let map_file = include_bytes!("../files/bhop_marble_5692093612.snfm");
|
||||||
let bot_file = include_bytes!("../bhop_marble_7cf33a64-7120-4514-b9fa-4fe29d9523d.qbot");
|
let bot_file = include_bytes!("../files/bhop_marble_7cf33a64-7120-4514-b9fa-4fe29d9523d.qbot");
|
||||||
|
|
||||||
// read files
|
// read files
|
||||||
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
|
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
|
||||||
@@ -70,10 +81,9 @@ fn training() {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
let timelines =
|
let timelines =
|
||||||
strafesnet_roblox_bot_file::v0::read_all_to_block(std::io::Cursor::new(bot_file)).unwrap();
|
strafesnet_roblox_bot_file::v0::read_all_to_block(std::io::Cursor::new(bot_file)).unwrap();
|
||||||
|
|
||||||
let bot = strafesnet_roblox_bot_player::bot::CompleteBot::new(timelines).unwrap();
|
let bot = strafesnet_roblox_bot_player::bot::CompleteBot::new(timelines).unwrap();
|
||||||
let mut playback_head =
|
let world_offset = bot.world_offset();
|
||||||
strafesnet_roblox_bot_player::head::PlaybackHead::new(&bot, SessionTime::ZERO);
|
let timelines = bot.timelines();
|
||||||
|
|
||||||
// setup graphics
|
// setup graphics
|
||||||
let desc = wgpu::InstanceDescriptor::new_without_display_handle_from_env();
|
let desc = wgpu::InstanceDescriptor::new_without_display_handle_from_env();
|
||||||
@@ -93,47 +103,252 @@ fn training() {
|
|||||||
});
|
});
|
||||||
|
|
||||||
const FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Rgba8UnormSrgb;
|
const FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Rgba8UnormSrgb;
|
||||||
let graphics = strafesnet_roblox_bot_player::graphics::Graphics::new(
|
let size = [SIZE_X as u32, SIZE_Y as u32].into();
|
||||||
&device,
|
let mut graphics = strafesnet_roblox_bot_player::graphics::Graphics::new(
|
||||||
&queue,
|
&device, &queue, size, FORMAT, LIMITS,
|
||||||
[SIZE_X as u32, SIZE_Y as u32].into(),
|
|
||||||
FORMAT,
|
|
||||||
LIMITS,
|
|
||||||
);
|
);
|
||||||
|
graphics.change_map(&device, &queue, &map).unwrap();
|
||||||
|
|
||||||
// setup simulation
|
// setup simulation
|
||||||
// run progressively longer segments of the map, starting very close to the end of the run and working the starting time backwards until the ai can run the whole map
|
// run progressively longer segments of the map, starting very close to the end of the run and working the starting time backwards until the ai can run the whole map
|
||||||
|
|
||||||
let device = Default::default();
|
// set up textures
|
||||||
|
let graphics_texture = device.create_texture(&wgpu::TextureDescriptor {
|
||||||
|
label: Some("RGB texture"),
|
||||||
|
format: FORMAT,
|
||||||
|
size: wgpu::Extent3d {
|
||||||
|
width: size.x,
|
||||||
|
height: size.y,
|
||||||
|
depth_or_array_layers: 1,
|
||||||
|
},
|
||||||
|
mip_level_count: 1,
|
||||||
|
sample_count: 1,
|
||||||
|
dimension: wgpu::TextureDimension::D2,
|
||||||
|
usage: wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
|
||||||
|
view_formats: &[],
|
||||||
|
});
|
||||||
|
let graphics_texture_view = graphics_texture.create_view(&wgpu::TextureViewDescriptor {
|
||||||
|
label: Some("RGB texture view"),
|
||||||
|
aspect: wgpu::TextureAspect::All,
|
||||||
|
usage: Some(wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING),
|
||||||
|
..Default::default()
|
||||||
|
});
|
||||||
|
// bytes_per_row needs to be a multiple of 256.
|
||||||
|
let stride_size = (size.x * size_of::<f32>() as u32).next_multiple_of(256);
|
||||||
|
let mut texture_data = Vec::<u8>::with_capacity((stride_size * size.y) as usize);
|
||||||
|
let output_staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
|
||||||
|
label: Some("Output staging buffer"),
|
||||||
|
size: texture_data.capacity() as u64,
|
||||||
|
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
|
||||||
|
mapped_at_creation: false,
|
||||||
|
});
|
||||||
|
|
||||||
|
// training data
|
||||||
|
let training_samples = timelines.input_events.len() - 1;
|
||||||
|
|
||||||
|
let input_size = INPUT * size_of::<f32>();
|
||||||
|
let mut inputs = Vec::with_capacity(input_size * training_samples);
|
||||||
|
let mut targets = Vec::with_capacity(OUTPUT * training_samples);
|
||||||
|
|
||||||
|
// generate all frames
|
||||||
|
println!("Generating {training_samples} frames of depth textures...");
|
||||||
|
let mut it = timelines.input_events.iter();
|
||||||
|
|
||||||
|
// grab mouse position from first frame, omitting one frame from the training data
|
||||||
|
let first = it.next().unwrap();
|
||||||
|
let mut last_mx = first.event.mouse_pos.x;
|
||||||
|
let mut last_my = first.event.mouse_pos.y;
|
||||||
|
|
||||||
|
for input_event in it {
|
||||||
|
let mouse_dx = input_event.event.mouse_pos.x - last_mx;
|
||||||
|
let mouse_dy = input_event.event.mouse_pos.y - last_my;
|
||||||
|
last_mx = input_event.event.mouse_pos.x;
|
||||||
|
last_my = input_event.event.mouse_pos.y;
|
||||||
|
|
||||||
|
// set targets
|
||||||
|
targets.extend([
|
||||||
|
// MoveForward
|
||||||
|
input_event
|
||||||
|
.event
|
||||||
|
.game_controls
|
||||||
|
.contains(v0::GameControls::MoveForward) as i32 as f32,
|
||||||
|
// MoveLeft
|
||||||
|
input_event
|
||||||
|
.event
|
||||||
|
.game_controls
|
||||||
|
.contains(v0::GameControls::MoveLeft) as i32 as f32,
|
||||||
|
// MoveBack
|
||||||
|
input_event
|
||||||
|
.event
|
||||||
|
.game_controls
|
||||||
|
.contains(v0::GameControls::MoveBack) as i32 as f32,
|
||||||
|
// MoveRight
|
||||||
|
input_event
|
||||||
|
.event
|
||||||
|
.game_controls
|
||||||
|
.contains(v0::GameControls::MoveRight) as i32 as f32,
|
||||||
|
// Jump
|
||||||
|
input_event
|
||||||
|
.event
|
||||||
|
.game_controls
|
||||||
|
.contains(v0::GameControls::Jump) as i32 as f32,
|
||||||
|
mouse_dx,
|
||||||
|
mouse_dy,
|
||||||
|
]);
|
||||||
|
|
||||||
|
// find the closest output event to the input event time
|
||||||
|
let output_event_index = timelines
|
||||||
|
.output_events
|
||||||
|
.binary_search_by(|event| event.time.partial_cmp(&input_event.time).unwrap());
|
||||||
|
|
||||||
|
let output_event = match output_event_index {
|
||||||
|
// found the exact same timestamp
|
||||||
|
Ok(output_event_index) => &timelines.output_events[output_event_index],
|
||||||
|
// found first index greater than the time.
|
||||||
|
// check this index and the one before and return the closest one
|
||||||
|
Err(insert_index) => timelines
|
||||||
|
.output_events
|
||||||
|
.get(insert_index)
|
||||||
|
.into_iter()
|
||||||
|
.chain(
|
||||||
|
insert_index
|
||||||
|
.checked_sub(1)
|
||||||
|
.and_then(|index| timelines.output_events.get(index)),
|
||||||
|
)
|
||||||
|
.min_by(|&e0, &e1| {
|
||||||
|
(e0.time - input_event.time)
|
||||||
|
.abs()
|
||||||
|
.partial_cmp(&(e1.time - input_event.time).abs())
|
||||||
|
.unwrap()
|
||||||
|
})
|
||||||
|
.unwrap(),
|
||||||
|
};
|
||||||
|
|
||||||
|
fn vec3(v: v0::Vector3) -> glam::Vec3 {
|
||||||
|
glam::vec3(v.x, v.y, v.z)
|
||||||
|
}
|
||||||
|
fn angles(a: v0::Vector3) -> glam::Vec2 {
|
||||||
|
glam::vec2(a.y, a.x)
|
||||||
|
}
|
||||||
|
|
||||||
|
let pos = vec3(output_event.event.position) - world_offset;
|
||||||
|
let angles = angles(output_event.event.angles);
|
||||||
|
|
||||||
|
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
|
||||||
|
label: Some("wgpu encoder"),
|
||||||
|
});
|
||||||
|
|
||||||
|
// render!
|
||||||
|
graphics.encode_commands(&mut encoder, &graphics_texture_view, pos, angles);
|
||||||
|
|
||||||
|
// copy the depth texture into ram
|
||||||
|
encoder.copy_texture_to_buffer(
|
||||||
|
wgpu::TexelCopyTextureInfo {
|
||||||
|
texture: graphics.depth_texture(),
|
||||||
|
mip_level: 0,
|
||||||
|
origin: wgpu::Origin3d::ZERO,
|
||||||
|
aspect: wgpu::TextureAspect::All,
|
||||||
|
},
|
||||||
|
wgpu::TexelCopyBufferInfo {
|
||||||
|
buffer: &output_staging_buffer,
|
||||||
|
layout: wgpu::TexelCopyBufferLayout {
|
||||||
|
offset: 0,
|
||||||
|
// This needs to be a multiple of 256.
|
||||||
|
bytes_per_row: Some(stride_size as u32),
|
||||||
|
rows_per_image: Some(size.y),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wgpu::Extent3d {
|
||||||
|
width: size.x,
|
||||||
|
height: size.y,
|
||||||
|
depth_or_array_layers: 1,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
queue.submit([encoder.finish()]);
|
||||||
|
|
||||||
|
// map buffer
|
||||||
|
let buffer_slice = output_staging_buffer.slice(..);
|
||||||
|
let (sender, receiver) = std::sync::mpsc::channel();
|
||||||
|
buffer_slice.map_async(wgpu::MapMode::Read, move |r| sender.send(r).unwrap());
|
||||||
|
device.poll(wgpu::PollType::wait_indefinitely()).unwrap();
|
||||||
|
receiver.recv().unwrap().unwrap();
|
||||||
|
|
||||||
|
// copy texture inside a scope so the mapped view gets dropped
|
||||||
|
{
|
||||||
|
let view = buffer_slice.get_mapped_range();
|
||||||
|
texture_data.extend_from_slice(&view[..]);
|
||||||
|
}
|
||||||
|
output_staging_buffer.unmap();
|
||||||
|
|
||||||
|
// discombolulate stride
|
||||||
|
for y in 0..size.y {
|
||||||
|
inputs.extend(
|
||||||
|
texture_data[(stride_size * y) as usize
|
||||||
|
..(stride_size * y + size.x * size_of::<f32>() as u32) as usize]
|
||||||
|
.chunks_exact(4)
|
||||||
|
.map(|b| f32::from_le_bytes(b.try_into().unwrap())),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
texture_data.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalize inputs
|
||||||
|
let global_min = *inputs
|
||||||
|
.iter()
|
||||||
|
.min_by(|a, b| a.partial_cmp(b).unwrap())
|
||||||
|
.unwrap();
|
||||||
|
let global_max = *inputs
|
||||||
|
.iter()
|
||||||
|
.max_by(|a, b| a.partial_cmp(b).unwrap())
|
||||||
|
.unwrap();
|
||||||
|
let global_range = global_max - global_min;
|
||||||
|
println!("Normalizing to range {global_min} - {global_max}");
|
||||||
|
inputs.iter_mut().for_each(|value| {
|
||||||
|
*value = 1.0 - (*value - global_min) / global_range;
|
||||||
|
});
|
||||||
|
|
||||||
|
let device = burn::backend::cuda::CudaDevice::new(gpu_id);
|
||||||
|
|
||||||
let mut model: Net<TrainingBackend> = Net::init(&device);
|
let mut model: Net<TrainingBackend> = Net::init(&device);
|
||||||
|
println!("Training model ({} parameters)", model.num_params());
|
||||||
|
|
||||||
let mut optim = SgdConfig::new().init();
|
let mut optim = AdamConfig::new().init();
|
||||||
|
|
||||||
let inputs = Tensor::from_floats([0.0f32; INPUT], &device);
|
let inputs = Tensor::from_data(
|
||||||
let targets = Tensor::from_floats([0.0f32; OUTPUT], &device);
|
TensorData::new(inputs, Shape::new([training_samples, INPUT])),
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
let targets = Tensor::from_data(
|
||||||
|
TensorData::new(targets, Shape::new([training_samples, OUTPUT])),
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
|
||||||
const LEARNING_RATE: f64 = 0.5;
|
const LEARNING_RATE: f64 = 0.001;
|
||||||
const EPOCHS: usize = 100;
|
const EPOCHS: usize = 100000;
|
||||||
|
|
||||||
for epoch in 0..EPOCHS {
|
for epoch in 0..EPOCHS {
|
||||||
let predictions = model.forward(inputs.clone());
|
let predictions = model.forward(inputs.clone());
|
||||||
|
|
||||||
let loss = MseLoss::new().forward(predictions, targets.clone(), Reduction::Mean);
|
let loss = MseLoss::new().forward(predictions, targets.clone(), Reduction::Mean);
|
||||||
|
|
||||||
if epoch % (EPOCHS >> 4) == 0 || epoch == EPOCHS - 1 {
|
let loss_scalar = loss.clone().into_scalar();
|
||||||
// .clone().into_scalar() extracts the f32 value from a 1-element tensor.
|
|
||||||
println!(
|
if epoch == 0 {
|
||||||
" epoch {:>5} | loss = {:.8}",
|
// kinda a fake print, but that's what is happening after this point
|
||||||
epoch,
|
println!("Compiling optimized GPU kernels...");
|
||||||
loss.clone().into_scalar()
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let grads = loss.backward();
|
let grads = loss.backward();
|
||||||
let grads = GradientsParams::from_grads(grads, &model);
|
let grads = GradientsParams::from_grads(grads, &model);
|
||||||
|
|
||||||
model = optim.step(LEARNING_RATE, model, grads);
|
model = optim.step(LEARNING_RATE, model, grads);
|
||||||
|
|
||||||
|
if epoch % (EPOCHS >> 4) == 0 || epoch == EPOCHS - 1 {
|
||||||
|
// .clone().into_scalar() extracts the f32 value from a 1-element tensor.
|
||||||
|
println!(" epoch {:>5} | loss = {:.8}", epoch, loss_scalar);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -144,4 +359,6 @@ fn inference() {
|
|||||||
// go!
|
// go!
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() {}
|
fn main() {
|
||||||
|
training();
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user