diff --git a/engine/src/main.rs b/engine/src/main.rs index 843cfba..6cd3487 100644 --- a/engine/src/main.rs +++ b/engine/src/main.rs @@ -1,7 +1,8 @@ #![recursion_limit = "256"] -use burn::backend::{Autodiff, Wgpu}; +use burn::backend::Autodiff; use burn::optim::AdamConfig; +use burn_ndarray::NdArray; use engine::mcts::MctsConfig; use engine::training::train::{train, TrainingConfig}; // fn main() { @@ -14,12 +15,12 @@ use engine::training::train::{train, TrainingConfig}; // } fn main() { - type MyBackend = Wgpu; + // type MyBackend = Wgpu; // type MyBackend = Cuda; - // type MyBackend = NdArray; + type MyBackend = NdArray; type MyAutodiffBackend = Autodiff; - let device = burn::backend::wgpu::WgpuDevice::default(); - // let device = burn::backend::ndarray::NdArrayDevice::default(); + // let device = burn::backend::wgpu::WgpuDevice::default(); + let device = burn::backend::ndarray::NdArrayDevice::default(); // let device = burn::backend::cuda::CudaDevice::default(); let mcts_config = MctsConfig::new(100, 1.0, 0.05, 0.25); diff --git a/engine/src/mcts.rs b/engine/src/mcts.rs index fc3a507..7e2d7a3 100644 --- a/engine/src/mcts.rs +++ b/engine/src/mcts.rs @@ -5,9 +5,10 @@ use crate::net::model::ChessModel; use burn::prelude::Backend; use burn::Tensor; use chess::BoardStatus::{Checkmate, Stalemate}; -use chess::Color::White; +use chess::Color::{Black, White}; use chess::Piece::{Bishop, Knight, Pawn, Queen, Rook}; use chess::{Board, ChessMove, Color, MoveGen, Piece, ALL_COLORS, ALL_PIECES}; +use rand::SeedableRng; use std::collections::HashMap; use std::marker::PhantomData; @@ -123,24 +124,110 @@ impl Mcts { let root = 0; nodes.push(Node::new(0.0, board_state.clone(), None)); + // Expand root to create initial children and priors self.expand(root, &mut nodes, model, device); - // 👇 APPLY DIRICHLET NOISE HERE + // Apply Dirichlet noise to root children self.add_dirichlet_noise(root, &mut nodes); - for _ in 0..self.config.num_simulations { - let mut path = vec![root]; - let mut current = root; + // We'll batch leaf evaluations to reduce per-leaf model calls and device-host syncs. + let mut sims_done = 0usize; + let num_sims = self.config.num_simulations; + // Tunable batch size for NN evaluation. Small value is safe; larger values increase throughput on GPU. + let batch_max = 32usize; - while !nodes[current].children.is_empty() { - current = nodes[current].select_child(&nodes, &self.config.c_puct); - path.push(current); + while sims_done < num_sims { + // Collect a batch of leaf nodes (and their selection paths) + let mut leaf_nodes: Vec = Vec::new(); + let mut leaf_paths: Vec> = Vec::new(); + let mut leaf_states: Vec> = Vec::new(); + + while leaf_nodes.len() < std::cmp::min(batch_max, num_sims - sims_done) { + let mut path = vec![root]; + let mut current = root; + + while !nodes[current].children.is_empty() { + current = nodes[current].select_child(&nodes, &self.config.c_puct); + path.push(current); + } + + // Record leaf node and its path + leaf_nodes.push(current); + leaf_paths.push(path.clone()); + + // Prepare state tensor for this leaf + let state: Tensor = encode_board_state_perspective(&nodes[current].board_state, device) + .reshape([1, 18, 8, 8]); + leaf_states.push(state); + + sims_done += 1; } - let value: f32 = self.expand(current, &mut nodes, model, device); + if leaf_nodes.is_empty() { + break; + } - let color = nodes[current].board_state.board.side_to_move(); - self.backpropagate(&mut nodes, &path, value, color); + // Batch evaluate the collected leaf states + let batch = Tensor::cat(leaf_states, 0); + let (policy_batch, value_batch) = model.forward(batch); + + // Move tensors to host once per batch + let policy_data = policy_batch.into_data().to_vec::().unwrap(); + let value_data = value_batch.into_data().to_vec::().unwrap(); + + let num_moves = policy_data.len() / leaf_nodes.len(); + + // Process each evaluated leaf: expand and backpropagate + for (i, &node_idx) in leaf_nodes.iter().enumerate() { + let path = &leaf_paths[i]; + + // slice for this sample's logits + let start = i * num_moves; + let end = start + num_moves; + let logits = &policy_data[start..end]; + + // Convert logits to probabilities with a numerically-stable softmax on host + let mut max_logit = std::f32::NEG_INFINITY; + for &v in logits.iter() { + if v > max_logit { + max_logit = v; + } + } + let mut exps_sum = 0.0f32; + // We'll build a Vec of probabilities lazily when needed + let mut probs: Vec = Vec::new(); + probs.resize(num_moves, 0.0); + for (j, &v) in logits.iter().enumerate() { + let e = (v - max_logit).exp(); + probs[j] = e; + exps_sum += e; + } + if exps_sum > 0.0 { + for p in probs.iter_mut() { + *p /= exps_sum; + } + } + + // Expand: add legal moves as children with prior from probs + let legal_moves: Vec = MoveGen::new_legal(&nodes[node_idx].board_state.board).collect(); + for mv in legal_moves { + let stm = nodes[node_idx].board_state.board.side_to_move(); + let idx = encode_move(mv, stm); + let prior = probs[idx]; + + let mut new_board = nodes[node_idx].board_state.clone(); + new_board.apply_move(mv); + + let child_idx = nodes.len(); + nodes.push(Node::new(prior, new_board, Some(mv))); + nodes[node_idx].children.push(child_idx); + } + + // Backpropagate the value for this leaf + let value = value_data[i]; + let color = nodes[node_idx].board_state.board.side_to_move(); + self.backpropagate(&mut nodes, path, value, color); + } } let mut move_dist: HashMap = HashMap::new(); // TODO: make vec<(Chessmove, f32)> @@ -161,33 +248,52 @@ impl Mcts { model: &ChessModel, device: &B::Device, ) -> f32 { - let state: Tensor = - encode_board_state_perspective(&arena[node_idx].board_state, device) - .reshape([1, 18, 8, 8]); - // let start = Instant::now(); - let (policy_head, value_head) = model.forward(state); - // println!("time: {:?}", start.elapsed()); + if arena[node_idx].board_state.status == BoardStateStatus::Stalemate + || arena[node_idx].board_state.status == BoardStateStatus::Threefold + || arena[node_idx].board_state.status == BoardStateStatus::FiftyMove + { + 0.0 + } else if arena[node_idx].board_state.status == BoardStateStatus::WhiteWinner { + if arena[node_idx].board_state.board.side_to_move() == Black { + 1.0 + } else { + -1.0 + } + } else if arena[node_idx].board_state.status == BoardStateStatus::BlackWinner { + if arena[node_idx].board_state.board.side_to_move() == White { + 1.0 + } else { + -1.0 + } + } else { + let state: Tensor = + encode_board_state_perspective(&arena[node_idx].board_state, device) + .reshape([1, 18, 8, 8]); + // let start = Instant::now(); + let (policy_head, value_head) = model.forward(state); + // println!("time: {:?}", start.elapsed()); - let legal_moves: Vec = - MoveGen::new_legal(&arena[node_idx].board_state.board).collect(); + let legal_moves: Vec = + MoveGen::new_legal(&arena[node_idx].board_state.board).collect(); - let policy = policy_head.into_data().to_vec::().unwrap(); + let policy = policy_head.into_data().to_vec::().unwrap(); - for mv in legal_moves { - let stm = arena[node_idx].board_state.board.side_to_move(); - let idx = encode_move(mv, stm); - let prior = policy[idx]; + for mv in legal_moves { + let stm = arena[node_idx].board_state.board.side_to_move(); + let idx = encode_move(mv, stm); + let prior = policy[idx]; - let mut new_board = arena[node_idx].board_state.clone(); - new_board.apply_move(mv); + let mut new_board = arena[node_idx].board_state.clone(); + new_board.apply_move(mv); - let child_idx = arena.len(); + let child_idx = arena.len(); - arena.push(Node::new(prior, new_board, Some(mv))); - arena[node_idx].children.push(child_idx); + arena.push(Node::new(prior, new_board, Some(mv))); + arena[node_idx].children.push(child_idx); + } + + value_head.into_data().to_vec().unwrap()[0] } - - value_head.into_data().to_vec().unwrap()[0] } fn backpropagate(&mut self, nodes: &mut [Node], path: &[usize], value: f32, color: Color) { @@ -229,9 +335,14 @@ fn dirichlet_sample(size: usize, alpha: f32) -> Vec { let gamma = Gamma::new(alpha as f64, 1.0).unwrap(); - let mut samples: Vec = (0..size) - .map(|_| gamma.sample(&mut rand::rng()) as f32) - .collect(); + // Use a single SmallRng seeded from system time (avoid depending on thread_rng helper) + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap(); + let seed = now.as_nanos() as u64; + let mut rng = rand::rngs::SmallRng::seed_from_u64(seed); + + let mut samples: Vec = (0..size).map(|_| gamma.sample(&mut rng) as f32).collect(); let sum: f32 = samples.iter().sum(); @@ -334,8 +445,6 @@ pub fn heuristic_eval(board: &Board, perspective: Color) -> f32 { } value - - // board.checkers() } #[derive(Debug, Clone, Copy, PartialEq, Eq)] diff --git a/engine/src/training/train.rs b/engine/src/training/train.rs index e2203e5..2a732ae 100644 --- a/engine/src/training/train.rs +++ b/engine/src/training/train.rs @@ -6,14 +6,15 @@ use burn::optim::{AdamConfig, GradientsParams, Optimizer}; use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder}; use burn::tensor::backend::AutodiffBackend; use chess::ChessMove; -use rand::rngs::ThreadRng; +use rand::rngs::SmallRng; use rand::seq::SliceRandom; -use rand::RngExt; +use rand::{RngExt, SeedableRng}; use std::collections::{HashMap, VecDeque}; use std::marker::PhantomData; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::time::Instant; +use std::time::{SystemTime, UNIX_EPOCH}; pub struct TrainingConfig { pub max_time_s: Option, @@ -35,8 +36,8 @@ pub fn train(training_config: TrainingConfig, device: B::Dev let model_path = format!("artifacts/{}", training_config.model_name.as_str()); println!("Creating model..."); let mut model: ChessModel = ChessModelConfig::init( - training_config.hidden_channels, training_config.num_blocks, + training_config.hidden_channels, &device, ); if training_config.load_model { @@ -67,7 +68,14 @@ pub fn train(training_config: TrainingConfig, device: B::Dev _marker: PhantomData, }; - let mut rng = rand::rng(); + // Create RNG once and reuse it for sampling and shuffling + // Seed from system time (platform default entropy may be unavailable in some contexts) + let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap(); + let seed = now.as_nanos() as u64; + let mut rng = SmallRng::seed_from_u64(seed); + + // Initialize optimizer once so state (moments) persist across steps + let mut optim = training_config.optimizer.init(); println!("Starting training..."); while train.load(Ordering::Relaxed) { @@ -140,8 +148,6 @@ pub fn train(training_config: TrainingConfig, device: B::Dev let batch = batcher.batch(samples, &device); - let mut optim = training_config.optimizer.init(); - let output = model.forward_chess(batch.states, batch.policy_targets, batch.value_targets); let grads = output.loss.backward(); @@ -221,7 +227,7 @@ fn apply_temperature( adjusted } -fn sample_move(dist: &HashMap, rng: &mut ThreadRng) -> Option { +fn sample_move(dist: &HashMap, rng: &mut SmallRng) -> Option { let mut r: f32 = rng.random_range(0.0..1.0); for (m, p) in dist {