allow configure gpu

This commit is contained in:
2026-03-26 11:44:24 -07:00
parent 9e441c1d95
commit 989bc37dc4

View File

@@ -70,6 +70,11 @@ impl<B: Backend> Net<B> {
}
fn training() {
let gpu_id: usize = std::env::args()
.skip(1)
.next()
.map(|id| id.parse().unwrap())
.unwrap_or_default();
// load map
// load replay
// setup player
@@ -292,7 +297,7 @@ fn training() {
texture_data.clear();
}
let device = Default::default();
let device = burn::backend::cuda::CudaDevice::new(gpu_id);
let mut model: Net<TrainingBackend> = Net::init(&device);
println!("Training model ({} parameters)...", model.num_params());