From ba3b962e8615a20970bc2cd6d446a7672075517f Mon Sep 17 00:00:00 2001 From: DragonDuck24 Date: Sun, 24 May 2026 15:43:14 -0500 Subject: [PATCH] batched mcts --- engine/src/mcts.rs | 24 +++++++++---------- engine/src/net/model.rs | 15 ++++++------ engine/src/training/train.rs | 46 +++++++++++++++++++----------------- 3 files changed, 43 insertions(+), 42 deletions(-) diff --git a/engine/src/mcts.rs b/engine/src/mcts.rs index 7e2d7a3..7833ed6 100644 --- a/engine/src/mcts.rs +++ b/engine/src/mcts.rs @@ -59,16 +59,13 @@ impl Clone for Node { #[derive(Clone, Debug)] pub struct MctsResults { pub board_state: BoardState, - pub move_dist: HashMap, + // Compact encoded move distribution: (encoded_move_index, probability) + pub move_dist: Vec<(usize, f32)>, pub value: f32, } impl MctsResults { - pub fn new( - board_state: BoardState, - move_dist: HashMap, - value: f32, - ) -> MctsResults { + pub fn new(board_state: BoardState, move_dist: Vec<(usize, f32)>, value: f32) -> MctsResults { MctsResults { board_state, move_dist, @@ -230,12 +227,15 @@ impl Mcts { } } - let mut move_dist: HashMap = HashMap::new(); // TODO: make vec<(Chessmove, f32)> - for idx in nodes[root].children.iter() { - move_dist.insert( - nodes[*idx].last_move.expect("move didnt exist"), - nodes[*idx].visit_count as f32 / self.config.num_simulations as f32, - ); + // Build compact move distribution: encoded move index -> probability (visits / num_simulations) + let mut move_dist: Vec<(usize, f32)> = Vec::with_capacity(nodes[root].children.len()); + let stm = board_state.board.side_to_move(); + let denom = self.config.num_simulations as f32; + for child_idx in nodes[root].children.iter() { + let mv = nodes[*child_idx].last_move.expect("move didnt exist"); + let enc = encode_move(mv, stm); + let prob = nodes[*child_idx].visit_count as f32 / denom; + move_dist.push((enc, prob)); } MctsResults::new(board_state.clone(), move_dist, nodes[root].value()) diff --git a/engine/src/net/model.rs b/engine/src/net/model.rs index 369dfee..0db3344 100644 --- a/engine/src/net/model.rs +++ b/engine/src/net/model.rs @@ -1,5 +1,5 @@ use crate::mcts::{BoardState, MctsResults}; -use crate::net::encoding::{encode_board_state_perspective, encode_move}; +use crate::net::encoding::encode_board_state_perspective; use burn::data::dataloader::batcher::Batcher; use burn::nn::conv::Conv2dConfig; use burn::nn::loss::{MseLoss, Reduction}; @@ -13,8 +13,6 @@ use burn::{ prelude::*, }; use burn_ndarray::NdArray; -use chess::ChessMove; -use std::collections::HashMap; /* Input planes: 1-6: your pieces (Pawn, Knight, Bishop, Rook, Queen, King) @@ -255,14 +253,15 @@ impl InferenceStep for ChessModel { #[derive(Clone)] pub struct TrainingSample { pub board_state: BoardState, - pub policy_target: HashMap, + // Compact representation: list of (encoded_move_index, probability) + pub policy_target: Vec<(usize, f32)>, pub value_target: f32, } impl TrainingSample { pub fn new( board_state: BoardState, - policy_target: HashMap, + policy_target: Vec<(usize, f32)>, value_target: f32, ) -> Self { TrainingSample { @@ -273,6 +272,7 @@ impl TrainingSample { } pub fn from_mcts_with_outcome(mcts_results: MctsResults, outcome: f32) -> Self { + // move_dist is already a compact Vec<(encoded_move_index, prob)> TrainingSample::new(mcts_results.board_state, mcts_results.move_dist, outcome) } } @@ -301,9 +301,8 @@ impl Batcher> for ChessBatcher { .cloned() .map(|item| { let mut policy = vec![0.0f32; 4672]; - let stm = item.board_state.board.side_to_move(); - for (mv, prob) in item.policy_target { - policy[encode_move(mv, stm)] = prob; + for (idx, prob) in item.policy_target.iter() { + policy[*idx] = *prob; } // Normalize diff --git a/engine/src/training/train.rs b/engine/src/training/train.rs index 2a732ae..d6b7bd1 100644 --- a/engine/src/training/train.rs +++ b/engine/src/training/train.rs @@ -1,15 +1,16 @@ use crate::mcts::{BoardState, BoardStateStatus, Mcts, MctsConfig, MctsResults}; +use crate::net::encoding::decode_move; use crate::net::model::{ChessBatcher, ChessModel, ChessModelConfig, TrainingSample}; use burn::data::dataloader::batcher::Batcher; use burn::module::{AutodiffModule, Module}; use burn::optim::{AdamConfig, GradientsParams, Optimizer}; use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder}; use burn::tensor::backend::AutodiffBackend; -use chess::ChessMove; +use chess::{ChessMove, Color}; use rand::rngs::SmallRng; use rand::seq::SliceRandom; use rand::{RngExt, SeedableRng}; -use std::collections::{HashMap, VecDeque}; +use std::collections::VecDeque; use std::marker::PhantomData; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -99,7 +100,8 @@ pub fn train(training_config: TrainingConfig, device: B::Dev let adjusted = apply_temperature(&episode_buffer.last().unwrap().move_dist, temp); - let mv = sample_move(&adjusted, &mut rng).unwrap(); + let stm = board_state.board.side_to_move(); + let mv = sample_move(&adjusted, &mut rng, stm).unwrap(); println!("playing move: {}", mv); board_state.apply_move(mv) } @@ -187,56 +189,56 @@ pub fn train(training_config: TrainingConfig, device: B::Dev return; } -fn apply_temperature( - visits: &HashMap, - temperature: f32, -) -> HashMap { +fn apply_temperature(visits: &[(usize, f32)], temperature: f32) -> Vec<(usize, f32)> { if visits.is_empty() { - return HashMap::new(); + return Vec::new(); } // Special case: deterministic selection if temperature == 0.0 { - let (&best_move, _) = visits + let (&best_idx, _) = visits .iter() - .max_by(|a, b| a.1.partial_cmp(b.1).unwrap()) + .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap()) + .map(|(i, p)| (i, p)) .unwrap(); - let mut out = HashMap::new(); - out.insert(best_move.clone(), 1.0); - return out; + return vec![(best_idx, 1.0)]; } let inv_temp = 1.0 / temperature; // Step 1: apply exponent - let mut adjusted: HashMap = - visits.iter().map(|(m, v)| (*m, v.powf(inv_temp))).collect(); + let mut adjusted: Vec<(usize, f32)> = + visits.iter().map(|(i, v)| (*i, v.powf(inv_temp))).collect(); // Step 2: normalize - let sum: f32 = adjusted.values().sum(); + let sum: f32 = adjusted.iter().map(|(_, v)| *v).sum(); if sum <= 0.0 { return adjusted; // fallback (shouldn't happen in normal MCTS) } - for v in adjusted.values_mut() { + for (_, v) in adjusted.iter_mut() { *v /= sum; } adjusted } -fn sample_move(dist: &HashMap, rng: &mut SmallRng) -> Option { +fn sample_move( + dist: &[(usize, f32)], + rng: &mut SmallRng, + side_to_move: Color, +) -> Option { let mut r: f32 = rng.random_range(0.0..1.0); - for (m, p) in dist { - r -= p; + for (idx, p) in dist { + r -= *p; if r <= 0.0 { - return Some(m.clone()); + return Some(decode_move(*idx, side_to_move)); } } // fallback due to floating point drift - dist.keys().next().cloned() + dist.get(0).map(|(idx, _)| decode_move(*idx, side_to_move)) }