batched mcts
This commit is contained in:
parent
23135b4386
commit
ba3b962e86
@ -59,16 +59,13 @@ impl Clone for Node {
|
|||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct MctsResults {
|
pub struct MctsResults {
|
||||||
pub board_state: BoardState,
|
pub board_state: BoardState,
|
||||||
pub move_dist: HashMap<ChessMove, f32>,
|
// Compact encoded move distribution: (encoded_move_index, probability)
|
||||||
|
pub move_dist: Vec<(usize, f32)>,
|
||||||
pub value: f32,
|
pub value: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MctsResults {
|
impl MctsResults {
|
||||||
pub fn new(
|
pub fn new(board_state: BoardState, move_dist: Vec<(usize, f32)>, value: f32) -> MctsResults {
|
||||||
board_state: BoardState,
|
|
||||||
move_dist: HashMap<ChessMove, f32>,
|
|
||||||
value: f32,
|
|
||||||
) -> MctsResults {
|
|
||||||
MctsResults {
|
MctsResults {
|
||||||
board_state,
|
board_state,
|
||||||
move_dist,
|
move_dist,
|
||||||
@ -230,12 +227,15 @@ impl<B: Backend> Mcts<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut move_dist: HashMap<ChessMove, f32> = HashMap::new(); // TODO: make vec<(Chessmove, f32)>
|
// Build compact move distribution: encoded move index -> probability (visits / num_simulations)
|
||||||
for idx in nodes[root].children.iter() {
|
let mut move_dist: Vec<(usize, f32)> = Vec::with_capacity(nodes[root].children.len());
|
||||||
move_dist.insert(
|
let stm = board_state.board.side_to_move();
|
||||||
nodes[*idx].last_move.expect("move didnt exist"),
|
let denom = self.config.num_simulations as f32;
|
||||||
nodes[*idx].visit_count as f32 / self.config.num_simulations as f32,
|
for child_idx in nodes[root].children.iter() {
|
||||||
);
|
let mv = nodes[*child_idx].last_move.expect("move didnt exist");
|
||||||
|
let enc = encode_move(mv, stm);
|
||||||
|
let prob = nodes[*child_idx].visit_count as f32 / denom;
|
||||||
|
move_dist.push((enc, prob));
|
||||||
}
|
}
|
||||||
|
|
||||||
MctsResults::new(board_state.clone(), move_dist, nodes[root].value())
|
MctsResults::new(board_state.clone(), move_dist, nodes[root].value())
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
use crate::mcts::{BoardState, MctsResults};
|
use crate::mcts::{BoardState, MctsResults};
|
||||||
use crate::net::encoding::{encode_board_state_perspective, encode_move};
|
use crate::net::encoding::encode_board_state_perspective;
|
||||||
use burn::data::dataloader::batcher::Batcher;
|
use burn::data::dataloader::batcher::Batcher;
|
||||||
use burn::nn::conv::Conv2dConfig;
|
use burn::nn::conv::Conv2dConfig;
|
||||||
use burn::nn::loss::{MseLoss, Reduction};
|
use burn::nn::loss::{MseLoss, Reduction};
|
||||||
@ -13,8 +13,6 @@ use burn::{
|
|||||||
prelude::*,
|
prelude::*,
|
||||||
};
|
};
|
||||||
use burn_ndarray::NdArray;
|
use burn_ndarray::NdArray;
|
||||||
use chess::ChessMove;
|
|
||||||
use std::collections::HashMap;
|
|
||||||
/*
|
/*
|
||||||
Input planes:
|
Input planes:
|
||||||
1-6: your pieces (Pawn, Knight, Bishop, Rook, Queen, King)
|
1-6: your pieces (Pawn, Knight, Bishop, Rook, Queen, King)
|
||||||
@ -255,14 +253,15 @@ impl<B: Backend> InferenceStep for ChessModel<B> {
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct TrainingSample {
|
pub struct TrainingSample {
|
||||||
pub board_state: BoardState,
|
pub board_state: BoardState,
|
||||||
pub policy_target: HashMap<ChessMove, f32>,
|
// Compact representation: list of (encoded_move_index, probability)
|
||||||
|
pub policy_target: Vec<(usize, f32)>,
|
||||||
pub value_target: f32,
|
pub value_target: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TrainingSample {
|
impl TrainingSample {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
board_state: BoardState,
|
board_state: BoardState,
|
||||||
policy_target: HashMap<ChessMove, f32>,
|
policy_target: Vec<(usize, f32)>,
|
||||||
value_target: f32,
|
value_target: f32,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
TrainingSample {
|
TrainingSample {
|
||||||
@ -273,6 +272,7 @@ impl TrainingSample {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_mcts_with_outcome(mcts_results: MctsResults, outcome: f32) -> Self {
|
pub fn from_mcts_with_outcome(mcts_results: MctsResults, outcome: f32) -> Self {
|
||||||
|
// move_dist is already a compact Vec<(encoded_move_index, prob)>
|
||||||
TrainingSample::new(mcts_results.board_state, mcts_results.move_dist, outcome)
|
TrainingSample::new(mcts_results.board_state, mcts_results.move_dist, outcome)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -301,9 +301,8 @@ impl<B: Backend> Batcher<B, TrainingSample, ChessBatch<B>> for ChessBatcher {
|
|||||||
.cloned()
|
.cloned()
|
||||||
.map(|item| {
|
.map(|item| {
|
||||||
let mut policy = vec![0.0f32; 4672];
|
let mut policy = vec![0.0f32; 4672];
|
||||||
let stm = item.board_state.board.side_to_move();
|
for (idx, prob) in item.policy_target.iter() {
|
||||||
for (mv, prob) in item.policy_target {
|
policy[*idx] = *prob;
|
||||||
policy[encode_move(mv, stm)] = prob;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Normalize
|
// Normalize
|
||||||
|
|||||||
@ -1,15 +1,16 @@
|
|||||||
use crate::mcts::{BoardState, BoardStateStatus, Mcts, MctsConfig, MctsResults};
|
use crate::mcts::{BoardState, BoardStateStatus, Mcts, MctsConfig, MctsResults};
|
||||||
|
use crate::net::encoding::decode_move;
|
||||||
use crate::net::model::{ChessBatcher, ChessModel, ChessModelConfig, TrainingSample};
|
use crate::net::model::{ChessBatcher, ChessModel, ChessModelConfig, TrainingSample};
|
||||||
use burn::data::dataloader::batcher::Batcher;
|
use burn::data::dataloader::batcher::Batcher;
|
||||||
use burn::module::{AutodiffModule, Module};
|
use burn::module::{AutodiffModule, Module};
|
||||||
use burn::optim::{AdamConfig, GradientsParams, Optimizer};
|
use burn::optim::{AdamConfig, GradientsParams, Optimizer};
|
||||||
use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder};
|
use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder};
|
||||||
use burn::tensor::backend::AutodiffBackend;
|
use burn::tensor::backend::AutodiffBackend;
|
||||||
use chess::ChessMove;
|
use chess::{ChessMove, Color};
|
||||||
use rand::rngs::SmallRng;
|
use rand::rngs::SmallRng;
|
||||||
use rand::seq::SliceRandom;
|
use rand::seq::SliceRandom;
|
||||||
use rand::{RngExt, SeedableRng};
|
use rand::{RngExt, SeedableRng};
|
||||||
use std::collections::{HashMap, VecDeque};
|
use std::collections::VecDeque;
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@ -99,7 +100,8 @@ pub fn train<B: AutodiffBackend>(training_config: TrainingConfig, device: B::Dev
|
|||||||
|
|
||||||
let adjusted = apply_temperature(&episode_buffer.last().unwrap().move_dist, temp);
|
let adjusted = apply_temperature(&episode_buffer.last().unwrap().move_dist, temp);
|
||||||
|
|
||||||
let mv = sample_move(&adjusted, &mut rng).unwrap();
|
let stm = board_state.board.side_to_move();
|
||||||
|
let mv = sample_move(&adjusted, &mut rng, stm).unwrap();
|
||||||
println!("playing move: {}", mv);
|
println!("playing move: {}", mv);
|
||||||
board_state.apply_move(mv)
|
board_state.apply_move(mv)
|
||||||
}
|
}
|
||||||
@ -187,56 +189,56 @@ pub fn train<B: AutodiffBackend>(training_config: TrainingConfig, device: B::Dev
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn apply_temperature(
|
fn apply_temperature(visits: &[(usize, f32)], temperature: f32) -> Vec<(usize, f32)> {
|
||||||
visits: &HashMap<ChessMove, f32>,
|
|
||||||
temperature: f32,
|
|
||||||
) -> HashMap<ChessMove, f32> {
|
|
||||||
if visits.is_empty() {
|
if visits.is_empty() {
|
||||||
return HashMap::new();
|
return Vec::new();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Special case: deterministic selection
|
// Special case: deterministic selection
|
||||||
if temperature == 0.0 {
|
if temperature == 0.0 {
|
||||||
let (&best_move, _) = visits
|
let (&best_idx, _) = visits
|
||||||
.iter()
|
.iter()
|
||||||
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
|
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
|
||||||
|
.map(|(i, p)| (i, p))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let mut out = HashMap::new();
|
return vec![(best_idx, 1.0)];
|
||||||
out.insert(best_move.clone(), 1.0);
|
|
||||||
return out;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let inv_temp = 1.0 / temperature;
|
let inv_temp = 1.0 / temperature;
|
||||||
|
|
||||||
// Step 1: apply exponent
|
// Step 1: apply exponent
|
||||||
let mut adjusted: HashMap<ChessMove, f32> =
|
let mut adjusted: Vec<(usize, f32)> =
|
||||||
visits.iter().map(|(m, v)| (*m, v.powf(inv_temp))).collect();
|
visits.iter().map(|(i, v)| (*i, v.powf(inv_temp))).collect();
|
||||||
|
|
||||||
// Step 2: normalize
|
// Step 2: normalize
|
||||||
let sum: f32 = adjusted.values().sum();
|
let sum: f32 = adjusted.iter().map(|(_, v)| *v).sum();
|
||||||
|
|
||||||
if sum <= 0.0 {
|
if sum <= 0.0 {
|
||||||
return adjusted; // fallback (shouldn't happen in normal MCTS)
|
return adjusted; // fallback (shouldn't happen in normal MCTS)
|
||||||
}
|
}
|
||||||
|
|
||||||
for v in adjusted.values_mut() {
|
for (_, v) in adjusted.iter_mut() {
|
||||||
*v /= sum;
|
*v /= sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
adjusted
|
adjusted
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sample_move(dist: &HashMap<ChessMove, f32>, rng: &mut SmallRng) -> Option<ChessMove> {
|
fn sample_move(
|
||||||
|
dist: &[(usize, f32)],
|
||||||
|
rng: &mut SmallRng,
|
||||||
|
side_to_move: Color,
|
||||||
|
) -> Option<ChessMove> {
|
||||||
let mut r: f32 = rng.random_range(0.0..1.0);
|
let mut r: f32 = rng.random_range(0.0..1.0);
|
||||||
|
|
||||||
for (m, p) in dist {
|
for (idx, p) in dist {
|
||||||
r -= p;
|
r -= *p;
|
||||||
if r <= 0.0 {
|
if r <= 0.0 {
|
||||||
return Some(m.clone());
|
return Some(decode_move(*idx, side_to_move));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// fallback due to floating point drift
|
// fallback due to floating point drift
|
||||||
dist.keys().next().cloned()
|
dist.get(0).map(|(idx, _)| decode_move(*idx, side_to_move))
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user