Compare commits
5 Commits
print-inpu
...
png
| Author | SHA1 | Date | |
|---|---|---|---|
|
7f9a76bd71
|
|||
|
757eb2a572
|
|||
|
0828f2ced0
|
|||
|
406763b8db
|
|||
|
d069271542
|
9
Cargo.lock
generated
9
Cargo.lock
generated
@@ -5451,6 +5451,7 @@ name = "strafe-ai"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"burn",
|
||||
"png",
|
||||
"pollster",
|
||||
"strafesnet_common",
|
||||
"strafesnet_graphics",
|
||||
@@ -5478,9 +5479,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "strafesnet_graphics"
|
||||
version = "0.0.11-depth2"
|
||||
version = "0.0.11-depth"
|
||||
source = "sparse+https://git.itzana.me/api/packages/strafesnet/cargo/"
|
||||
checksum = "829804ab9c167365e576de8ebd8a245ad979cb24558b086e693e840697d7956c"
|
||||
checksum = "16266ca7e57ce802b7abd24c6cd8f9b8d95752f7eaead27e42b431b9768d6135"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"ddsfile",
|
||||
@@ -5515,9 +5516,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "strafesnet_roblox_bot_player"
|
||||
version = "0.6.2-depth2"
|
||||
version = "0.6.2-depth"
|
||||
source = "sparse+https://git.itzana.me/api/packages/strafesnet/cargo/"
|
||||
checksum = "f39e7dfc0cb23e482089dc7eac235ad4b274ccfdb8df7617889a90e64a1e247a"
|
||||
checksum = "12d1aa21c174f23f7f7ede583292a8c82e4b3c483fb0d950e58f84d52807f6ed"
|
||||
dependencies = [
|
||||
"glam",
|
||||
"strafesnet_common",
|
||||
|
||||
@@ -8,9 +8,10 @@ burn = { version = "0.20.1", features = ["cuda", "autodiff"] }
|
||||
wgpu = "29.0.0"
|
||||
|
||||
strafesnet_common = { version = "0.9.0", registry = "strafesnet" }
|
||||
strafesnet_graphics = { version = "=0.0.11-depth2", registry = "strafesnet" }
|
||||
strafesnet_graphics = { version = "=0.0.11-depth", registry = "strafesnet" }
|
||||
strafesnet_physics = { version = "=0.0.2-surf", registry = "strafesnet" }
|
||||
strafesnet_roblox_bot_file = { version = "0.9.4", registry = "strafesnet" }
|
||||
strafesnet_roblox_bot_player = { version = "=0.6.2-depth2", registry = "strafesnet" }
|
||||
strafesnet_roblox_bot_player = { version = "=0.6.2-depth", registry = "strafesnet" }
|
||||
strafesnet_snf = { version = "0.4.0", registry = "strafesnet" }
|
||||
pollster = "0.4.0"
|
||||
png = "0.18.1"
|
||||
|
||||
39
src/main.rs
39
src/main.rs
@@ -1,7 +1,7 @@
|
||||
use burn::backend::Autodiff;
|
||||
use burn::nn::loss::{MseLoss, Reduction};
|
||||
use burn::nn::{Linear, LinearConfig, Relu};
|
||||
use burn::optim::{GradientsParams, Optimizer, AdamConfig};
|
||||
use burn::optim::{GradientsParams, Optimizer, SgdConfig};
|
||||
use burn::prelude::*;
|
||||
|
||||
type InferenceBackend = burn::backend::Cuda<f32>;
|
||||
@@ -11,11 +11,31 @@ const LIMITS: wgpu::Limits = wgpu::Limits::defaults();
|
||||
use strafesnet_graphics::setup;
|
||||
use strafesnet_roblox_bot_file::v0;
|
||||
|
||||
pub fn output_image_native(image_data: &[u8], texture_dims: (usize, usize), path: String) {
|
||||
use std::io::Write;
|
||||
|
||||
let mut png_data = Vec::<u8>::with_capacity(image_data.len());
|
||||
let mut encoder =
|
||||
png::Encoder::new(&mut png_data, texture_dims.0 as u32, texture_dims.1 as u32);
|
||||
encoder.set_color(png::ColorType::Grayscale);
|
||||
let mut png_writer = encoder.write_header().unwrap();
|
||||
png_writer.write_image_data(image_data).unwrap();
|
||||
png_writer.finish().unwrap();
|
||||
|
||||
let mut file = std::fs::File::create(&path).unwrap();
|
||||
file.write_all(&png_data).unwrap();
|
||||
}
|
||||
|
||||
const SIZE_X: usize = 64;
|
||||
const SIZE_Y: usize = 36;
|
||||
const INPUT: usize = SIZE_X * SIZE_Y;
|
||||
const HIDDEN: [usize; 2] = [
|
||||
const HIDDEN: [usize; 7] = [
|
||||
INPUT >> 1,
|
||||
INPUT >> 2,
|
||||
INPUT >> 3,
|
||||
INPUT >> 4,
|
||||
INPUT >> 5,
|
||||
INPUT >> 6,
|
||||
INPUT >> 7,
|
||||
];
|
||||
// MoveForward
|
||||
@@ -162,7 +182,7 @@ fn training() {
|
||||
let mut last_mx = first.event.mouse_pos.x;
|
||||
let mut last_my = first.event.mouse_pos.y;
|
||||
|
||||
for input_event in it {
|
||||
for (i, input_event) in it.enumerate() {
|
||||
let mouse_dx = input_event.event.mouse_pos.x - last_mx;
|
||||
let mouse_dy = input_event.event.mouse_pos.y - last_my;
|
||||
last_mx = input_event.event.mouse_pos.x;
|
||||
@@ -302,6 +322,17 @@ fn training() {
|
||||
let inputs_end = inputs.len();
|
||||
println!("inputs = {:?}", &inputs[inputs_start..inputs_end]);
|
||||
|
||||
// write a png
|
||||
output_image_native(
|
||||
&inputs[i * INPUT..(i + 1) * INPUT]
|
||||
.iter()
|
||||
.copied()
|
||||
.map(|f| (f * 255.0) as u8)
|
||||
.collect::<Vec<u8>>(),
|
||||
(SIZE_X, SIZE_Y),
|
||||
format!("depth_images/{i}.png").into(),
|
||||
);
|
||||
|
||||
texture_data.clear();
|
||||
}
|
||||
|
||||
@@ -310,7 +341,7 @@ fn training() {
|
||||
let mut model: Net<TrainingBackend> = Net::init(&device);
|
||||
println!("Training model ({} parameters)", model.num_params());
|
||||
|
||||
let mut optim = AdamConfig::new().init();
|
||||
let mut optim = SgdConfig::new().init();
|
||||
|
||||
let inputs = Tensor::from_data(
|
||||
TensorData::new(
|
||||
|
||||
Reference in New Issue
Block a user