Compare commits
No commits in common. "ba3b962e8615a20970bc2cd6d446a7672075517f" and "be589de6f45dbd1c42a803514c1dcd33e1153c22" have entirely different histories.
ba3b962e86
...
be589de6f4
@ -1,8 +1,7 @@
|
|||||||
#![recursion_limit = "256"]
|
#![recursion_limit = "256"]
|
||||||
|
|
||||||
use burn::backend::Autodiff;
|
use burn::backend::{Autodiff, Wgpu};
|
||||||
use burn::optim::AdamConfig;
|
use burn::optim::AdamConfig;
|
||||||
use burn_ndarray::NdArray;
|
|
||||||
use engine::mcts::MctsConfig;
|
use engine::mcts::MctsConfig;
|
||||||
use engine::training::train::{train, TrainingConfig};
|
use engine::training::train::{train, TrainingConfig};
|
||||||
// fn main() {
|
// fn main() {
|
||||||
@ -15,12 +14,12 @@ use engine::training::train::{train, TrainingConfig};
|
|||||||
// }
|
// }
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
// type MyBackend = Wgpu<f32, i32>;
|
type MyBackend = Wgpu<f32, i32>;
|
||||||
// type MyBackend = Cuda<f32, i32>;
|
// type MyBackend = Cuda<f32, i32>;
|
||||||
type MyBackend = NdArray<f32, i32>;
|
// type MyBackend = NdArray<f32, i32>;
|
||||||
type MyAutodiffBackend = Autodiff<MyBackend>;
|
type MyAutodiffBackend = Autodiff<MyBackend>;
|
||||||
// let device = burn::backend::wgpu::WgpuDevice::default();
|
let device = burn::backend::wgpu::WgpuDevice::default();
|
||||||
let device = burn::backend::ndarray::NdArrayDevice::default();
|
// let device = burn::backend::ndarray::NdArrayDevice::default();
|
||||||
// let device = burn::backend::cuda::CudaDevice::default();
|
// let device = burn::backend::cuda::CudaDevice::default();
|
||||||
|
|
||||||
let mcts_config = MctsConfig::new(100, 1.0, 0.05, 0.25);
|
let mcts_config = MctsConfig::new(100, 1.0, 0.05, 0.25);
|
||||||
|
|||||||
@ -5,10 +5,9 @@ use crate::net::model::ChessModel;
|
|||||||
use burn::prelude::Backend;
|
use burn::prelude::Backend;
|
||||||
use burn::Tensor;
|
use burn::Tensor;
|
||||||
use chess::BoardStatus::{Checkmate, Stalemate};
|
use chess::BoardStatus::{Checkmate, Stalemate};
|
||||||
use chess::Color::{Black, White};
|
use chess::Color::White;
|
||||||
use chess::Piece::{Bishop, Knight, Pawn, Queen, Rook};
|
use chess::Piece::{Bishop, Knight, Pawn, Queen, Rook};
|
||||||
use chess::{Board, ChessMove, Color, MoveGen, Piece, ALL_COLORS, ALL_PIECES};
|
use chess::{Board, ChessMove, Color, MoveGen, Piece, ALL_COLORS, ALL_PIECES};
|
||||||
use rand::SeedableRng;
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
@ -59,13 +58,16 @@ impl Clone for Node {
|
|||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct MctsResults {
|
pub struct MctsResults {
|
||||||
pub board_state: BoardState,
|
pub board_state: BoardState,
|
||||||
// Compact encoded move distribution: (encoded_move_index, probability)
|
pub move_dist: HashMap<ChessMove, f32>,
|
||||||
pub move_dist: Vec<(usize, f32)>,
|
|
||||||
pub value: f32,
|
pub value: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MctsResults {
|
impl MctsResults {
|
||||||
pub fn new(board_state: BoardState, move_dist: Vec<(usize, f32)>, value: f32) -> MctsResults {
|
pub fn new(
|
||||||
|
board_state: BoardState,
|
||||||
|
move_dist: HashMap<ChessMove, f32>,
|
||||||
|
value: f32,
|
||||||
|
) -> MctsResults {
|
||||||
MctsResults {
|
MctsResults {
|
||||||
board_state,
|
board_state,
|
||||||
move_dist,
|
move_dist,
|
||||||
@ -121,25 +123,12 @@ impl<B: Backend> Mcts<B> {
|
|||||||
let root = 0;
|
let root = 0;
|
||||||
nodes.push(Node::new(0.0, board_state.clone(), None));
|
nodes.push(Node::new(0.0, board_state.clone(), None));
|
||||||
|
|
||||||
// Expand root to create initial children and priors
|
|
||||||
self.expand(root, &mut nodes, model, device);
|
self.expand(root, &mut nodes, model, device);
|
||||||
|
|
||||||
// Apply Dirichlet noise to root children
|
// 👇 APPLY DIRICHLET NOISE HERE
|
||||||
self.add_dirichlet_noise(root, &mut nodes);
|
self.add_dirichlet_noise(root, &mut nodes);
|
||||||
|
|
||||||
// We'll batch leaf evaluations to reduce per-leaf model calls and device-host syncs.
|
for _ in 0..self.config.num_simulations {
|
||||||
let mut sims_done = 0usize;
|
|
||||||
let num_sims = self.config.num_simulations;
|
|
||||||
// Tunable batch size for NN evaluation. Small value is safe; larger values increase throughput on GPU.
|
|
||||||
let batch_max = 32usize;
|
|
||||||
|
|
||||||
while sims_done < num_sims {
|
|
||||||
// Collect a batch of leaf nodes (and their selection paths)
|
|
||||||
let mut leaf_nodes: Vec<usize> = Vec::new();
|
|
||||||
let mut leaf_paths: Vec<Vec<usize>> = Vec::new();
|
|
||||||
let mut leaf_states: Vec<Tensor<B, 4>> = Vec::new();
|
|
||||||
|
|
||||||
while leaf_nodes.len() < std::cmp::min(batch_max, num_sims - sims_done) {
|
|
||||||
let mut path = vec![root];
|
let mut path = vec![root];
|
||||||
let mut current = root;
|
let mut current = root;
|
||||||
|
|
||||||
@ -148,94 +137,18 @@ impl<B: Backend> Mcts<B> {
|
|||||||
path.push(current);
|
path.push(current);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Record leaf node and its path
|
let value: f32 = self.expand(current, &mut nodes, model, device);
|
||||||
leaf_nodes.push(current);
|
|
||||||
leaf_paths.push(path.clone());
|
|
||||||
|
|
||||||
// Prepare state tensor for this leaf
|
let color = nodes[current].board_state.board.side_to_move();
|
||||||
let state: Tensor<B, 4> = encode_board_state_perspective(&nodes[current].board_state, device)
|
self.backpropagate(&mut nodes, &path, value, color);
|
||||||
.reshape([1, 18, 8, 8]);
|
|
||||||
leaf_states.push(state);
|
|
||||||
|
|
||||||
sims_done += 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if leaf_nodes.is_empty() {
|
let mut move_dist: HashMap<ChessMove, f32> = HashMap::new(); // TODO: make vec<(Chessmove, f32)>
|
||||||
break;
|
for idx in nodes[root].children.iter() {
|
||||||
}
|
move_dist.insert(
|
||||||
|
nodes[*idx].last_move.expect("move didnt exist"),
|
||||||
// Batch evaluate the collected leaf states
|
nodes[*idx].visit_count as f32 / self.config.num_simulations as f32,
|
||||||
let batch = Tensor::cat(leaf_states, 0);
|
);
|
||||||
let (policy_batch, value_batch) = model.forward(batch);
|
|
||||||
|
|
||||||
// Move tensors to host once per batch
|
|
||||||
let policy_data = policy_batch.into_data().to_vec::<f32>().unwrap();
|
|
||||||
let value_data = value_batch.into_data().to_vec::<f32>().unwrap();
|
|
||||||
|
|
||||||
let num_moves = policy_data.len() / leaf_nodes.len();
|
|
||||||
|
|
||||||
// Process each evaluated leaf: expand and backpropagate
|
|
||||||
for (i, &node_idx) in leaf_nodes.iter().enumerate() {
|
|
||||||
let path = &leaf_paths[i];
|
|
||||||
|
|
||||||
// slice for this sample's logits
|
|
||||||
let start = i * num_moves;
|
|
||||||
let end = start + num_moves;
|
|
||||||
let logits = &policy_data[start..end];
|
|
||||||
|
|
||||||
// Convert logits to probabilities with a numerically-stable softmax on host
|
|
||||||
let mut max_logit = std::f32::NEG_INFINITY;
|
|
||||||
for &v in logits.iter() {
|
|
||||||
if v > max_logit {
|
|
||||||
max_logit = v;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let mut exps_sum = 0.0f32;
|
|
||||||
// We'll build a Vec<f32> of probabilities lazily when needed
|
|
||||||
let mut probs: Vec<f32> = Vec::new();
|
|
||||||
probs.resize(num_moves, 0.0);
|
|
||||||
for (j, &v) in logits.iter().enumerate() {
|
|
||||||
let e = (v - max_logit).exp();
|
|
||||||
probs[j] = e;
|
|
||||||
exps_sum += e;
|
|
||||||
}
|
|
||||||
if exps_sum > 0.0 {
|
|
||||||
for p in probs.iter_mut() {
|
|
||||||
*p /= exps_sum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Expand: add legal moves as children with prior from probs
|
|
||||||
let legal_moves: Vec<ChessMove> = MoveGen::new_legal(&nodes[node_idx].board_state.board).collect();
|
|
||||||
for mv in legal_moves {
|
|
||||||
let stm = nodes[node_idx].board_state.board.side_to_move();
|
|
||||||
let idx = encode_move(mv, stm);
|
|
||||||
let prior = probs[idx];
|
|
||||||
|
|
||||||
let mut new_board = nodes[node_idx].board_state.clone();
|
|
||||||
new_board.apply_move(mv);
|
|
||||||
|
|
||||||
let child_idx = nodes.len();
|
|
||||||
nodes.push(Node::new(prior, new_board, Some(mv)));
|
|
||||||
nodes[node_idx].children.push(child_idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Backpropagate the value for this leaf
|
|
||||||
let value = value_data[i];
|
|
||||||
let color = nodes[node_idx].board_state.board.side_to_move();
|
|
||||||
self.backpropagate(&mut nodes, path, value, color);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build compact move distribution: encoded move index -> probability (visits / num_simulations)
|
|
||||||
let mut move_dist: Vec<(usize, f32)> = Vec::with_capacity(nodes[root].children.len());
|
|
||||||
let stm = board_state.board.side_to_move();
|
|
||||||
let denom = self.config.num_simulations as f32;
|
|
||||||
for child_idx in nodes[root].children.iter() {
|
|
||||||
let mv = nodes[*child_idx].last_move.expect("move didnt exist");
|
|
||||||
let enc = encode_move(mv, stm);
|
|
||||||
let prob = nodes[*child_idx].visit_count as f32 / denom;
|
|
||||||
move_dist.push((enc, prob));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MctsResults::new(board_state.clone(), move_dist, nodes[root].value())
|
MctsResults::new(board_state.clone(), move_dist, nodes[root].value())
|
||||||
@ -248,24 +161,6 @@ impl<B: Backend> Mcts<B> {
|
|||||||
model: &ChessModel<B>,
|
model: &ChessModel<B>,
|
||||||
device: &B::Device,
|
device: &B::Device,
|
||||||
) -> f32 {
|
) -> f32 {
|
||||||
if arena[node_idx].board_state.status == BoardStateStatus::Stalemate
|
|
||||||
|| arena[node_idx].board_state.status == BoardStateStatus::Threefold
|
|
||||||
|| arena[node_idx].board_state.status == BoardStateStatus::FiftyMove
|
|
||||||
{
|
|
||||||
0.0
|
|
||||||
} else if arena[node_idx].board_state.status == BoardStateStatus::WhiteWinner {
|
|
||||||
if arena[node_idx].board_state.board.side_to_move() == Black {
|
|
||||||
1.0
|
|
||||||
} else {
|
|
||||||
-1.0
|
|
||||||
}
|
|
||||||
} else if arena[node_idx].board_state.status == BoardStateStatus::BlackWinner {
|
|
||||||
if arena[node_idx].board_state.board.side_to_move() == White {
|
|
||||||
1.0
|
|
||||||
} else {
|
|
||||||
-1.0
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
let state: Tensor<B, 4> =
|
let state: Tensor<B, 4> =
|
||||||
encode_board_state_perspective(&arena[node_idx].board_state, device)
|
encode_board_state_perspective(&arena[node_idx].board_state, device)
|
||||||
.reshape([1, 18, 8, 8]);
|
.reshape([1, 18, 8, 8]);
|
||||||
@ -294,7 +189,6 @@ impl<B: Backend> Mcts<B> {
|
|||||||
|
|
||||||
value_head.into_data().to_vec().unwrap()[0]
|
value_head.into_data().to_vec().unwrap()[0]
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
fn backpropagate(&mut self, nodes: &mut [Node], path: &[usize], value: f32, color: Color) {
|
fn backpropagate(&mut self, nodes: &mut [Node], path: &[usize], value: f32, color: Color) {
|
||||||
for &idx in path {
|
for &idx in path {
|
||||||
@ -335,14 +229,9 @@ fn dirichlet_sample(size: usize, alpha: f32) -> Vec<f32> {
|
|||||||
|
|
||||||
let gamma = Gamma::new(alpha as f64, 1.0).unwrap();
|
let gamma = Gamma::new(alpha as f64, 1.0).unwrap();
|
||||||
|
|
||||||
// Use a single SmallRng seeded from system time (avoid depending on thread_rng helper)
|
let mut samples: Vec<f32> = (0..size)
|
||||||
let now = std::time::SystemTime::now()
|
.map(|_| gamma.sample(&mut rand::rng()) as f32)
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
.collect();
|
||||||
.unwrap();
|
|
||||||
let seed = now.as_nanos() as u64;
|
|
||||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(seed);
|
|
||||||
|
|
||||||
let mut samples: Vec<f32> = (0..size).map(|_| gamma.sample(&mut rng) as f32).collect();
|
|
||||||
|
|
||||||
let sum: f32 = samples.iter().sum();
|
let sum: f32 = samples.iter().sum();
|
||||||
|
|
||||||
@ -445,6 +334,8 @@ pub fn heuristic_eval(board: &Board, perspective: Color) -> f32 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
value
|
value
|
||||||
|
|
||||||
|
// board.checkers()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
use crate::mcts::{BoardState, MctsResults};
|
use crate::mcts::{BoardState, MctsResults};
|
||||||
use crate::net::encoding::encode_board_state_perspective;
|
use crate::net::encoding::{encode_board_state_perspective, encode_move};
|
||||||
use burn::data::dataloader::batcher::Batcher;
|
use burn::data::dataloader::batcher::Batcher;
|
||||||
use burn::nn::conv::Conv2dConfig;
|
use burn::nn::conv::Conv2dConfig;
|
||||||
use burn::nn::loss::{MseLoss, Reduction};
|
use burn::nn::loss::{MseLoss, Reduction};
|
||||||
@ -13,6 +13,8 @@ use burn::{
|
|||||||
prelude::*,
|
prelude::*,
|
||||||
};
|
};
|
||||||
use burn_ndarray::NdArray;
|
use burn_ndarray::NdArray;
|
||||||
|
use chess::ChessMove;
|
||||||
|
use std::collections::HashMap;
|
||||||
/*
|
/*
|
||||||
Input planes:
|
Input planes:
|
||||||
1-6: your pieces (Pawn, Knight, Bishop, Rook, Queen, King)
|
1-6: your pieces (Pawn, Knight, Bishop, Rook, Queen, King)
|
||||||
@ -253,15 +255,14 @@ impl<B: Backend> InferenceStep for ChessModel<B> {
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct TrainingSample {
|
pub struct TrainingSample {
|
||||||
pub board_state: BoardState,
|
pub board_state: BoardState,
|
||||||
// Compact representation: list of (encoded_move_index, probability)
|
pub policy_target: HashMap<ChessMove, f32>,
|
||||||
pub policy_target: Vec<(usize, f32)>,
|
|
||||||
pub value_target: f32,
|
pub value_target: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TrainingSample {
|
impl TrainingSample {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
board_state: BoardState,
|
board_state: BoardState,
|
||||||
policy_target: Vec<(usize, f32)>,
|
policy_target: HashMap<ChessMove, f32>,
|
||||||
value_target: f32,
|
value_target: f32,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
TrainingSample {
|
TrainingSample {
|
||||||
@ -272,7 +273,6 @@ impl TrainingSample {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_mcts_with_outcome(mcts_results: MctsResults, outcome: f32) -> Self {
|
pub fn from_mcts_with_outcome(mcts_results: MctsResults, outcome: f32) -> Self {
|
||||||
// move_dist is already a compact Vec<(encoded_move_index, prob)>
|
|
||||||
TrainingSample::new(mcts_results.board_state, mcts_results.move_dist, outcome)
|
TrainingSample::new(mcts_results.board_state, mcts_results.move_dist, outcome)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -301,8 +301,9 @@ impl<B: Backend> Batcher<B, TrainingSample, ChessBatch<B>> for ChessBatcher {
|
|||||||
.cloned()
|
.cloned()
|
||||||
.map(|item| {
|
.map(|item| {
|
||||||
let mut policy = vec![0.0f32; 4672];
|
let mut policy = vec![0.0f32; 4672];
|
||||||
for (idx, prob) in item.policy_target.iter() {
|
let stm = item.board_state.board.side_to_move();
|
||||||
policy[*idx] = *prob;
|
for (mv, prob) in item.policy_target {
|
||||||
|
policy[encode_move(mv, stm)] = prob;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Normalize
|
// Normalize
|
||||||
|
|||||||
@ -1,21 +1,19 @@
|
|||||||
use crate::mcts::{BoardState, BoardStateStatus, Mcts, MctsConfig, MctsResults};
|
use crate::mcts::{BoardState, BoardStateStatus, Mcts, MctsConfig, MctsResults};
|
||||||
use crate::net::encoding::decode_move;
|
|
||||||
use crate::net::model::{ChessBatcher, ChessModel, ChessModelConfig, TrainingSample};
|
use crate::net::model::{ChessBatcher, ChessModel, ChessModelConfig, TrainingSample};
|
||||||
use burn::data::dataloader::batcher::Batcher;
|
use burn::data::dataloader::batcher::Batcher;
|
||||||
use burn::module::{AutodiffModule, Module};
|
use burn::module::{AutodiffModule, Module};
|
||||||
use burn::optim::{AdamConfig, GradientsParams, Optimizer};
|
use burn::optim::{AdamConfig, GradientsParams, Optimizer};
|
||||||
use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder};
|
use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder};
|
||||||
use burn::tensor::backend::AutodiffBackend;
|
use burn::tensor::backend::AutodiffBackend;
|
||||||
use chess::{ChessMove, Color};
|
use chess::ChessMove;
|
||||||
use rand::rngs::SmallRng;
|
use rand::rngs::ThreadRng;
|
||||||
use rand::seq::SliceRandom;
|
use rand::seq::SliceRandom;
|
||||||
use rand::{RngExt, SeedableRng};
|
use rand::RngExt;
|
||||||
use std::collections::VecDeque;
|
use std::collections::{HashMap, VecDeque};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
use std::time::{SystemTime, UNIX_EPOCH};
|
|
||||||
|
|
||||||
pub struct TrainingConfig {
|
pub struct TrainingConfig {
|
||||||
pub max_time_s: Option<u64>,
|
pub max_time_s: Option<u64>,
|
||||||
@ -37,8 +35,8 @@ pub fn train<B: AutodiffBackend>(training_config: TrainingConfig, device: B::Dev
|
|||||||
let model_path = format!("artifacts/{}", training_config.model_name.as_str());
|
let model_path = format!("artifacts/{}", training_config.model_name.as_str());
|
||||||
println!("Creating model...");
|
println!("Creating model...");
|
||||||
let mut model: ChessModel<B> = ChessModelConfig::init(
|
let mut model: ChessModel<B> = ChessModelConfig::init(
|
||||||
training_config.num_blocks,
|
|
||||||
training_config.hidden_channels,
|
training_config.hidden_channels,
|
||||||
|
training_config.num_blocks,
|
||||||
&device,
|
&device,
|
||||||
);
|
);
|
||||||
if training_config.load_model {
|
if training_config.load_model {
|
||||||
@ -69,14 +67,7 @@ pub fn train<B: AutodiffBackend>(training_config: TrainingConfig, device: B::Dev
|
|||||||
_marker: PhantomData,
|
_marker: PhantomData,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create RNG once and reuse it for sampling and shuffling
|
let mut rng = rand::rng();
|
||||||
// Seed from system time (platform default entropy may be unavailable in some contexts)
|
|
||||||
let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
|
|
||||||
let seed = now.as_nanos() as u64;
|
|
||||||
let mut rng = SmallRng::seed_from_u64(seed);
|
|
||||||
|
|
||||||
// Initialize optimizer once so state (moments) persist across steps
|
|
||||||
let mut optim = training_config.optimizer.init();
|
|
||||||
|
|
||||||
println!("Starting training...");
|
println!("Starting training...");
|
||||||
while train.load(Ordering::Relaxed) {
|
while train.load(Ordering::Relaxed) {
|
||||||
@ -100,8 +91,7 @@ pub fn train<B: AutodiffBackend>(training_config: TrainingConfig, device: B::Dev
|
|||||||
|
|
||||||
let adjusted = apply_temperature(&episode_buffer.last().unwrap().move_dist, temp);
|
let adjusted = apply_temperature(&episode_buffer.last().unwrap().move_dist, temp);
|
||||||
|
|
||||||
let stm = board_state.board.side_to_move();
|
let mv = sample_move(&adjusted, &mut rng).unwrap();
|
||||||
let mv = sample_move(&adjusted, &mut rng, stm).unwrap();
|
|
||||||
println!("playing move: {}", mv);
|
println!("playing move: {}", mv);
|
||||||
board_state.apply_move(mv)
|
board_state.apply_move(mv)
|
||||||
}
|
}
|
||||||
@ -150,6 +140,8 @@ pub fn train<B: AutodiffBackend>(training_config: TrainingConfig, device: B::Dev
|
|||||||
|
|
||||||
let batch = batcher.batch(samples, &device);
|
let batch = batcher.batch(samples, &device);
|
||||||
|
|
||||||
|
let mut optim = training_config.optimizer.init();
|
||||||
|
|
||||||
let output = model.forward_chess(batch.states, batch.policy_targets, batch.value_targets);
|
let output = model.forward_chess(batch.states, batch.policy_targets, batch.value_targets);
|
||||||
|
|
||||||
let grads = output.loss.backward();
|
let grads = output.loss.backward();
|
||||||
@ -189,56 +181,56 @@ pub fn train<B: AutodiffBackend>(training_config: TrainingConfig, device: B::Dev
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn apply_temperature(visits: &[(usize, f32)], temperature: f32) -> Vec<(usize, f32)> {
|
fn apply_temperature(
|
||||||
|
visits: &HashMap<ChessMove, f32>,
|
||||||
|
temperature: f32,
|
||||||
|
) -> HashMap<ChessMove, f32> {
|
||||||
if visits.is_empty() {
|
if visits.is_empty() {
|
||||||
return Vec::new();
|
return HashMap::new();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Special case: deterministic selection
|
// Special case: deterministic selection
|
||||||
if temperature == 0.0 {
|
if temperature == 0.0 {
|
||||||
let (&best_idx, _) = visits
|
let (&best_move, _) = visits
|
||||||
.iter()
|
.iter()
|
||||||
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
|
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
|
||||||
.map(|(i, p)| (i, p))
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
return vec![(best_idx, 1.0)];
|
let mut out = HashMap::new();
|
||||||
|
out.insert(best_move.clone(), 1.0);
|
||||||
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
let inv_temp = 1.0 / temperature;
|
let inv_temp = 1.0 / temperature;
|
||||||
|
|
||||||
// Step 1: apply exponent
|
// Step 1: apply exponent
|
||||||
let mut adjusted: Vec<(usize, f32)> =
|
let mut adjusted: HashMap<ChessMove, f32> =
|
||||||
visits.iter().map(|(i, v)| (*i, v.powf(inv_temp))).collect();
|
visits.iter().map(|(m, v)| (*m, v.powf(inv_temp))).collect();
|
||||||
|
|
||||||
// Step 2: normalize
|
// Step 2: normalize
|
||||||
let sum: f32 = adjusted.iter().map(|(_, v)| *v).sum();
|
let sum: f32 = adjusted.values().sum();
|
||||||
|
|
||||||
if sum <= 0.0 {
|
if sum <= 0.0 {
|
||||||
return adjusted; // fallback (shouldn't happen in normal MCTS)
|
return adjusted; // fallback (shouldn't happen in normal MCTS)
|
||||||
}
|
}
|
||||||
|
|
||||||
for (_, v) in adjusted.iter_mut() {
|
for v in adjusted.values_mut() {
|
||||||
*v /= sum;
|
*v /= sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
adjusted
|
adjusted
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sample_move(
|
fn sample_move(dist: &HashMap<ChessMove, f32>, rng: &mut ThreadRng) -> Option<ChessMove> {
|
||||||
dist: &[(usize, f32)],
|
|
||||||
rng: &mut SmallRng,
|
|
||||||
side_to_move: Color,
|
|
||||||
) -> Option<ChessMove> {
|
|
||||||
let mut r: f32 = rng.random_range(0.0..1.0);
|
let mut r: f32 = rng.random_range(0.0..1.0);
|
||||||
|
|
||||||
for (idx, p) in dist {
|
for (m, p) in dist {
|
||||||
r -= *p;
|
r -= p;
|
||||||
if r <= 0.0 {
|
if r <= 0.0 {
|
||||||
return Some(decode_move(*idx, side_to_move));
|
return Some(m.clone());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// fallback due to floating point drift
|
// fallback due to floating point drift
|
||||||
dist.get(0).map(|(idx, _)| decode_move(*idx, side_to_move))
|
dist.keys().next().cloned()
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user