19 Commits
png ... main

Author SHA1 Message Date
48f9657d0f don't hardcode map and bot 2026-03-27 16:46:02 -07:00
e38c0a92b4 remove chrono dep 2026-03-27 16:28:30 -07:00
148471dce1 simulate: add output_file argument 2026-03-27 16:28:30 -07:00
7bf439395b write model name based on num params 2026-03-27 16:16:26 -07:00
03f5eb5c13 tweak model 2026-03-27 16:04:03 -07:00
1e7bb6c4ce format 2026-03-27 15:57:32 -07:00
d8b0f9abbb rename GraphicsState to InputGenerator 2026-03-27 15:57:17 -07:00
fb8c6e2492 training options 2026-03-27 15:56:15 -07:00
18cad85b62 add cli args 2026-03-27 15:56:15 -07:00
b195a7eb95 split code into modules 2026-03-27 15:46:51 -07:00
4208090da0 add clap dep 2026-03-27 15:32:17 -07:00
e31b148f41 simulator 2026-03-27 15:29:32 -07:00
a05113baa5 feed position history into model inputs 2026-03-27 15:28:53 -07:00
9ad8a70ad0 hardcode depth "normalization" 2026-03-27 15:13:29 -07:00
e890623f2e add dropout to input 2026-03-27 15:00:30 -07:00
7d55e872e7 save model with current date 2026-03-27 11:56:45 -07:00
1e1cbeb180 graphics state 2026-03-27 11:56:45 -07:00
e19c46d851 save best model 2026-03-27 11:25:34 -07:00
59bb8eee12 implement training 2026-03-27 11:13:23 -07:00
8 changed files with 896 additions and 202 deletions

1
.gitignore vendored
View File

@@ -1 +1,2 @@
/files
/target

230
Cargo.lock generated
View File

@@ -82,6 +82,56 @@ dependencies = [
"libc",
]
[[package]]
name = "anstream"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "824a212faf96e9acacdbd09febd34438f8f711fb84e09a8916013cd7815ca28d"
dependencies = [
"anstyle",
"anstyle-parse",
"anstyle-query",
"anstyle-wincon",
"colorchoice",
"is_terminal_polyfill",
"utf8parse",
]
[[package]]
name = "anstyle"
version = "1.0.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000"
[[package]]
name = "anstyle-parse"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52ce7f38b242319f7cabaa6813055467063ecdc9d355bbb4ce0c68908cd8130e"
dependencies = [
"utf8parse",
]
[[package]]
name = "anstyle-query"
version = "1.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc"
dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "anstyle-wincon"
version = "3.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d"
dependencies = [
"anstyle",
"once_cell_polyfill",
"windows-sys 0.61.2",
]
[[package]]
name = "anyhow"
version = "1.0.102"
@@ -1038,9 +1088,9 @@ dependencies = [
[[package]]
name = "cc"
version = "1.2.57"
version = "1.2.58"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a0dd1ca384932ff3641c8718a02769f1698e7563dc6974ffd03346116310423"
checksum = "e1e928d4b69e3077709075a938a05ffbedfa53a84c8f766efbf8220bb1ff60e1"
dependencies = [
"find-msvc-tools",
"jobserver",
@@ -1101,6 +1151,46 @@ dependencies = [
"libloading",
]
[[package]]
name = "clap"
version = "4.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b193af5b67834b676abd72466a96c1024e6a6ad978a1f484bd90b85c94041351"
dependencies = [
"clap_builder",
"clap_derive",
]
[[package]]
name = "clap_builder"
version = "4.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "714a53001bf66416adb0e2ef5ac857140e7dc3a0c48fb28b2f10762fc4b5069f"
dependencies = [
"anstream",
"anstyle",
"clap_lex",
"strsim",
]
[[package]]
name = "clap_derive"
version = "4.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1110bd8a634a1ab8cb04345d8d878267d57c3cf1b38d91b71af6686408bbca6a"
dependencies = [
"heck",
"proc-macro2",
"quote",
"syn 2.0.117",
]
[[package]]
name = "clap_lex"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9"
[[package]]
name = "codespan-reporting"
version = "0.12.0"
@@ -1109,7 +1199,7 @@ checksum = "fe6d2e5af09e8c8ad56c969f2157a3d4238cebc7c55f0a517728c38f7b200f81"
dependencies = [
"serde",
"termcolor",
"unicode-width 0.1.14",
"unicode-width 0.2.0",
]
[[package]]
@@ -1120,7 +1210,7 @@ checksum = "af491d569909a7e4dee0ad7db7f5341fef5c614d5b8ec8cf765732aba3cff681"
dependencies = [
"serde",
"termcolor",
"unicode-width 0.1.14",
"unicode-width 0.2.0",
]
[[package]]
@@ -1129,13 +1219,19 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b"
[[package]]
name = "colorchoice"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570"
[[package]]
name = "colored"
version = "3.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "faf9468729b8cbcea668e36183cb69d317348c2e08e994829fb56ebfdfbaac34"
dependencies = [
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -2028,7 +2124,7 @@ dependencies = [
"libc",
"option-ext",
"redox_users",
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -2232,7 +2328,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [
"libc",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -2919,7 +3015,7 @@ dependencies = [
"log",
"presser",
"thiserror 2.0.18",
"windows 0.58.0",
"windows 0.62.2",
]
[[package]]
@@ -3396,6 +3492,12 @@ dependencies = [
"serde",
]
[[package]]
name = "is_terminal_polyfill"
version = "1.70.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695"
[[package]]
name = "itertools"
version = "0.13.0"
@@ -3789,9 +3891,9 @@ dependencies = [
[[package]]
name = "mio"
version = "1.1.1"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc"
checksum = "50b7e5b27aa02a74bac8c3f23f448f8d87ff11f92d3aac1a6ed369ee08cc56c1"
dependencies = [
"libc",
"log",
@@ -3843,9 +3945,9 @@ dependencies = [
[[package]]
name = "naga"
version = "29.0.0"
version = "29.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85b4372fed0bd362d646d01b6926df0e837859ccc522fed720c395e0460f29c8"
checksum = "aa2630921705b9b01dcdd0b6864b9562ca3c1951eecd0f0c4f5f04f61e412647"
dependencies = [
"arrayvec",
"bit-set 0.9.1",
@@ -3968,7 +4070,7 @@ version = "0.50.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
dependencies = [
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -4201,6 +4303,12 @@ version = "1.21.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50"
[[package]]
name = "once_cell_polyfill"
version = "1.70.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe"
[[package]]
name = "option-ext"
version = "0.2.0"
@@ -5054,7 +5162,7 @@ dependencies = [
"errno",
"libc",
"linux-raw-sys 0.12.1",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -5346,9 +5454,9 @@ dependencies = [
[[package]]
name = "simd-adler32"
version = "0.3.8"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2"
checksum = "703d5c7ef118737c72f1af64ad2f6f8c5e1921f818cdcb97b8fe6fc69bf66214"
[[package]]
name = "simd_helpers"
@@ -5397,7 +5505,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e"
dependencies = [
"libc",
"windows-sys 0.60.2",
"windows-sys 0.61.2",
]
[[package]]
@@ -5451,6 +5559,8 @@ name = "strafe-ai"
version = "0.1.0"
dependencies = [
"burn",
"clap",
"glam",
"pollster",
"strafesnet_common",
"strafesnet_graphics",
@@ -5458,7 +5568,7 @@ dependencies = [
"strafesnet_roblox_bot_file",
"strafesnet_roblox_bot_player",
"strafesnet_snf",
"wgpu 29.0.0",
"wgpu 29.0.1",
]
[[package]]
@@ -5478,16 +5588,16 @@ dependencies = [
[[package]]
name = "strafesnet_graphics"
version = "0.0.10"
version = "0.0.11-depth2"
source = "sparse+https://git.itzana.me/api/packages/strafesnet/cargo/"
checksum = "5080cb31a6cf898daab6c960801828ce9834dba8e932dea6b02823651ea53c33"
checksum = "829804ab9c167365e576de8ebd8a245ad979cb24558b086e693e840697d7956c"
dependencies = [
"bytemuck",
"ddsfile",
"glam",
"id",
"strafesnet_common",
"wgpu 29.0.0",
"wgpu 29.0.1",
]
[[package]]
@@ -5515,16 +5625,16 @@ dependencies = [
[[package]]
name = "strafesnet_roblox_bot_player"
version = "0.6.1"
version = "0.6.2-depth2"
source = "sparse+https://git.itzana.me/api/packages/strafesnet/cargo/"
checksum = "0669779b58836ac36b0166f5a3f326ee46ce25b4d14b7fd6f75bf273e806c1bf"
checksum = "f39e7dfc0cb23e482089dc7eac235ad4b274ccfdb8df7617889a90e64a1e247a"
dependencies = [
"glam",
"strafesnet_common",
"strafesnet_graphics",
"strafesnet_roblox_bot_file",
"thiserror 2.0.18",
"wgpu 29.0.0",
"wgpu 29.0.1",
]
[[package]]
@@ -5729,7 +5839,7 @@ dependencies = [
"getrandom 0.4.2",
"once_cell",
"rustix 1.1.4",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -6301,9 +6411,9 @@ dependencies = [
[[package]]
name = "unicode-segmentation"
version = "1.13.1"
version = "1.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da36089a805484bcccfffe0739803392c8298778a2d2f09febf76fac5ad9025b"
checksum = "9629274872b2bfaf8d66f5f15725007f635594914870f65218920345aa11aa8c"
[[package]]
name = "unicode-truncate"
@@ -6401,10 +6511,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
[[package]]
name = "uuid"
version = "1.22.0"
name = "utf8parse"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a68d3c8f01c0cfa54a75291d83601161799e4a89a39e0929f4b0354d88757a37"
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]]
name = "uuid"
version = "1.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ac8b6f42ead25368cf5b098aeb3dc8a1a2c05a3eee8a9a1a68c640edbfc79d9"
dependencies = [
"getrandom 0.4.2",
"js-sys",
@@ -6681,9 +6797,9 @@ dependencies = [
[[package]]
name = "wgpu"
version = "29.0.0"
version = "29.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78f9f386699b1fb8b8a05bfe82169b24d151f05702d2905a0bf93bc454fcc825"
checksum = "72c239a9a747bbd379590985bac952c2e53cb19873f7072b3370c6a6a8e06837"
dependencies = [
"arrayvec",
"bitflags",
@@ -6694,7 +6810,7 @@ dependencies = [
"hashbrown 0.16.1",
"js-sys",
"log",
"naga 29.0.0",
"naga 29.0.1",
"parking_lot",
"portable-atomic",
"profiling",
@@ -6704,9 +6820,9 @@ dependencies = [
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
"wgpu-core 29.0.0",
"wgpu-hal 29.0.0",
"wgpu-types 29.0.0",
"wgpu-core 29.0.1",
"wgpu-hal 29.0.1",
"wgpu-types 29.0.1",
]
[[package]]
@@ -6742,9 +6858,9 @@ dependencies = [
[[package]]
name = "wgpu-core"
version = "29.0.0"
version = "29.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7c34181b0acb8f98168f78f8e57ec66f57df5522b39143dbe5f2f45d7ca927c"
checksum = "1e80ac6cf1895df6342f87d975162108f9d98772a0d74bc404ab7304ac29469e"
dependencies = [
"arrayvec",
"bit-set 0.9.1",
@@ -6756,7 +6872,7 @@ dependencies = [
"hashbrown 0.16.1",
"indexmap",
"log",
"naga 29.0.0",
"naga 29.0.1",
"once_cell",
"parking_lot",
"portable-atomic",
@@ -6768,9 +6884,9 @@ dependencies = [
"wgpu-core-deps-apple 29.0.0",
"wgpu-core-deps-emscripten 29.0.0",
"wgpu-core-deps-windows-linux-android 29.0.0",
"wgpu-hal 29.0.0",
"wgpu-hal 29.0.1",
"wgpu-naga-bridge",
"wgpu-types 29.0.0",
"wgpu-types 29.0.1",
]
[[package]]
@@ -6788,7 +6904,7 @@ version = "29.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43acd053312501689cd92a01a9638d37f3e41a5fd9534875efa8917ee2d11ac0"
dependencies = [
"wgpu-hal 29.0.0",
"wgpu-hal 29.0.1",
]
[[package]]
@@ -6806,7 +6922,7 @@ version = "29.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef043bf135cc68b6f667c55ff4e345ce2b5924d75bad36a47921b0287ca4b24a"
dependencies = [
"wgpu-hal 29.0.0",
"wgpu-hal 29.0.1",
]
[[package]]
@@ -6824,7 +6940,7 @@ version = "29.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "725d5c006a8c02967b6d93ef04f6537ec4593313e330cfe86d9d3f946eb90f28"
dependencies = [
"wgpu-hal 29.0.0",
"wgpu-hal 29.0.1",
]
[[package]]
@@ -6877,9 +6993,9 @@ dependencies = [
[[package]]
name = "wgpu-hal"
version = "29.0.0"
version = "29.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "058b6047337cf323a4f092486443a9337f3d81325347e5d77deed7e563aeaedc"
checksum = "89a47aef47636562f3937285af4c44b4b5b404b46577471411cc5313a921da7e"
dependencies = [
"android_system_properties",
"arrayvec",
@@ -6900,7 +7016,7 @@ dependencies = [
"libc",
"libloading",
"log",
"naga 29.0.0",
"naga 29.0.1",
"ndk-sys",
"objc2",
"objc2-core-foundation",
@@ -6923,19 +7039,19 @@ dependencies = [
"wayland-sys",
"web-sys",
"wgpu-naga-bridge",
"wgpu-types 29.0.0",
"wgpu-types 29.0.1",
"windows 0.62.2",
"windows-core 0.62.2",
]
[[package]]
name = "wgpu-naga-bridge"
version = "29.0.0"
version = "29.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0b8e1e505095f24cb4a578f04b1421d456257dca7fac114d9d9dd3d978c34b8"
checksum = "7b4684f4410da0cf95a4cb63bb5edaac022461dedb6adf0b64d0d9b5f6890d51"
dependencies = [
"naga 29.0.0",
"wgpu-types 29.0.0",
"naga 29.0.1",
"wgpu-types 29.0.1",
]
[[package]]
@@ -6954,9 +7070,9 @@ dependencies = [
[[package]]
name = "wgpu-types"
version = "29.0.0"
version = "29.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d15ece45db77dd5451f11c0ce898334317ce8502d304a20454b531fdc0652fae"
checksum = "ec2675540fb1a5cfa5ef122d3d5f390e2c75711a0b946410f2d6ac3a0f77d1f6"
dependencies = [
"bitflags",
"bytemuck",
@@ -6988,7 +7104,7 @@ version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
dependencies = [
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -7755,9 +7871,9 @@ dependencies = [
[[package]]
name = "zune-jpeg"
version = "0.5.14"
version = "0.5.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b7a1c0af6e5d8d1363f4994b7a091ccf963d8b694f7da5b0b9cceb82da2c0a6"
checksum = "27bc9d5b815bc103f142aa054f561d9187d191692ec7c2d1e2b4737f8dbd7296"
dependencies = [
"zune-core",
]

View File

@@ -5,12 +5,14 @@ edition = "2024"
[dependencies]
burn = { version = "0.20.1", features = ["cuda", "autodiff"] }
clap = { version = "4.6.0", features = ["derive"] }
glam = "0.32.1"
pollster = "0.4.0"
wgpu = "29.0.0"
strafesnet_common = { version = "0.9.0", registry = "strafesnet" }
strafesnet_graphics = { version = "0.0.10", registry = "strafesnet" }
strafesnet_graphics = { version = "=0.0.11-depth2", registry = "strafesnet" }
strafesnet_physics = { version = "=0.0.2-surf", registry = "strafesnet" }
strafesnet_roblox_bot_file = { version = "0.9.4", registry = "strafesnet" }
strafesnet_roblox_bot_player = { version = "0.6.1", registry = "strafesnet" }
strafesnet_roblox_bot_player = { version = "=0.6.2-depth2", registry = "strafesnet" }
strafesnet_snf = { version = "0.4.0", registry = "strafesnet" }
pollster = "0.4.0"

244
src/inference.rs Normal file
View File

@@ -0,0 +1,244 @@
#[derive(clap::Subcommand)]
pub enum Commands {
Simulate(SimulateSubcommand),
}
impl Commands {
pub fn run(self) {
match self {
Commands::Simulate(subcommand) => subcommand.run(),
}
}
}
#[derive(clap::Args)]
pub struct SimulateSubcommand {
#[arg(long)]
gpu_id: Option<usize>,
#[arg(long)]
model_file: std::path::PathBuf,
#[arg(long)]
output_file: Option<std::path::PathBuf>,
#[arg(long)]
map_file: std::path::PathBuf,
}
impl SimulateSubcommand {
fn run(self) {
let output_file = self.output_file.unwrap_or_else(|| {
let mut file_name = self
.model_file
.file_stem()
.unwrap()
.to_str()
.unwrap()
.to_owned();
file_name.push_str("_replay");
let mut path = self.model_file.clone();
path.set_file_name(file_name);
path.set_extension("snfb");
path
});
inference(
self.gpu_id.unwrap_or_default(),
self.model_file,
output_file,
self.map_file,
);
}
}
use burn::prelude::*;
use crate::inputs::InputGenerator;
use crate::net::{INPUT, InferenceBackend, Net};
use strafesnet_common::instruction::TimedInstruction;
use strafesnet_common::mouse::MouseState;
use strafesnet_common::physics::{
Instruction as PhysicsInputInstruction, MiscInstruction, ModeInstruction, MouseInstruction,
SetControlInstruction, Time as PhysicsTime,
};
use strafesnet_physics::physics::{PhysicsContext, PhysicsData, PhysicsState};
pub struct Recording {
instructions: Vec<TimedInstruction<PhysicsInputInstruction, PhysicsTime>>,
}
struct FrameState {
trajectory: strafesnet_physics::physics::Trajectory,
camera: strafesnet_physics::physics::PhysicsCamera,
}
impl FrameState {
fn pos(&self, time: PhysicsTime) -> glam::Vec3 {
self.trajectory
.extrapolated_position(time)
.map(Into::<f32>::into)
.to_array()
.into()
}
fn angles(&self) -> glam::Vec2 {
self.camera.simulate_move_angles(glam::IVec2::ZERO)
}
}
struct Session {
geometry_shared: PhysicsData,
simulation: PhysicsState,
recording: Recording,
}
impl Session {
fn get_frame_state(&self) -> FrameState {
FrameState {
trajectory: self.simulation.camera_trajectory(&self.geometry_shared),
camera: self.simulation.camera(),
}
}
fn run(&mut self, time: PhysicsTime, instruction: PhysicsInputInstruction) {
let instruction = TimedInstruction { time, instruction };
self.recording.instructions.push(instruction.clone());
PhysicsContext::run_input_instruction(
&mut self.simulation,
&self.geometry_shared,
instruction,
);
}
}
fn inference(
gpu_id: usize,
model_file: std::path::PathBuf,
output_file: std::path::PathBuf,
map_file: std::path::PathBuf,
) {
// pick device
let device = burn::backend::cuda::CudaDevice::new(gpu_id);
// load model
let mut model: Net<InferenceBackend> = Net::init(&device);
model = model
.load_file(
model_file,
&burn::record::BinFileRecorder::<burn::record::FullPrecisionSettings>::new(),
&device,
)
.unwrap();
// load map
let map_file = std::fs::read(map_file).unwrap();
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
.unwrap()
.into_complete_map()
.unwrap();
let modes = map.modes.clone().denormalize();
let mode = modes
.get_mode(strafesnet_common::gameplay_modes::ModeId::MAIN)
.unwrap();
let start_zone = map.models.get(mode.get_start().get() as usize).unwrap();
let start_offset = glam::Vec3::from_array(
start_zone
.transform
.translation
.map(|f| f.into())
.to_array(),
);
// setup graphics
let mut g = InputGenerator::new(&map);
// setup simulation
let mut session = Session {
geometry_shared: PhysicsData::new(&map),
simulation: PhysicsState::default(),
recording: Recording {
instructions: Vec::new(),
},
};
let mut time = PhysicsTime::ZERO;
// reset to start zone
session.run(time, PhysicsInputInstruction::Mode(ModeInstruction::Reset));
// session.run(
// time,
// PhysicsInputInstruction::Misc(MiscInstruction::SetSensitivity(?)),
// );
session.run(
time,
PhysicsInputInstruction::Mode(ModeInstruction::Restart(
strafesnet_common::gameplay_modes::ModeId::MAIN,
)),
);
// TEMP: turn mouse left
let mut mouse_pos = glam::ivec2(-5300, 0);
const STEP: PhysicsTime = PhysicsTime::from_millis(10);
let mut input_floats = Vec::new();
// setup agent-simulation feedback loop
for _ in 0..20 * 100 {
// generate inputs
let frame_state = session.get_frame_state();
g.generate_inputs(
frame_state.pos(time) - start_offset,
frame_state.angles(),
&mut input_floats,
);
// inference
let inputs = Tensor::from_data(
TensorData::new(input_floats.clone(), Shape::new([1, INPUT])),
&device,
);
let outputs = model.forward(inputs).into_data().into_vec::<f32>().unwrap();
let &[
move_forward,
move_left,
move_back,
move_right,
jump,
mouse_dx,
mouse_dy,
] = outputs.as_slice()
else {
panic!()
};
macro_rules! set_control {
($control:ident,$output:expr) => {
session.run(
time,
PhysicsInputInstruction::SetControl(SetControlInstruction::$control(
0.5 < $output,
)),
);
};
}
set_control!(SetMoveForward, move_forward);
set_control!(SetMoveLeft, move_left);
set_control!(SetMoveBack, move_back);
set_control!(SetMoveRight, move_right);
set_control!(SetJump, jump);
mouse_pos += glam::vec2(mouse_dx, mouse_dy).round().as_ivec2();
let next_time = time + STEP;
session.run(
time,
PhysicsInputInstruction::Mouse(MouseInstruction::SetNextMouse(MouseState {
pos: mouse_pos,
time: next_time,
})),
);
time = next_time;
// clear
input_floats.clear();
}
let file = std::fs::File::create(output_file).unwrap();
strafesnet_snf::bot::write_bot(
std::io::BufWriter::new(file),
strafesnet_physics::VERSION.get(),
core::mem::take(&mut session.recording.instructions),
)
.unwrap();
}

166
src/inputs.rs Normal file
View File

@@ -0,0 +1,166 @@
const LIMITS: wgpu::Limits = wgpu::Limits::defaults();
const FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Rgba8UnormSrgb;
use strafesnet_graphics::setup;
use crate::net::{POSITION_HISTORY, SIZE};
// bytes_per_row needs to be a multiple of 256.
const STRIDE_SIZE: u32 = (SIZE.x * size_of::<f32>() as u32).next_multiple_of(256);
pub struct InputGenerator {
device: wgpu::Device,
queue: wgpu::Queue,
graphics: strafesnet_roblox_bot_player::graphics::Graphics,
graphics_texture_view: wgpu::TextureView,
output_staging_buffer: wgpu::Buffer,
texture_data: Vec<u8>,
position_history: Vec<glam::Vec3>,
}
impl InputGenerator {
pub fn new(map: &strafesnet_common::map::CompleteMap) -> Self {
let desc = wgpu::InstanceDescriptor::new_without_display_handle_from_env();
let instance = wgpu::Instance::new(desc);
let (device, queue) = pollster::block_on(async {
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
force_fallback_adapter: false,
compatible_surface: None,
})
.await
.unwrap();
setup::step4::request_device(&adapter, LIMITS)
.await
.unwrap()
});
let mut graphics = strafesnet_roblox_bot_player::graphics::Graphics::new(
&device, &queue, SIZE, FORMAT, LIMITS,
);
graphics.change_map(&device, &queue, map).unwrap();
let graphics_texture = device.create_texture(&wgpu::TextureDescriptor {
label: Some("RGB texture"),
format: FORMAT,
size: wgpu::Extent3d {
width: SIZE.x,
height: SIZE.y,
depth_or_array_layers: 1,
},
mip_level_count: 1,
sample_count: 1,
dimension: wgpu::TextureDimension::D2,
usage: wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
view_formats: &[],
});
let graphics_texture_view = graphics_texture.create_view(&wgpu::TextureViewDescriptor {
label: Some("RGB texture view"),
aspect: wgpu::TextureAspect::All,
usage: Some(
wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
),
..Default::default()
});
let texture_data = Vec::<u8>::with_capacity((STRIDE_SIZE * SIZE.y) as usize);
let output_staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Output staging buffer"),
size: texture_data.capacity() as u64,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let position_history = Vec::with_capacity(POSITION_HISTORY);
Self {
device,
queue,
graphics,
graphics_texture_view,
output_staging_buffer,
texture_data,
position_history,
}
}
pub fn generate_inputs(&mut self, pos: glam::Vec3, angles: glam::Vec2, inputs: &mut Vec<f32>) {
// write position history to model inputs
if !self.position_history.is_empty() {
let camera = strafesnet_graphics::graphics::view_inv(pos, angles).inverse();
for &pos in self.position_history.iter().rev() {
let relative_pos = camera.transform_vector3(pos);
inputs.extend_from_slice(&relative_pos.to_array());
}
}
// fill remaining history with zeroes
for _ in self.position_history.len()..POSITION_HISTORY {
inputs.extend_from_slice(&[0.0, 0.0, 0.0]);
}
// track position history
if self.position_history.len() < POSITION_HISTORY {
self.position_history.push(pos);
} else {
self.position_history.rotate_left(1);
*self.position_history.last_mut().unwrap() = pos;
}
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("wgpu encoder"),
});
// render!
self.graphics
.encode_commands(&mut encoder, &self.graphics_texture_view, pos, angles);
// copy the depth texture into ram
encoder.copy_texture_to_buffer(
wgpu::TexelCopyTextureInfo {
texture: self.graphics.depth_texture(),
mip_level: 0,
origin: wgpu::Origin3d::ZERO,
aspect: wgpu::TextureAspect::All,
},
wgpu::TexelCopyBufferInfo {
buffer: &self.output_staging_buffer,
layout: wgpu::TexelCopyBufferLayout {
offset: 0,
// This needs to be a multiple of 256.
bytes_per_row: Some(STRIDE_SIZE),
rows_per_image: Some(SIZE.y),
},
},
wgpu::Extent3d {
width: SIZE.x,
height: SIZE.y,
depth_or_array_layers: 1,
},
);
self.queue.submit([encoder.finish()]);
// map buffer
let buffer_slice = self.output_staging_buffer.slice(..);
let (sender, receiver) = std::sync::mpsc::channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |r| sender.send(r).unwrap());
self.device
.poll(wgpu::PollType::wait_indefinitely())
.unwrap();
receiver.recv().unwrap().unwrap();
// copy texture inside a scope so the mapped view gets dropped
{
let view = buffer_slice.get_mapped_range();
self.texture_data.extend_from_slice(&view[..]);
}
self.output_staging_buffer.unmap();
// discombolulate stride
for y in 0..SIZE.y {
inputs.extend(
self.texture_data[(STRIDE_SIZE * y) as usize
..(STRIDE_SIZE * y + SIZE.x * size_of::<f32>() as u32) as usize]
.chunks_exact(4)
.map(|b| 1.0 - 2.0 * f32::from_le_bytes(b.try_into().unwrap())),
)
}
self.texture_data.clear();
}
}

View File

@@ -1,147 +1,30 @@
use burn::backend::Autodiff;
use burn::module::AutodiffModule;
use burn::nn::loss::{MseLoss, Reduction};
use burn::nn::{Linear, LinearConfig, Relu, Sigmoid};
use burn::optim::{GradientsParams, Optimizer, SgdConfig};
use burn::prelude::*;
use clap::{Parser, Subcommand};
type InferenceBackend = burn::backend::Cuda<f32>;
type TrainingBackend = Autodiff<InferenceBackend>;
mod inference;
mod inputs;
mod net;
mod training;
const LIMITS: wgpu::Limits = wgpu::Limits::defaults();
use strafesnet_common::session::Time as SessionTime;
use strafesnet_graphics::setup;
const INPUT: usize = 2;
const HIDDEN: usize = 64;
// MoveForward
// MoveLeft
// MoveBack
// MoveRight
// Jump
// mouse_dx
// mouse_dy
const OUTPUT: usize = 7;
#[derive(Module, Debug)]
struct Net<B: Backend> {
input: Linear<B>,
hidden: [Linear<B>; 64],
output: Linear<B>,
activation: Relu,
sigmoid: Sigmoid,
}
impl<B: Backend> Net<B> {
fn init(device: &B::Device) -> Self {
Self {
input: LinearConfig::new(INPUT, HIDDEN).init(device),
hidden: core::array::from_fn(|_| LinearConfig::new(HIDDEN, HIDDEN).init(device)),
output: LinearConfig::new(HIDDEN, OUTPUT).init(device),
activation: Relu::new(),
sigmoid: Sigmoid::new(),
}
}
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let x = self.input.forward(input);
let mut x = self.activation.forward(x);
for layer in &self.hidden {
x = layer.forward(x);
x = self.activation.forward(x);
}
let x = self.output.forward(x);
self.sigmoid.forward(x)
}
#[derive(Parser)]
#[command(author,version,about,long_about=None)]
#[command(propagate_version = true)]
struct Cli {
#[command(subcommand)]
command: Commands,
}
fn training() {
// load map
// load replay
// setup player
const SIZE_X: usize = 64;
const SIZE_Y: usize = 36;
let map_file = include_bytes!("../bhop_marble_5692093612.snfm");
let bot_file = include_bytes!("../bhop_marble_7cf33a64-7120-4514-b9fa-4fe29d9523d.qbot");
// read files
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
.unwrap()
.into_complete_map()
.unwrap();
let timelines =
strafesnet_roblox_bot_file::v0::read_all_to_block(std::io::Cursor::new(bot_file)).unwrap();
let bot = strafesnet_roblox_bot_player::bot::CompleteBot::new(timelines).unwrap();
let mut playback_head =
strafesnet_roblox_bot_player::head::PlaybackHead::new(&bot, SessionTime::ZERO);
// setup graphics
let desc = wgpu::InstanceDescriptor::new_without_display_handle_from_env();
let instance = wgpu::Instance::new(desc);
let (device, queue) = pollster::block_on(async {
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
force_fallback_adapter: false,
compatible_surface: None,
})
.await
.unwrap();
setup::step4::request_device(&adapter, LIMITS)
.await
.unwrap()
});
const FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Rgba8UnormSrgb;
let graphics = strafesnet_roblox_bot_player::graphics::Graphics::new(
&device,
&queue,
[SIZE_X as u32, SIZE_Y as u32].into(),
FORMAT,
LIMITS,
);
// setup simulation
// run progressively longer segments of the map, starting very close to the end of the run and working the starting time backwards until the ai can run the whole map
let device = Default::default();
let mut model: Net<TrainingBackend> = Net::init(&device);
let mut optim = SgdConfig::new().init();
let inputs = Tensor::from_floats([0.0f32; INPUT], &device);
let targets = Tensor::from_floats([0.0f32; OUTPUT], &device);
const LEARNING_RATE: f64 = 0.5;
const EPOCHS: usize = 100;
for epoch in 0..EPOCHS {
let predictions = model.forward(inputs.clone());
let loss = MseLoss::new().forward(predictions, targets.clone(), Reduction::Mean);
if epoch % (EPOCHS >> 4) == 0 || epoch == EPOCHS - 1 {
// .clone().into_scalar() extracts the f32 value from a 1-element tensor.
println!(
" epoch {:>5} | loss = {:.8}",
epoch,
loss.clone().into_scalar()
);
}
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
model = optim.step(LEARNING_RATE, model, grads);
}
#[derive(Subcommand)]
enum Commands {
#[command(flatten)]
Roblox(inference::Commands),
#[command(flatten)]
Source(training::Commands),
}
fn inference() {
// load map
// setup simulation
// setup agent-simulation feedback loop
// go!
fn main() {
let cli = Cli::parse();
match cli.command {
Commands::Roblox(commands) => commands.run(),
Commands::Source(commands) => commands.run(),
}
}
fn main() {}

60
src/net.rs Normal file
View File

@@ -0,0 +1,60 @@
use burn::backend::Autodiff;
use burn::nn::{Dropout, DropoutConfig, Linear, LinearConfig, Relu};
use burn::prelude::*;
pub type InferenceBackend = burn::backend::Cuda<f32>;
pub type TrainingBackend = Autodiff<InferenceBackend>;
pub const SIZE: glam::UVec2 = glam::uvec2(64, 36);
pub const POSITION_HISTORY: usize = 10;
pub const INPUT: usize = (SIZE.x * SIZE.y) as usize + POSITION_HISTORY * 3;
pub const HIDDEN: [usize; 3] = [INPUT >> 3, INPUT >> 5, INPUT >> 7];
// MoveForward
// MoveLeft
// MoveBack
// MoveRight
// Jump
// mouse_dx
// mouse_dy
pub const OUTPUT: usize = 7;
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
input: Linear<B>,
dropout: Dropout,
hidden: [Linear<B>; HIDDEN.len() - 1],
output: Linear<B>,
activation: Relu,
}
impl<B: Backend> Net<B> {
pub fn init(device: &B::Device) -> Self {
let mut it = HIDDEN.into_iter();
let mut last_size = it.next().unwrap();
let input = LinearConfig::new(INPUT, last_size).init(device);
let hidden = core::array::from_fn(|_| {
let size = it.next().unwrap();
let layer = LinearConfig::new(last_size, size).init(device);
last_size = size;
layer
});
let output = LinearConfig::new(last_size, OUTPUT).init(device);
let dropout = DropoutConfig::new(0.1).init();
Self {
input,
dropout,
hidden,
output,
activation: Relu::new(),
}
}
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let x = self.input.forward(input);
let x = self.dropout.forward(x);
let mut x = self.activation.forward(x);
for layer in &self.hidden {
x = layer.forward(x);
x = self.activation.forward(x);
}
self.output.forward(x)
}
}

222
src/training.rs Normal file
View File

@@ -0,0 +1,222 @@
#[derive(clap::Subcommand)]
pub enum Commands {
Train(TrainSubcommand),
}
impl Commands {
pub fn run(self) {
match self {
Commands::Train(subcommand) => subcommand.run(),
}
}
}
#[derive(clap::Args)]
pub struct TrainSubcommand {
#[arg(long)]
gpu_id: Option<usize>,
#[arg(long)]
epochs: Option<usize>,
#[arg(long)]
learning_rate: Option<f64>,
#[arg(long)]
map_file: std::path::PathBuf,
#[arg(long)]
bot_file: std::path::PathBuf,
}
impl TrainSubcommand {
fn run(self) {
training(
self.gpu_id.unwrap_or_default(),
self.epochs.unwrap_or(100_000),
self.learning_rate.unwrap_or(0.001),
self.map_file,
self.bot_file,
);
}
}
use burn::nn::loss::{MseLoss, Reduction};
use burn::optim::{AdamConfig, GradientsParams, Optimizer};
use burn::prelude::*;
use crate::inputs::InputGenerator;
use crate::net::{INPUT, Net, OUTPUT, TrainingBackend};
use strafesnet_roblox_bot_file::v0;
fn training(
gpu_id: usize,
epochs: usize,
learning_rate: f64,
map_file: std::path::PathBuf,
bot_file: std::path::PathBuf,
) {
// read files
let map_file = std::fs::read(map_file).unwrap();
let bot_file = std::fs::read(bot_file).unwrap();
// load map
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
.unwrap()
.into_complete_map()
.unwrap();
// load replay
let timelines =
strafesnet_roblox_bot_file::v0::read_all_to_block(std::io::Cursor::new(bot_file)).unwrap();
let bot = strafesnet_roblox_bot_player::bot::CompleteBot::new(timelines).unwrap();
let world_offset = bot.world_offset();
let timelines = bot.timelines();
// set up graphics
let mut g = InputGenerator::new(&map);
// training data
let training_samples = timelines.input_events.len() - 1;
let input_size = INPUT * size_of::<f32>();
let mut inputs = Vec::with_capacity(input_size * training_samples);
let mut targets = Vec::with_capacity(OUTPUT * training_samples);
// generate all frames
println!("Generating {training_samples} frames of depth textures...");
let mut it = timelines.input_events.iter();
// grab mouse position from first frame, omitting one frame from the training data
let first = it.next().unwrap();
let mut last_mx = first.event.mouse_pos.x;
let mut last_my = first.event.mouse_pos.y;
for input_event in it {
let mouse_dx = input_event.event.mouse_pos.x - last_mx;
let mouse_dy = input_event.event.mouse_pos.y - last_my;
last_mx = input_event.event.mouse_pos.x;
last_my = input_event.event.mouse_pos.y;
// set targets
targets.extend([
// MoveForward
input_event
.event
.game_controls
.contains(v0::GameControls::MoveForward) as i32 as f32,
// MoveLeft
input_event
.event
.game_controls
.contains(v0::GameControls::MoveLeft) as i32 as f32,
// MoveBack
input_event
.event
.game_controls
.contains(v0::GameControls::MoveBack) as i32 as f32,
// MoveRight
input_event
.event
.game_controls
.contains(v0::GameControls::MoveRight) as i32 as f32,
// Jump
input_event
.event
.game_controls
.contains(v0::GameControls::Jump) as i32 as f32,
mouse_dx,
mouse_dy,
]);
// find the closest output event to the input event time
let output_event_index = timelines
.output_events
.binary_search_by(|event| event.time.partial_cmp(&input_event.time).unwrap());
let output_event = match output_event_index {
// found the exact same timestamp
Ok(output_event_index) => &timelines.output_events[output_event_index],
// found first index greater than the time.
// check this index and the one before and return the closest one
Err(insert_index) => timelines
.output_events
.get(insert_index)
.into_iter()
.chain(
insert_index
.checked_sub(1)
.and_then(|index| timelines.output_events.get(index)),
)
.min_by(|&e0, &e1| {
(e0.time - input_event.time)
.abs()
.partial_cmp(&(e1.time - input_event.time).abs())
.unwrap()
})
.unwrap(),
};
fn vec3(v: v0::Vector3) -> glam::Vec3 {
glam::vec3(v.x, v.y, v.z)
}
fn angles(a: v0::Vector3) -> glam::Vec2 {
glam::vec2(a.y, a.x)
}
let pos = vec3(output_event.event.position) - world_offset;
let angles = angles(output_event.event.angles);
g.generate_inputs(pos, angles, &mut inputs);
}
let device = burn::backend::cuda::CudaDevice::new(gpu_id);
let mut model: Net<TrainingBackend> = Net::init(&device);
let num_params = model.num_params();
println!("Training model ({} parameters)", num_params);
let mut optim = AdamConfig::new().init();
let inputs = Tensor::from_data(
TensorData::new(inputs, Shape::new([training_samples, INPUT])),
&device,
);
let targets = Tensor::from_data(
TensorData::new(targets, Shape::new([training_samples, OUTPUT])),
&device,
);
let mut best_model = model.clone();
let mut best_loss = f32::INFINITY;
for epoch in 0..epochs {
let predictions = model.forward(inputs.clone());
let loss = MseLoss::new().forward(predictions, targets.clone(), Reduction::Mean);
let loss_scalar = loss.clone().into_scalar();
if epoch == 0 {
// kinda a fake print, but that's what is happening after this point
println!("Compiling optimized GPU kernels...");
}
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
// get the best model
if loss_scalar < best_loss {
best_loss = loss_scalar;
best_model = model.clone();
}
model = optim.step(learning_rate, model, grads);
if epoch % (epochs >> 4) == 0 || epoch == epochs - 1 {
// .clone().into_scalar() extracts the f32 value from a 1-element tensor.
println!(" epoch {:>5} | loss = {:.8}", epoch, loss_scalar);
}
}
let date_string = format!("{}_{}.model", num_params, best_loss);
best_model
.save_file(
date_string,
&burn::record::BinFileRecorder::<burn::record::FullPrecisionSettings>::new(),
)
.unwrap();
}