339 lines
10 KiB
Rust
339 lines
10 KiB
Rust
use crate::mcts::{BoardState, MctsResults};
|
|
use crate::net::encoding::{encode_board_state_perspective, encode_move};
|
|
use burn::data::dataloader::batcher::Batcher;
|
|
use burn::nn::conv::Conv2dConfig;
|
|
use burn::nn::loss::{MseLoss, Reduction};
|
|
use burn::nn::{BatchNorm, BatchNormConfig, LinearConfig, PaddingConfig2d};
|
|
use burn::tensor::activation::log_softmax;
|
|
use burn::tensor::backend::AutodiffBackend;
|
|
use burn::tensor::Transaction;
|
|
use burn::train::{InferenceStep, ItemLazy, TrainOutput, TrainStep};
|
|
use burn::{
|
|
nn::{conv::Conv2d, Linear, Relu},
|
|
prelude::*,
|
|
};
|
|
use burn_ndarray::NdArray;
|
|
use chess::ChessMove;
|
|
use std::collections::HashMap;
|
|
/*
|
|
Input planes:
|
|
1-6: your pieces (Pawn, Knight, Bishop, Rook, Queen, King)
|
|
7-12: opponent pieces
|
|
13: your kingside
|
|
14: your queenside
|
|
15: opponent kingside
|
|
16: opponent queenside
|
|
17: rep >= 2
|
|
18: En passant square (gap pass over)
|
|
|
|
|
|
Output:
|
|
4672 moves represented through planes
|
|
1-56: sliding moves
|
|
57-64: Knight moves
|
|
65-73: underpromotions
|
|
*/
|
|
|
|
#[derive(Module, Debug)]
|
|
pub struct ResidualBlock<B: Backend> {
|
|
conv1: Conv2d<B>,
|
|
bn1: BatchNorm<B>,
|
|
conv2: Conv2d<B>,
|
|
bn2: BatchNorm<B>,
|
|
activation: Relu,
|
|
}
|
|
|
|
#[derive(Config, Debug)]
|
|
pub struct ResidualBlockConfig {
|
|
channels: usize,
|
|
}
|
|
|
|
impl<B: Backend> ResidualBlock<B> {
|
|
pub fn new(channels: usize, device: &B::Device) -> Self {
|
|
Self {
|
|
conv1: Conv2dConfig::new([channels, channels], [3, 3])
|
|
.with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1))
|
|
.init(device),
|
|
bn1: BatchNormConfig::new(channels).init(device),
|
|
conv2: Conv2dConfig::new([channels, channels], [3, 3])
|
|
.with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1))
|
|
.init(device),
|
|
bn2: BatchNormConfig::new(channels).init(device),
|
|
activation: Relu::new(),
|
|
}
|
|
}
|
|
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
|
let residual = x.clone();
|
|
|
|
let x = self.conv1.forward(x);
|
|
let x = self.bn1.forward(x);
|
|
let x = self.activation.forward(x);
|
|
|
|
let x = self.conv2.forward(x);
|
|
let x = self.bn2.forward(x);
|
|
|
|
self.activation.forward(x + residual)
|
|
}
|
|
}
|
|
|
|
#[derive(Module, Debug)]
|
|
pub struct ChessModel<B: Backend> {
|
|
// Initial convolution
|
|
conv: Conv2d<B>,
|
|
bn: BatchNorm<B>,
|
|
|
|
// Residual tower
|
|
residual_blocks: Vec<ResidualBlock<B>>,
|
|
|
|
// Policy head
|
|
policy_conv: Conv2d<B>,
|
|
policy_bn: BatchNorm<B>,
|
|
policy_fc: Linear<B>,
|
|
|
|
// Value head
|
|
value_conv: Conv2d<B>,
|
|
value_bn: BatchNorm<B>,
|
|
value_fc1: Linear<B>,
|
|
value_fc2: Linear<B>,
|
|
|
|
activation: Relu,
|
|
}
|
|
|
|
#[derive(Config, Debug)]
|
|
pub struct ChessModelConfig {
|
|
num_blocks: usize,
|
|
channels: usize,
|
|
}
|
|
|
|
impl ChessModelConfig {
|
|
pub fn init<B: Backend>(
|
|
num_blocks: usize,
|
|
channels: usize,
|
|
device: &B::Device,
|
|
) -> ChessModel<B> {
|
|
ChessModel {
|
|
conv: Conv2dConfig::new([18, channels], [3, 3]) // 18 plane input
|
|
.with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1))
|
|
.init(device),
|
|
bn: BatchNormConfig::new(channels).init(device),
|
|
|
|
residual_blocks: (0..num_blocks)
|
|
.map(|_| ResidualBlock::new(channels, device))
|
|
.collect(),
|
|
|
|
// Policy head
|
|
policy_conv: Conv2dConfig::new([channels, 2], [1, 1]).init(device),
|
|
policy_bn: BatchNormConfig::new(2).init(device),
|
|
policy_fc: LinearConfig::new(2 * 8 * 8, 8 * 8 * 73).init(device), // 4672 typical chess move space
|
|
|
|
// Value head
|
|
value_conv: Conv2dConfig::new([channels, 1], [1, 1]).init(device),
|
|
value_bn: BatchNormConfig::new(1).init(device),
|
|
value_fc1: LinearConfig::new(1 * 8 * 8, 256).init(device),
|
|
value_fc2: LinearConfig::new(256, 1).init(device),
|
|
|
|
activation: Relu::new(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<B: Backend> ChessModel<B> {
|
|
pub fn forward(&self, x: Tensor<B, 4>) -> (Tensor<B, 2>, Tensor<B, 2>) {
|
|
let mut x = self.conv.forward(x);
|
|
x = self.bn.forward(x);
|
|
x = self.activation.forward(x);
|
|
|
|
for block in &self.residual_blocks {
|
|
x = block.forward(x);
|
|
}
|
|
|
|
let batch_size = x.dims()[0];
|
|
|
|
// -------- Policy Head --------
|
|
let mut p = self.policy_conv.forward(x.clone());
|
|
p = self.policy_bn.forward(p);
|
|
p = self.activation.forward(p);
|
|
let mut p = p.reshape([batch_size, 2 * 8 * 8]);
|
|
p = self.policy_fc.forward(p);
|
|
|
|
// -------- Value Head --------
|
|
let mut v = self.value_conv.forward(x);
|
|
v = self.value_bn.forward(v);
|
|
v = self.activation.forward(v);
|
|
let mut v = v.reshape([batch_size, 8 * 8]);
|
|
v = self.activation.forward(self.value_fc1.forward(v));
|
|
v = self.value_fc2.forward(v).tanh();
|
|
|
|
(p, v)
|
|
}
|
|
|
|
pub fn forward_chess(
|
|
&self,
|
|
boards: Tensor<B, 4>, // e.g. [batch, channels, 8, 8]
|
|
policy_targets: Tensor<B, 2>, // move distribution (make sure is normalized)
|
|
value_targets: Tensor<B, 2>, // scalar evaluation
|
|
) -> ChessOutput<B> {
|
|
let (policy_logits, value) = self.forward(boards);
|
|
|
|
let log_probs = log_softmax(policy_logits.clone(), 1);
|
|
let policy_loss = policy_targets
|
|
.clone()
|
|
.mul(log_probs)
|
|
.neg()
|
|
.sum_dim(1)
|
|
.mean();
|
|
|
|
let value_loss =
|
|
MseLoss::new().forward(value.clone(), value_targets.clone(), Reduction::Mean);
|
|
|
|
let total_loss = policy_loss + 0.5 * value_loss;
|
|
|
|
ChessOutput {
|
|
policy_logits,
|
|
value,
|
|
policy_targets,
|
|
value_targets,
|
|
loss: total_loss,
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct ChessOutput<B: Backend> {
|
|
pub policy_logits: Tensor<B, 2>, // [sample, num_moves (4672)]
|
|
pub value: Tensor<B, 2>, // [sample, 1 value]
|
|
pub policy_targets: Tensor<B, 2>,
|
|
pub value_targets: Tensor<B, 2>, // kept 2d for consistency?
|
|
pub loss: Tensor<B, 1>, // [sample]
|
|
}
|
|
|
|
impl<B: Backend> ItemLazy for ChessOutput<B> {
|
|
type ItemSync = ChessOutput<NdArray>;
|
|
|
|
fn sync(self) -> Self::ItemSync {
|
|
let [policy_logits, value, policy_targets, value_targets, loss] = Transaction::default()
|
|
.register(self.policy_logits)
|
|
.register(self.value)
|
|
.register(self.policy_targets)
|
|
.register(self.value_targets)
|
|
.register(self.loss)
|
|
.execute()
|
|
.try_into()
|
|
.expect("Correct amount of tensor data");
|
|
|
|
let device = &Default::default();
|
|
|
|
ChessOutput {
|
|
policy_logits: Tensor::from_data(policy_logits, device),
|
|
value: Tensor::from_data(value, device),
|
|
policy_targets: Tensor::from_data(policy_targets, device),
|
|
value_targets: Tensor::from_data(value_targets, device),
|
|
loss: Tensor::from_data(loss, device),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<B: AutodiffBackend> TrainStep for ChessModel<B> {
|
|
type Input = ChessBatch<B>;
|
|
type Output = ChessOutput<B>;
|
|
|
|
fn step(&self, batch: ChessBatch<B>) -> TrainOutput<ChessOutput<B>> {
|
|
let item = self.forward_chess(batch.states, batch.policy_targets, batch.value_targets);
|
|
|
|
TrainOutput::new(self, item.loss.backward(), item)
|
|
}
|
|
}
|
|
|
|
impl<B: Backend> InferenceStep for ChessModel<B> {
|
|
type Input = ChessBatch<B>;
|
|
type Output = ChessOutput<B>;
|
|
|
|
fn step(&self, batch: ChessBatch<B>) -> ChessOutput<B> {
|
|
self.forward_chess(batch.states, batch.policy_targets, batch.value_targets)
|
|
}
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct TrainingSample {
|
|
pub board_state: BoardState,
|
|
pub policy_target: HashMap<ChessMove, f32>,
|
|
pub value_target: f32,
|
|
}
|
|
|
|
impl TrainingSample {
|
|
pub fn new(
|
|
board_state: BoardState,
|
|
policy_target: HashMap<ChessMove, f32>,
|
|
value_target: f32,
|
|
) -> Self {
|
|
TrainingSample {
|
|
board_state,
|
|
policy_target,
|
|
value_target,
|
|
}
|
|
}
|
|
|
|
pub fn from_mcts_with_outcome(mcts_results: MctsResults, outcome: f32) -> Self {
|
|
TrainingSample::new(mcts_results.board_state, mcts_results.move_dist, outcome)
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Default)]
|
|
pub struct ChessBatcher {}
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub struct ChessBatch<B: Backend> {
|
|
pub states: Tensor<B, 4>,
|
|
pub policy_targets: Tensor<B, 2>,
|
|
pub value_targets: Tensor<B, 2>,
|
|
}
|
|
|
|
impl<B: Backend> Batcher<B, TrainingSample, ChessBatch<B>> for ChessBatcher {
|
|
fn batch(&self, items: Vec<TrainingSample>, device: &B::Device) -> ChessBatch<B> {
|
|
let state_tensors = items
|
|
.iter()
|
|
.map(|item| {
|
|
encode_board_state_perspective(&item.board_state, device).reshape([1, 18, 8, 8])
|
|
})
|
|
.collect::<Vec<_>>();
|
|
|
|
let policy_target_tensors = items
|
|
.iter()
|
|
.cloned()
|
|
.map(|item| {
|
|
let mut policy = vec![0.0f32; 4672];
|
|
let stm = item.board_state.board.side_to_move();
|
|
for (mv, prob) in item.policy_target {
|
|
policy[encode_move(mv, stm)] = prob;
|
|
}
|
|
|
|
// Normalize
|
|
let sum: f32 = policy.iter().sum();
|
|
if sum > 0.0 {
|
|
for p in &mut policy {
|
|
*p /= sum;
|
|
}
|
|
}
|
|
Tensor::<B, 2>::from_floats(TensorData::new(policy, [4672]), device).unsqueeze()
|
|
})
|
|
.collect::<Vec<_>>();
|
|
|
|
let value_target_tensors = items
|
|
.iter()
|
|
.map(|item| {
|
|
Tensor::<B, 2>::from_floats(TensorData::new(vec![item.value_target], [1]), device)
|
|
.reshape([1, 1])
|
|
})
|
|
.collect::<Vec<_>>();
|
|
|
|
let states = Tensor::cat(state_tensors, 0); // [B, 18, 8, 8]
|
|
let policy_targets = Tensor::cat(policy_target_tensors, 0); // [B, 4672]
|
|
let value_targets = Tensor::cat(value_target_tensors, 0); // [B, 1]
|
|
|
|
ChessBatch {
|
|
states,
|
|
policy_targets,
|
|
value_targets,
|
|
}
|
|
}
|
|
}
|