use crate::mcts::{BoardState, MctsResults}; use crate::net::encoding::{encode_board_state_perspective, encode_move}; use burn::data::dataloader::batcher::Batcher; use burn::nn::conv::Conv2dConfig; use burn::nn::loss::{MseLoss, Reduction}; use burn::nn::{BatchNorm, BatchNormConfig, LinearConfig, PaddingConfig2d}; use burn::tensor::activation::log_softmax; use burn::tensor::backend::AutodiffBackend; use burn::tensor::Transaction; use burn::train::{InferenceStep, ItemLazy, TrainOutput, TrainStep}; use burn::{ nn::{conv::Conv2d, Linear, Relu}, prelude::*, }; use burn_ndarray::NdArray; use chess::ChessMove; use std::collections::HashMap; /* Input planes: 1-6: your pieces (Pawn, Knight, Bishop, Rook, Queen, King) 7-12: opponent pieces 13: your kingside 14: your queenside 15: opponent kingside 16: opponent queenside 17: rep >= 2 18: En passant square (gap pass over) Output: 4672 moves represented through planes 1-56: sliding moves 57-64: Knight moves 65-73: underpromotions */ #[derive(Module, Debug)] pub struct ResidualBlock { conv1: Conv2d, bn1: BatchNorm, conv2: Conv2d, bn2: BatchNorm, activation: Relu, } #[derive(Config, Debug)] pub struct ResidualBlockConfig { channels: usize, } impl ResidualBlock { pub fn new(channels: usize, device: &B::Device) -> Self { Self { conv1: Conv2dConfig::new([channels, channels], [3, 3]) .with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1)) .init(device), bn1: BatchNormConfig::new(channels).init(device), conv2: Conv2dConfig::new([channels, channels], [3, 3]) .with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1)) .init(device), bn2: BatchNormConfig::new(channels).init(device), activation: Relu::new(), } } pub fn forward(&self, x: Tensor) -> Tensor { let residual = x.clone(); let x = self.conv1.forward(x); let x = self.bn1.forward(x); let x = self.activation.forward(x); let x = self.conv2.forward(x); let x = self.bn2.forward(x); self.activation.forward(x + residual) } } #[derive(Module, Debug)] pub struct ChessModel { // Initial convolution conv: Conv2d, bn: BatchNorm, // Residual tower residual_blocks: Vec>, // Policy head policy_conv: Conv2d, policy_bn: BatchNorm, policy_fc: Linear, // Value head value_conv: Conv2d, value_bn: BatchNorm, value_fc1: Linear, value_fc2: Linear, activation: Relu, } #[derive(Config, Debug)] pub struct ChessModelConfig { num_blocks: usize, channels: usize, } impl ChessModelConfig { pub fn init( num_blocks: usize, channels: usize, device: &B::Device, ) -> ChessModel { ChessModel { conv: Conv2dConfig::new([18, channels], [3, 3]) // 18 plane input .with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1)) .init(device), bn: BatchNormConfig::new(channels).init(device), residual_blocks: (0..num_blocks) .map(|_| ResidualBlock::new(channels, device)) .collect(), // Policy head policy_conv: Conv2dConfig::new([channels, 2], [1, 1]).init(device), policy_bn: BatchNormConfig::new(2).init(device), policy_fc: LinearConfig::new(2 * 8 * 8, 8 * 8 * 73).init(device), // 4672 typical chess move space // Value head value_conv: Conv2dConfig::new([channels, 1], [1, 1]).init(device), value_bn: BatchNormConfig::new(1).init(device), value_fc1: LinearConfig::new(1 * 8 * 8, 256).init(device), value_fc2: LinearConfig::new(256, 1).init(device), activation: Relu::new(), } } } impl ChessModel { pub fn forward(&self, x: Tensor) -> (Tensor, Tensor) { let mut x = self.conv.forward(x); x = self.bn.forward(x); x = self.activation.forward(x); for block in &self.residual_blocks { x = block.forward(x); } let batch_size = x.dims()[0]; // -------- Policy Head -------- let mut p = self.policy_conv.forward(x.clone()); p = self.policy_bn.forward(p); p = self.activation.forward(p); let mut p = p.reshape([batch_size, 2 * 8 * 8]); p = self.policy_fc.forward(p); // -------- Value Head -------- let mut v = self.value_conv.forward(x); v = self.value_bn.forward(v); v = self.activation.forward(v); let mut v = v.reshape([batch_size, 8 * 8]); v = self.activation.forward(self.value_fc1.forward(v)); v = self.value_fc2.forward(v).tanh(); (p, v) } pub fn forward_chess( &self, boards: Tensor, // e.g. [batch, channels, 8, 8] policy_targets: Tensor, // move distribution (make sure is normalized) value_targets: Tensor, // scalar evaluation ) -> ChessOutput { let (policy_logits, value) = self.forward(boards); let log_probs = log_softmax(policy_logits.clone(), 1); let policy_loss = policy_targets .clone() .mul(log_probs) .neg() .sum_dim(1) .mean(); let value_loss = MseLoss::new().forward(value.clone(), value_targets.clone(), Reduction::Mean); let total_loss = policy_loss + 0.5 * value_loss; ChessOutput { policy_logits, value, policy_targets, value_targets, loss: total_loss, } } } pub struct ChessOutput { pub policy_logits: Tensor, // [sample, num_moves (4672)] pub value: Tensor, // [sample, 1 value] pub policy_targets: Tensor, pub value_targets: Tensor, // kept 2d for consistency? pub loss: Tensor, // [sample] } impl ItemLazy for ChessOutput { type ItemSync = ChessOutput; fn sync(self) -> Self::ItemSync { let [policy_logits, value, policy_targets, value_targets, loss] = Transaction::default() .register(self.policy_logits) .register(self.value) .register(self.policy_targets) .register(self.value_targets) .register(self.loss) .execute() .try_into() .expect("Correct amount of tensor data"); let device = &Default::default(); ChessOutput { policy_logits: Tensor::from_data(policy_logits, device), value: Tensor::from_data(value, device), policy_targets: Tensor::from_data(policy_targets, device), value_targets: Tensor::from_data(value_targets, device), loss: Tensor::from_data(loss, device), } } } impl TrainStep for ChessModel { type Input = ChessBatch; type Output = ChessOutput; fn step(&self, batch: ChessBatch) -> TrainOutput> { let item = self.forward_chess(batch.states, batch.policy_targets, batch.value_targets); TrainOutput::new(self, item.loss.backward(), item) } } impl InferenceStep for ChessModel { type Input = ChessBatch; type Output = ChessOutput; fn step(&self, batch: ChessBatch) -> ChessOutput { self.forward_chess(batch.states, batch.policy_targets, batch.value_targets) } } #[derive(Clone)] pub struct TrainingSample { pub board_state: BoardState, pub policy_target: HashMap, pub value_target: f32, } impl TrainingSample { pub fn new( board_state: BoardState, policy_target: HashMap, value_target: f32, ) -> Self { TrainingSample { board_state, policy_target, value_target, } } pub fn from_mcts_with_outcome(mcts_results: MctsResults, outcome: f32) -> Self { TrainingSample::new(mcts_results.board_state, mcts_results.move_dist, outcome) } } #[derive(Clone, Default)] pub struct ChessBatcher {} #[derive(Clone, Debug)] pub struct ChessBatch { pub states: Tensor, pub policy_targets: Tensor, pub value_targets: Tensor, } impl Batcher> for ChessBatcher { fn batch(&self, items: Vec, device: &B::Device) -> ChessBatch { let state_tensors = items .iter() .map(|item| { encode_board_state_perspective(&item.board_state, device).reshape([1, 18, 8, 8]) }) .collect::>(); let policy_target_tensors = items .iter() .cloned() .map(|item| { let mut policy = vec![0.0f32; 4672]; let stm = item.board_state.board.side_to_move(); for (mv, prob) in item.policy_target { policy[encode_move(mv, stm)] = prob; } // Normalize let sum: f32 = policy.iter().sum(); if sum > 0.0 { for p in &mut policy { *p /= sum; } } Tensor::::from_floats(TensorData::new(policy, [4672]), device).unsqueeze() }) .collect::>(); let value_target_tensors = items .iter() .map(|item| { Tensor::::from_floats(TensorData::new(vec![item.value_target], [1]), device) .reshape([1, 1]) }) .collect::>(); let states = Tensor::cat(state_tensors, 0); // [B, 18, 8, 8] let policy_targets = Tensor::cat(policy_target_tensors, 0); // [B, 4672] let value_targets = Tensor::cat(value_target_tensors, 0); // [B, 1] ChessBatch { states, policy_targets, value_targets, } } }