chess_dragon/engine/src/net/model.rs
2026-05-23 15:06:10 -05:00

339 lines
10 KiB
Rust

use crate::mcts::{BoardState, MctsResults};
use crate::net::encoding::{encode_board_state_perspective, encode_move};
use burn::data::dataloader::batcher::Batcher;
use burn::nn::conv::Conv2dConfig;
use burn::nn::loss::{MseLoss, Reduction};
use burn::nn::{BatchNorm, BatchNormConfig, LinearConfig, PaddingConfig2d};
use burn::tensor::activation::log_softmax;
use burn::tensor::backend::AutodiffBackend;
use burn::tensor::Transaction;
use burn::train::{InferenceStep, ItemLazy, TrainOutput, TrainStep};
use burn::{
nn::{conv::Conv2d, Linear, Relu},
prelude::*,
};
use burn_ndarray::NdArray;
use chess::ChessMove;
use std::collections::HashMap;
/*
Input planes:
1-6: your pieces (Pawn, Knight, Bishop, Rook, Queen, King)
7-12: opponent pieces
13: your kingside
14: your queenside
15: opponent kingside
16: opponent queenside
17: rep >= 2
18: En passant square (gap pass over)
Output:
4672 moves represented through planes
1-56: sliding moves
57-64: Knight moves
65-73: underpromotions
*/
#[derive(Module, Debug)]
pub struct ResidualBlock<B: Backend> {
conv1: Conv2d<B>,
bn1: BatchNorm<B>,
conv2: Conv2d<B>,
bn2: BatchNorm<B>,
activation: Relu,
}
#[derive(Config, Debug)]
pub struct ResidualBlockConfig {
channels: usize,
}
impl<B: Backend> ResidualBlock<B> {
pub fn new(channels: usize, device: &B::Device) -> Self {
Self {
conv1: Conv2dConfig::new([channels, channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1))
.init(device),
bn1: BatchNormConfig::new(channels).init(device),
conv2: Conv2dConfig::new([channels, channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1))
.init(device),
bn2: BatchNormConfig::new(channels).init(device),
activation: Relu::new(),
}
}
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let residual = x.clone();
let x = self.conv1.forward(x);
let x = self.bn1.forward(x);
let x = self.activation.forward(x);
let x = self.conv2.forward(x);
let x = self.bn2.forward(x);
self.activation.forward(x + residual)
}
}
#[derive(Module, Debug)]
pub struct ChessModel<B: Backend> {
// Initial convolution
conv: Conv2d<B>,
bn: BatchNorm<B>,
// Residual tower
residual_blocks: Vec<ResidualBlock<B>>,
// Policy head
policy_conv: Conv2d<B>,
policy_bn: BatchNorm<B>,
policy_fc: Linear<B>,
// Value head
value_conv: Conv2d<B>,
value_bn: BatchNorm<B>,
value_fc1: Linear<B>,
value_fc2: Linear<B>,
activation: Relu,
}
#[derive(Config, Debug)]
pub struct ChessModelConfig {
num_blocks: usize,
channels: usize,
}
impl ChessModelConfig {
pub fn init<B: Backend>(
num_blocks: usize,
channels: usize,
device: &B::Device,
) -> ChessModel<B> {
ChessModel {
conv: Conv2dConfig::new([18, channels], [3, 3]) // 18 plane input
.with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1))
.init(device),
bn: BatchNormConfig::new(channels).init(device),
residual_blocks: (0..num_blocks)
.map(|_| ResidualBlock::new(channels, device))
.collect(),
// Policy head
policy_conv: Conv2dConfig::new([channels, 2], [1, 1]).init(device),
policy_bn: BatchNormConfig::new(2).init(device),
policy_fc: LinearConfig::new(2 * 8 * 8, 8 * 8 * 73).init(device), // 4672 typical chess move space
// Value head
value_conv: Conv2dConfig::new([channels, 1], [1, 1]).init(device),
value_bn: BatchNormConfig::new(1).init(device),
value_fc1: LinearConfig::new(1 * 8 * 8, 256).init(device),
value_fc2: LinearConfig::new(256, 1).init(device),
activation: Relu::new(),
}
}
}
impl<B: Backend> ChessModel<B> {
pub fn forward(&self, x: Tensor<B, 4>) -> (Tensor<B, 2>, Tensor<B, 2>) {
let mut x = self.conv.forward(x);
x = self.bn.forward(x);
x = self.activation.forward(x);
for block in &self.residual_blocks {
x = block.forward(x);
}
let batch_size = x.dims()[0];
// -------- Policy Head --------
let mut p = self.policy_conv.forward(x.clone());
p = self.policy_bn.forward(p);
p = self.activation.forward(p);
let mut p = p.reshape([batch_size, 2 * 8 * 8]);
p = self.policy_fc.forward(p);
// -------- Value Head --------
let mut v = self.value_conv.forward(x);
v = self.value_bn.forward(v);
v = self.activation.forward(v);
let mut v = v.reshape([batch_size, 8 * 8]);
v = self.activation.forward(self.value_fc1.forward(v));
v = self.value_fc2.forward(v).tanh();
(p, v)
}
pub fn forward_chess(
&self,
boards: Tensor<B, 4>, // e.g. [batch, channels, 8, 8]
policy_targets: Tensor<B, 2>, // move distribution (make sure is normalized)
value_targets: Tensor<B, 2>, // scalar evaluation
) -> ChessOutput<B> {
let (policy_logits, value) = self.forward(boards);
let log_probs = log_softmax(policy_logits.clone(), 1);
let policy_loss = policy_targets
.clone()
.mul(log_probs)
.neg()
.sum_dim(1)
.mean();
let value_loss =
MseLoss::new().forward(value.clone(), value_targets.clone(), Reduction::Mean);
let total_loss = policy_loss + 0.5 * value_loss;
ChessOutput {
policy_logits,
value,
policy_targets,
value_targets,
loss: total_loss,
}
}
}
pub struct ChessOutput<B: Backend> {
pub policy_logits: Tensor<B, 2>, // [sample, num_moves (4672)]
pub value: Tensor<B, 2>, // [sample, 1 value]
pub policy_targets: Tensor<B, 2>,
pub value_targets: Tensor<B, 2>, // kept 2d for consistency?
pub loss: Tensor<B, 1>, // [sample]
}
impl<B: Backend> ItemLazy for ChessOutput<B> {
type ItemSync = ChessOutput<NdArray>;
fn sync(self) -> Self::ItemSync {
let [policy_logits, value, policy_targets, value_targets, loss] = Transaction::default()
.register(self.policy_logits)
.register(self.value)
.register(self.policy_targets)
.register(self.value_targets)
.register(self.loss)
.execute()
.try_into()
.expect("Correct amount of tensor data");
let device = &Default::default();
ChessOutput {
policy_logits: Tensor::from_data(policy_logits, device),
value: Tensor::from_data(value, device),
policy_targets: Tensor::from_data(policy_targets, device),
value_targets: Tensor::from_data(value_targets, device),
loss: Tensor::from_data(loss, device),
}
}
}
impl<B: AutodiffBackend> TrainStep for ChessModel<B> {
type Input = ChessBatch<B>;
type Output = ChessOutput<B>;
fn step(&self, batch: ChessBatch<B>) -> TrainOutput<ChessOutput<B>> {
let item = self.forward_chess(batch.states, batch.policy_targets, batch.value_targets);
TrainOutput::new(self, item.loss.backward(), item)
}
}
impl<B: Backend> InferenceStep for ChessModel<B> {
type Input = ChessBatch<B>;
type Output = ChessOutput<B>;
fn step(&self, batch: ChessBatch<B>) -> ChessOutput<B> {
self.forward_chess(batch.states, batch.policy_targets, batch.value_targets)
}
}
#[derive(Clone)]
pub struct TrainingSample {
pub board_state: BoardState,
pub policy_target: HashMap<ChessMove, f32>,
pub value_target: f32,
}
impl TrainingSample {
pub fn new(
board_state: BoardState,
policy_target: HashMap<ChessMove, f32>,
value_target: f32,
) -> Self {
TrainingSample {
board_state,
policy_target,
value_target,
}
}
pub fn from_mcts_with_outcome(mcts_results: MctsResults, outcome: f32) -> Self {
TrainingSample::new(mcts_results.board_state, mcts_results.move_dist, outcome)
}
}
#[derive(Clone, Default)]
pub struct ChessBatcher {}
#[derive(Clone, Debug)]
pub struct ChessBatch<B: Backend> {
pub states: Tensor<B, 4>,
pub policy_targets: Tensor<B, 2>,
pub value_targets: Tensor<B, 2>,
}
impl<B: Backend> Batcher<B, TrainingSample, ChessBatch<B>> for ChessBatcher {
fn batch(&self, items: Vec<TrainingSample>, device: &B::Device) -> ChessBatch<B> {
let state_tensors = items
.iter()
.map(|item| {
encode_board_state_perspective(&item.board_state, device).reshape([1, 18, 8, 8])
})
.collect::<Vec<_>>();
let policy_target_tensors = items
.iter()
.cloned()
.map(|item| {
let mut policy = vec![0.0f32; 4672];
let stm = item.board_state.board.side_to_move();
for (mv, prob) in item.policy_target {
policy[encode_move(mv, stm)] = prob;
}
// Normalize
let sum: f32 = policy.iter().sum();
if sum > 0.0 {
for p in &mut policy {
*p /= sum;
}
}
Tensor::<B, 2>::from_floats(TensorData::new(policy, [4672]), device).unsqueeze()
})
.collect::<Vec<_>>();
let value_target_tensors = items
.iter()
.map(|item| {
Tensor::<B, 2>::from_floats(TensorData::new(vec![item.value_target], [1]), device)
.reshape([1, 1])
})
.collect::<Vec<_>>();
let states = Tensor::cat(state_tensors, 0); // [B, 18, 8, 8]
let policy_targets = Tensor::cat(policy_target_tensors, 0); // [B, 4672]
let value_targets = Tensor::cat(value_target_tensors, 0); // [B, 1]
ChessBatch {
states,
policy_targets,
value_targets,
}
}
}