From 7d5c0cf7ebbbba7f4730d49939a194dd21baac5e Mon Sep 17 00:00:00 2001 From: DragonDuck24 Date: Sat, 23 May 2026 23:12:27 -0500 Subject: [PATCH] updates --- engine/src/main.rs | 16 ++++++++-------- engine/src/mcts.rs | 5 ++--- engine/src/net/model.rs | 9 ++++++--- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/engine/src/main.rs b/engine/src/main.rs index dbaead3..843cfba 100644 --- a/engine/src/main.rs +++ b/engine/src/main.rs @@ -1,6 +1,6 @@ #![recursion_limit = "256"] -use burn::backend::{Autodiff, Cuda}; +use burn::backend::{Autodiff, Wgpu}; use burn::optim::AdamConfig; use engine::mcts::MctsConfig; use engine::training::train::{train, TrainingConfig}; @@ -14,15 +14,15 @@ use engine::training::train::{train, TrainingConfig}; // } fn main() { - // type MyBackend = Wgpu; - type MyBackend = Cuda; + type MyBackend = Wgpu; + // type MyBackend = Cuda; // type MyBackend = NdArray; type MyAutodiffBackend = Autodiff; - // let device = burn::backend::wgpu::WgpuDevice::default(); + let device = burn::backend::wgpu::WgpuDevice::default(); // let device = burn::backend::ndarray::NdArrayDevice::default(); - let device = burn::backend::cuda::CudaDevice::default(); + // let device = burn::backend::cuda::CudaDevice::default(); - let mcts_config = MctsConfig::new(10, 1.0, 0.05, 0.25); + let mcts_config = MctsConfig::new(100, 1.0, 0.05, 0.25); let adam_config = AdamConfig::new(); @@ -34,8 +34,8 @@ fn main() { load_model: false, hidden_channels: 64, num_blocks: 4, - batch_size: 128, // num positions sampled to a batch (2048) - num_episodes: 10, // num games generated per iteration (5000) + batch_size: 1024, // num positions sampled to a batch (2048) + num_episodes: 100, // num games generated per iteration (5000) buffer_max_size: 200_000, // max number of samples in the buffer mcts_config, optimizer: adam_config, diff --git a/engine/src/mcts.rs b/engine/src/mcts.rs index 133fe04..6f12328 100644 --- a/engine/src/mcts.rs +++ b/engine/src/mcts.rs @@ -10,7 +10,6 @@ use chess::Piece::{Bishop, Knight, Pawn, Queen, Rook}; use chess::{Board, ChessMove, Color, MoveGen, Piece, ALL_COLORS, ALL_PIECES}; use std::collections::HashMap; use std::marker::PhantomData; -use std::time::Instant; pub struct Node { pub prior: f32, @@ -168,9 +167,9 @@ impl Mcts { let state: Tensor = encode_board_state_perspective(&arena[node_idx].board_state, device) .reshape([1, 18, 8, 8]); - let start = Instant::now(); + // let start = Instant::now(); let (policy_head, value_head) = model.forward(state); - println!("time: {:?}", start.elapsed()); + // println!("time: {:?}", start.elapsed()); let legal_moves: Vec = MoveGen::new_legal(&arena[node_idx].board_state.board).collect(); diff --git a/engine/src/net/model.rs b/engine/src/net/model.rs index 3b87bcc..369dfee 100644 --- a/engine/src/net/model.rs +++ b/engine/src/net/model.rs @@ -313,15 +313,18 @@ impl Batcher> for ChessBatcher { *p /= sum; } } - Tensor::::from_floats(TensorData::new(policy, [4672]), device).unsqueeze() + Tensor::::from_floats(TensorData::new(policy, [1, 4672]), device).unsqueeze() }) .collect::>(); let value_target_tensors = items .iter() .map(|item| { - Tensor::::from_floats(TensorData::new(vec![item.value_target], [1]), device) - .reshape([1, 1]) + Tensor::::from_floats( + TensorData::new(vec![item.value_target], [1, 1]), + device, + ) + .reshape([1, 1]) }) .collect::>();