Compare commits

..

2 Commits

Author SHA1 Message Date
ba3b962e86 batched mcts 2026-05-24 15:43:14 -05:00
23135b4386 Fixes 2026-05-24 13:28:33 -05:00
4 changed files with 206 additions and 89 deletions

View File

@ -1,7 +1,8 @@
#![recursion_limit = "256"] #![recursion_limit = "256"]
use burn::backend::{Autodiff, Wgpu}; use burn::backend::Autodiff;
use burn::optim::AdamConfig; use burn::optim::AdamConfig;
use burn_ndarray::NdArray;
use engine::mcts::MctsConfig; use engine::mcts::MctsConfig;
use engine::training::train::{train, TrainingConfig}; use engine::training::train::{train, TrainingConfig};
// fn main() { // fn main() {
@ -14,12 +15,12 @@ use engine::training::train::{train, TrainingConfig};
// } // }
fn main() { fn main() {
type MyBackend = Wgpu<f32, i32>; // type MyBackend = Wgpu<f32, i32>;
// type MyBackend = Cuda<f32, i32>; // type MyBackend = Cuda<f32, i32>;
// type MyBackend = NdArray<f32, i32>; type MyBackend = NdArray<f32, i32>;
type MyAutodiffBackend = Autodiff<MyBackend>; type MyAutodiffBackend = Autodiff<MyBackend>;
let device = burn::backend::wgpu::WgpuDevice::default(); // let device = burn::backend::wgpu::WgpuDevice::default();
// let device = burn::backend::ndarray::NdArrayDevice::default(); let device = burn::backend::ndarray::NdArrayDevice::default();
// let device = burn::backend::cuda::CudaDevice::default(); // let device = burn::backend::cuda::CudaDevice::default();
let mcts_config = MctsConfig::new(100, 1.0, 0.05, 0.25); let mcts_config = MctsConfig::new(100, 1.0, 0.05, 0.25);

View File

@ -5,9 +5,10 @@ use crate::net::model::ChessModel;
use burn::prelude::Backend; use burn::prelude::Backend;
use burn::Tensor; use burn::Tensor;
use chess::BoardStatus::{Checkmate, Stalemate}; use chess::BoardStatus::{Checkmate, Stalemate};
use chess::Color::White; use chess::Color::{Black, White};
use chess::Piece::{Bishop, Knight, Pawn, Queen, Rook}; use chess::Piece::{Bishop, Knight, Pawn, Queen, Rook};
use chess::{Board, ChessMove, Color, MoveGen, Piece, ALL_COLORS, ALL_PIECES}; use chess::{Board, ChessMove, Color, MoveGen, Piece, ALL_COLORS, ALL_PIECES};
use rand::SeedableRng;
use std::collections::HashMap; use std::collections::HashMap;
use std::marker::PhantomData; use std::marker::PhantomData;
@ -58,16 +59,13 @@ impl Clone for Node {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct MctsResults { pub struct MctsResults {
pub board_state: BoardState, pub board_state: BoardState,
pub move_dist: HashMap<ChessMove, f32>, // Compact encoded move distribution: (encoded_move_index, probability)
pub move_dist: Vec<(usize, f32)>,
pub value: f32, pub value: f32,
} }
impl MctsResults { impl MctsResults {
pub fn new( pub fn new(board_state: BoardState, move_dist: Vec<(usize, f32)>, value: f32) -> MctsResults {
board_state: BoardState,
move_dist: HashMap<ChessMove, f32>,
value: f32,
) -> MctsResults {
MctsResults { MctsResults {
board_state, board_state,
move_dist, move_dist,
@ -123,12 +121,25 @@ impl<B: Backend> Mcts<B> {
let root = 0; let root = 0;
nodes.push(Node::new(0.0, board_state.clone(), None)); nodes.push(Node::new(0.0, board_state.clone(), None));
// Expand root to create initial children and priors
self.expand(root, &mut nodes, model, device); self.expand(root, &mut nodes, model, device);
// 👇 APPLY DIRICHLET NOISE HERE // Apply Dirichlet noise to root children
self.add_dirichlet_noise(root, &mut nodes); self.add_dirichlet_noise(root, &mut nodes);
for _ in 0..self.config.num_simulations { // We'll batch leaf evaluations to reduce per-leaf model calls and device-host syncs.
let mut sims_done = 0usize;
let num_sims = self.config.num_simulations;
// Tunable batch size for NN evaluation. Small value is safe; larger values increase throughput on GPU.
let batch_max = 32usize;
while sims_done < num_sims {
// Collect a batch of leaf nodes (and their selection paths)
let mut leaf_nodes: Vec<usize> = Vec::new();
let mut leaf_paths: Vec<Vec<usize>> = Vec::new();
let mut leaf_states: Vec<Tensor<B, 4>> = Vec::new();
while leaf_nodes.len() < std::cmp::min(batch_max, num_sims - sims_done) {
let mut path = vec![root]; let mut path = vec![root];
let mut current = root; let mut current = root;
@ -137,18 +148,94 @@ impl<B: Backend> Mcts<B> {
path.push(current); path.push(current);
} }
let value: f32 = self.expand(current, &mut nodes, model, device); // Record leaf node and its path
leaf_nodes.push(current);
leaf_paths.push(path.clone());
let color = nodes[current].board_state.board.side_to_move(); // Prepare state tensor for this leaf
self.backpropagate(&mut nodes, &path, value, color); let state: Tensor<B, 4> = encode_board_state_perspective(&nodes[current].board_state, device)
.reshape([1, 18, 8, 8]);
leaf_states.push(state);
sims_done += 1;
} }
let mut move_dist: HashMap<ChessMove, f32> = HashMap::new(); // TODO: make vec<(Chessmove, f32)> if leaf_nodes.is_empty() {
for idx in nodes[root].children.iter() { break;
move_dist.insert( }
nodes[*idx].last_move.expect("move didnt exist"),
nodes[*idx].visit_count as f32 / self.config.num_simulations as f32, // Batch evaluate the collected leaf states
); let batch = Tensor::cat(leaf_states, 0);
let (policy_batch, value_batch) = model.forward(batch);
// Move tensors to host once per batch
let policy_data = policy_batch.into_data().to_vec::<f32>().unwrap();
let value_data = value_batch.into_data().to_vec::<f32>().unwrap();
let num_moves = policy_data.len() / leaf_nodes.len();
// Process each evaluated leaf: expand and backpropagate
for (i, &node_idx) in leaf_nodes.iter().enumerate() {
let path = &leaf_paths[i];
// slice for this sample's logits
let start = i * num_moves;
let end = start + num_moves;
let logits = &policy_data[start..end];
// Convert logits to probabilities with a numerically-stable softmax on host
let mut max_logit = std::f32::NEG_INFINITY;
for &v in logits.iter() {
if v > max_logit {
max_logit = v;
}
}
let mut exps_sum = 0.0f32;
// We'll build a Vec<f32> of probabilities lazily when needed
let mut probs: Vec<f32> = Vec::new();
probs.resize(num_moves, 0.0);
for (j, &v) in logits.iter().enumerate() {
let e = (v - max_logit).exp();
probs[j] = e;
exps_sum += e;
}
if exps_sum > 0.0 {
for p in probs.iter_mut() {
*p /= exps_sum;
}
}
// Expand: add legal moves as children with prior from probs
let legal_moves: Vec<ChessMove> = MoveGen::new_legal(&nodes[node_idx].board_state.board).collect();
for mv in legal_moves {
let stm = nodes[node_idx].board_state.board.side_to_move();
let idx = encode_move(mv, stm);
let prior = probs[idx];
let mut new_board = nodes[node_idx].board_state.clone();
new_board.apply_move(mv);
let child_idx = nodes.len();
nodes.push(Node::new(prior, new_board, Some(mv)));
nodes[node_idx].children.push(child_idx);
}
// Backpropagate the value for this leaf
let value = value_data[i];
let color = nodes[node_idx].board_state.board.side_to_move();
self.backpropagate(&mut nodes, path, value, color);
}
}
// Build compact move distribution: encoded move index -> probability (visits / num_simulations)
let mut move_dist: Vec<(usize, f32)> = Vec::with_capacity(nodes[root].children.len());
let stm = board_state.board.side_to_move();
let denom = self.config.num_simulations as f32;
for child_idx in nodes[root].children.iter() {
let mv = nodes[*child_idx].last_move.expect("move didnt exist");
let enc = encode_move(mv, stm);
let prob = nodes[*child_idx].visit_count as f32 / denom;
move_dist.push((enc, prob));
} }
MctsResults::new(board_state.clone(), move_dist, nodes[root].value()) MctsResults::new(board_state.clone(), move_dist, nodes[root].value())
@ -161,6 +248,24 @@ impl<B: Backend> Mcts<B> {
model: &ChessModel<B>, model: &ChessModel<B>,
device: &B::Device, device: &B::Device,
) -> f32 { ) -> f32 {
if arena[node_idx].board_state.status == BoardStateStatus::Stalemate
|| arena[node_idx].board_state.status == BoardStateStatus::Threefold
|| arena[node_idx].board_state.status == BoardStateStatus::FiftyMove
{
0.0
} else if arena[node_idx].board_state.status == BoardStateStatus::WhiteWinner {
if arena[node_idx].board_state.board.side_to_move() == Black {
1.0
} else {
-1.0
}
} else if arena[node_idx].board_state.status == BoardStateStatus::BlackWinner {
if arena[node_idx].board_state.board.side_to_move() == White {
1.0
} else {
-1.0
}
} else {
let state: Tensor<B, 4> = let state: Tensor<B, 4> =
encode_board_state_perspective(&arena[node_idx].board_state, device) encode_board_state_perspective(&arena[node_idx].board_state, device)
.reshape([1, 18, 8, 8]); .reshape([1, 18, 8, 8]);
@ -189,6 +294,7 @@ impl<B: Backend> Mcts<B> {
value_head.into_data().to_vec().unwrap()[0] value_head.into_data().to_vec().unwrap()[0]
} }
}
fn backpropagate(&mut self, nodes: &mut [Node], path: &[usize], value: f32, color: Color) { fn backpropagate(&mut self, nodes: &mut [Node], path: &[usize], value: f32, color: Color) {
for &idx in path { for &idx in path {
@ -229,9 +335,14 @@ fn dirichlet_sample(size: usize, alpha: f32) -> Vec<f32> {
let gamma = Gamma::new(alpha as f64, 1.0).unwrap(); let gamma = Gamma::new(alpha as f64, 1.0).unwrap();
let mut samples: Vec<f32> = (0..size) // Use a single SmallRng seeded from system time (avoid depending on thread_rng helper)
.map(|_| gamma.sample(&mut rand::rng()) as f32) let now = std::time::SystemTime::now()
.collect(); .duration_since(std::time::UNIX_EPOCH)
.unwrap();
let seed = now.as_nanos() as u64;
let mut rng = rand::rngs::SmallRng::seed_from_u64(seed);
let mut samples: Vec<f32> = (0..size).map(|_| gamma.sample(&mut rng) as f32).collect();
let sum: f32 = samples.iter().sum(); let sum: f32 = samples.iter().sum();
@ -334,8 +445,6 @@ pub fn heuristic_eval(board: &Board, perspective: Color) -> f32 {
} }
value value
// board.checkers()
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]

View File

@ -1,5 +1,5 @@
use crate::mcts::{BoardState, MctsResults}; use crate::mcts::{BoardState, MctsResults};
use crate::net::encoding::{encode_board_state_perspective, encode_move}; use crate::net::encoding::encode_board_state_perspective;
use burn::data::dataloader::batcher::Batcher; use burn::data::dataloader::batcher::Batcher;
use burn::nn::conv::Conv2dConfig; use burn::nn::conv::Conv2dConfig;
use burn::nn::loss::{MseLoss, Reduction}; use burn::nn::loss::{MseLoss, Reduction};
@ -13,8 +13,6 @@ use burn::{
prelude::*, prelude::*,
}; };
use burn_ndarray::NdArray; use burn_ndarray::NdArray;
use chess::ChessMove;
use std::collections::HashMap;
/* /*
Input planes: Input planes:
1-6: your pieces (Pawn, Knight, Bishop, Rook, Queen, King) 1-6: your pieces (Pawn, Knight, Bishop, Rook, Queen, King)
@ -255,14 +253,15 @@ impl<B: Backend> InferenceStep for ChessModel<B> {
#[derive(Clone)] #[derive(Clone)]
pub struct TrainingSample { pub struct TrainingSample {
pub board_state: BoardState, pub board_state: BoardState,
pub policy_target: HashMap<ChessMove, f32>, // Compact representation: list of (encoded_move_index, probability)
pub policy_target: Vec<(usize, f32)>,
pub value_target: f32, pub value_target: f32,
} }
impl TrainingSample { impl TrainingSample {
pub fn new( pub fn new(
board_state: BoardState, board_state: BoardState,
policy_target: HashMap<ChessMove, f32>, policy_target: Vec<(usize, f32)>,
value_target: f32, value_target: f32,
) -> Self { ) -> Self {
TrainingSample { TrainingSample {
@ -273,6 +272,7 @@ impl TrainingSample {
} }
pub fn from_mcts_with_outcome(mcts_results: MctsResults, outcome: f32) -> Self { pub fn from_mcts_with_outcome(mcts_results: MctsResults, outcome: f32) -> Self {
// move_dist is already a compact Vec<(encoded_move_index, prob)>
TrainingSample::new(mcts_results.board_state, mcts_results.move_dist, outcome) TrainingSample::new(mcts_results.board_state, mcts_results.move_dist, outcome)
} }
} }
@ -301,9 +301,8 @@ impl<B: Backend> Batcher<B, TrainingSample, ChessBatch<B>> for ChessBatcher {
.cloned() .cloned()
.map(|item| { .map(|item| {
let mut policy = vec![0.0f32; 4672]; let mut policy = vec![0.0f32; 4672];
let stm = item.board_state.board.side_to_move(); for (idx, prob) in item.policy_target.iter() {
for (mv, prob) in item.policy_target { policy[*idx] = *prob;
policy[encode_move(mv, stm)] = prob;
} }
// Normalize // Normalize

View File

@ -1,19 +1,21 @@
use crate::mcts::{BoardState, BoardStateStatus, Mcts, MctsConfig, MctsResults}; use crate::mcts::{BoardState, BoardStateStatus, Mcts, MctsConfig, MctsResults};
use crate::net::encoding::decode_move;
use crate::net::model::{ChessBatcher, ChessModel, ChessModelConfig, TrainingSample}; use crate::net::model::{ChessBatcher, ChessModel, ChessModelConfig, TrainingSample};
use burn::data::dataloader::batcher::Batcher; use burn::data::dataloader::batcher::Batcher;
use burn::module::{AutodiffModule, Module}; use burn::module::{AutodiffModule, Module};
use burn::optim::{AdamConfig, GradientsParams, Optimizer}; use burn::optim::{AdamConfig, GradientsParams, Optimizer};
use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder}; use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder};
use burn::tensor::backend::AutodiffBackend; use burn::tensor::backend::AutodiffBackend;
use chess::ChessMove; use chess::{ChessMove, Color};
use rand::rngs::ThreadRng; use rand::rngs::SmallRng;
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use rand::RngExt; use rand::{RngExt, SeedableRng};
use std::collections::{HashMap, VecDeque}; use std::collections::VecDeque;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
use std::time::{SystemTime, UNIX_EPOCH};
pub struct TrainingConfig { pub struct TrainingConfig {
pub max_time_s: Option<u64>, pub max_time_s: Option<u64>,
@ -35,8 +37,8 @@ pub fn train<B: AutodiffBackend>(training_config: TrainingConfig, device: B::Dev
let model_path = format!("artifacts/{}", training_config.model_name.as_str()); let model_path = format!("artifacts/{}", training_config.model_name.as_str());
println!("Creating model..."); println!("Creating model...");
let mut model: ChessModel<B> = ChessModelConfig::init( let mut model: ChessModel<B> = ChessModelConfig::init(
training_config.hidden_channels,
training_config.num_blocks, training_config.num_blocks,
training_config.hidden_channels,
&device, &device,
); );
if training_config.load_model { if training_config.load_model {
@ -67,7 +69,14 @@ pub fn train<B: AutodiffBackend>(training_config: TrainingConfig, device: B::Dev
_marker: PhantomData, _marker: PhantomData,
}; };
let mut rng = rand::rng(); // Create RNG once and reuse it for sampling and shuffling
// Seed from system time (platform default entropy may be unavailable in some contexts)
let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
let seed = now.as_nanos() as u64;
let mut rng = SmallRng::seed_from_u64(seed);
// Initialize optimizer once so state (moments) persist across steps
let mut optim = training_config.optimizer.init();
println!("Starting training..."); println!("Starting training...");
while train.load(Ordering::Relaxed) { while train.load(Ordering::Relaxed) {
@ -91,7 +100,8 @@ pub fn train<B: AutodiffBackend>(training_config: TrainingConfig, device: B::Dev
let adjusted = apply_temperature(&episode_buffer.last().unwrap().move_dist, temp); let adjusted = apply_temperature(&episode_buffer.last().unwrap().move_dist, temp);
let mv = sample_move(&adjusted, &mut rng).unwrap(); let stm = board_state.board.side_to_move();
let mv = sample_move(&adjusted, &mut rng, stm).unwrap();
println!("playing move: {}", mv); println!("playing move: {}", mv);
board_state.apply_move(mv) board_state.apply_move(mv)
} }
@ -140,8 +150,6 @@ pub fn train<B: AutodiffBackend>(training_config: TrainingConfig, device: B::Dev
let batch = batcher.batch(samples, &device); let batch = batcher.batch(samples, &device);
let mut optim = training_config.optimizer.init();
let output = model.forward_chess(batch.states, batch.policy_targets, batch.value_targets); let output = model.forward_chess(batch.states, batch.policy_targets, batch.value_targets);
let grads = output.loss.backward(); let grads = output.loss.backward();
@ -181,56 +189,56 @@ pub fn train<B: AutodiffBackend>(training_config: TrainingConfig, device: B::Dev
return; return;
} }
fn apply_temperature( fn apply_temperature(visits: &[(usize, f32)], temperature: f32) -> Vec<(usize, f32)> {
visits: &HashMap<ChessMove, f32>,
temperature: f32,
) -> HashMap<ChessMove, f32> {
if visits.is_empty() { if visits.is_empty() {
return HashMap::new(); return Vec::new();
} }
// Special case: deterministic selection // Special case: deterministic selection
if temperature == 0.0 { if temperature == 0.0 {
let (&best_move, _) = visits let (&best_idx, _) = visits
.iter() .iter()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap()) .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.map(|(i, p)| (i, p))
.unwrap(); .unwrap();
let mut out = HashMap::new(); return vec![(best_idx, 1.0)];
out.insert(best_move.clone(), 1.0);
return out;
} }
let inv_temp = 1.0 / temperature; let inv_temp = 1.0 / temperature;
// Step 1: apply exponent // Step 1: apply exponent
let mut adjusted: HashMap<ChessMove, f32> = let mut adjusted: Vec<(usize, f32)> =
visits.iter().map(|(m, v)| (*m, v.powf(inv_temp))).collect(); visits.iter().map(|(i, v)| (*i, v.powf(inv_temp))).collect();
// Step 2: normalize // Step 2: normalize
let sum: f32 = adjusted.values().sum(); let sum: f32 = adjusted.iter().map(|(_, v)| *v).sum();
if sum <= 0.0 { if sum <= 0.0 {
return adjusted; // fallback (shouldn't happen in normal MCTS) return adjusted; // fallback (shouldn't happen in normal MCTS)
} }
for v in adjusted.values_mut() { for (_, v) in adjusted.iter_mut() {
*v /= sum; *v /= sum;
} }
adjusted adjusted
} }
fn sample_move(dist: &HashMap<ChessMove, f32>, rng: &mut ThreadRng) -> Option<ChessMove> { fn sample_move(
dist: &[(usize, f32)],
rng: &mut SmallRng,
side_to_move: Color,
) -> Option<ChessMove> {
let mut r: f32 = rng.random_range(0.0..1.0); let mut r: f32 = rng.random_range(0.0..1.0);
for (m, p) in dist { for (idx, p) in dist {
r -= p; r -= *p;
if r <= 0.0 { if r <= 0.0 {
return Some(m.clone()); return Some(decode_move(*idx, side_to_move));
} }
} }
// fallback due to floating point drift // fallback due to floating point drift
dist.keys().next().cloned() dist.get(0).map(|(idx, _)| decode_move(*idx, side_to_move))
} }