first working

This commit is contained in:
Drake Marino 2026-05-23 15:06:10 -05:00
commit 764252b0c4
32 changed files with 9465 additions and 0 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
/target

10
.idea/.gitignore generated vendored Normal file
View File

@ -0,0 +1,10 @@
# Default ignored files
/shelf/
/workspace.xml
# Ignored default folder with query files
/queries/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml
# Editor-based HTTP Client requests
/httpRequests/

15
.idea/chess_dragon.iml generated Normal file
View File

@ -0,0 +1,15 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="EMPTY_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/cli/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/engine/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/uci/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/web/src" isTestSource="false" />
<excludeFolder url="file://$MODULE_DIR$/target" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

6
.idea/copilot.data.migration.agent.xml generated Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="AgentMigrationStateService">
<option name="migrationStatus" value="COMPLETED" />
</component>
</project>

6
.idea/copilot.data.migration.ask.xml generated Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="AskMigrationStateService">
<option name="migrationStatus" value="COMPLETED" />
</component>
</project>

View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Ask2AgentMigrationStateService">
<option name="migrationStatus" value="COMPLETED" />
</component>
</project>

6
.idea/copilot.data.migration.edit.xml generated Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="EditMigrationStateService">
<option name="migrationStatus" value="COMPLETED" />
</component>
</project>

8
.idea/modules.xml generated Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/chess_dragon.iml" filepath="$PROJECT_DIR$/.idea/chess_dragon.iml" />
</modules>
</component>
</project>

6
.idea/vcs.xml generated Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
</component>
</project>

7461
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

4
Cargo.toml Normal file
View File

@ -0,0 +1,4 @@
[workspace]
members = ["engine", "cli", "uci", "web"]
resolver = "2"

8
cli/Cargo.toml Normal file
View File

@ -0,0 +1,8 @@
[package]
name = "cli"
version = "0.1.0"
edition = "2024"
[dependencies]
engine = { path = "../engine" }

3
cli/src/commands.rs Normal file
View File

@ -0,0 +1,3 @@
pub mod benchmark;
pub mod play;
pub mod train;

View File

@ -0,0 +1 @@

0
cli/src/commands/play.rs Normal file
View File

View File

@ -0,0 +1,4 @@
pub fn run() {
// train(TrainingConfig {}, );
}

16
cli/src/main.rs Normal file
View File

@ -0,0 +1,16 @@
mod commands;
use std::env;
fn main() {
let args: Vec<String> = env::args().collect();
match args.get(1).map(|s| s.as_str()) {
Some("train") => commands::train::run(),
Some("selfplay") => commands::selfplay::run(),
Some("benchmark") => commands::benchmark::run(),
_ => {
println!("unknown command");
}
}
}

12
engine/Cargo.toml Normal file
View File

@ -0,0 +1,12 @@
[package]
name = "engine"
version = "0.1.0"
edition = "2024"
[dependencies]
chess = "3.2.0"
burn = { version = "0.21.0", features = ["std", "tui", "train", "cuda", "fusion", "wgpu", "ndarray"], default-features = false }
burn-ndarray = "0.21.0"
rand_distr = "0.6.0"
rand = "0.10.1"
ctrlc = "3.5.2"

70
engine/src/lib.rs Normal file
View File

@ -0,0 +1,70 @@
pub mod mcts;
mod net;
mod train;
pub mod training;
use chess::Game;
pub const DEFAULT_MAX_DEPTH: u16 = 6;
pub const DEFAULT_PLAYER_TIME_REMAINING_MS: u64 = 120_000; // 2 minutes
pub const DEFAULT_PLAYER_INCREMENT_MS: u64 = 0;
pub struct Engine {
pub game: Game,
pub search_settings: SearchSettings,
}
pub struct SearchSettings {
pub wtime: u64,
pub btime: u64,
pub winc: u64,
pub binc: u64,
pub movetime: Option<u64>,
pub max_depth: Option<u16>,
pub max_nodes: Option<usize>,
}
impl Default for SearchSettings {
fn default() -> Self {
SearchSettings::new(None, None, None, None, None, None, None)
}
}
impl SearchSettings {
pub fn new(
wtime: Option<u64>,
btime: Option<u64>,
winc: Option<u64>,
binc: Option<u64>,
movetime: Option<u64>,
max_depth: Option<u16>,
max_nodes: Option<usize>,
) -> Self {
SearchSettings {
wtime: wtime.unwrap_or(DEFAULT_PLAYER_TIME_REMAINING_MS),
btime: btime.unwrap_or(DEFAULT_PLAYER_TIME_REMAINING_MS),
winc: winc.unwrap_or(DEFAULT_PLAYER_INCREMENT_MS),
binc: binc.unwrap_or(DEFAULT_PLAYER_INCREMENT_MS),
movetime,
max_depth,
max_nodes,
}
}
}
impl Default for Engine {
fn default() -> Self {
Engine::new(None)
}
}
impl Engine {
pub fn new(search_settings: Option<SearchSettings>) -> Self {
Engine {
game: Game::new(),
search_settings: search_settings.unwrap_or(SearchSettings::default()),
}
}
pub fn go(&mut self) {}
}

104
engine/src/main.rs Normal file
View File

@ -0,0 +1,104 @@
#![recursion_limit = "256"]
use burn::backend::{Autodiff, Cuda};
use burn::optim::AdamConfig;
use engine::mcts::MctsConfig;
use engine::training::train::{train, TrainingConfig};
// fn main() {
// type MyBackend = Wgpu<f32, i32>;
//
// let device = Default::default();
// let model = ChessModelConfig::new(10, 512).init::<MyBackend>(&device);
//
// println!("{model}");
// }
fn main() {
// type MyBackend = Wgpu<f32, i32>;
type MyBackend = Cuda<f32, i32>;
// type MyBackend = NdArray<f32, i32>;
type MyAutodiffBackend = Autodiff<MyBackend>;
// let device = burn::backend::wgpu::WgpuDevice::default();
// let device = burn::backend::ndarray::NdArrayDevice::default();
let device = burn::backend::cuda::CudaDevice::default();
let mcts_config = MctsConfig::new(10, 1.0, 0.05, 0.25);
let adam_config = AdamConfig::new();
let training_config = TrainingConfig {
max_time_s: None,
num_iters: None,
max_depth: 100, // unused
model_name: String::from("Test1"),
load_model: false,
hidden_channels: 64,
num_blocks: 4,
batch_size: 128, // num positions sampled to a batch (2048)
num_episodes: 10, // num games generated per iteration (5000)
buffer_max_size: 200_000, // max number of samples in the buffer
mcts_config,
optimizer: adam_config,
lr: 2e-4,
};
train::<MyAutodiffBackend>(training_config, device);
// let mut board_state: Option<BoardState> = Some(BoardState::new(
// Board::default(),
// 0,
// HashMap::<u64, u8>::new(),
// ));
//
//
// // let start = Instant::now();
// // rollout(&board, &PolicyNetwork::new());
// // let end = Instant::now();
// // println!("Rollout: {:?}", end - start);
// let mut previous_state = board_state.clone().unwrap();
//
// let mut move_history: String = String::new();
//
// let mut mcts = MctsConfig::new()
//
// while board_state.is_some() {
// // mcts now returns Option<(BoardState, String)>
// let results = Mcts::search(&previous_state, 1000000, 0);
// if !results.move_dist.is_empty() {
// let mut best_prob: f32 = -1.0;
// let mut best_move: &ChessMove = &ChessMove::default();
//
// for m in results.move_dist.iter() {
// if m.1 > &best_prob {
// best_move = m.0;
// }
// }
//
// previous_state = results.board_state;
// previous_state.board = previous_state.board.make_move_new(*best_move);
// board_state = Some(previous_state.clone());
//
// println!(
// "Chosen move: {}, board state: {}",
// best_move.to_string(),
// previous_state.board
// );
// move_history += &*best_move.to_string();
// move_history += " ";
// } else {
// board_state = None;
// }
// }
//
// println!("Finished game w/ board state: {}", previous_state.board);
// println!("Move history: {}", move_history);
// println!(
// "Game ended! status: {:?}, clock: {}, 3-fold: {}",
// previous_state.board.status(),
// previous_state.halfmove_clock,
// previous_state
// .repetition_table
// .iter()
// .any(|(_, ct)| ct >= &3u8)
// );
}

416
engine/src/mcts.rs Normal file
View File

@ -0,0 +1,416 @@
// use crate::networks::PolicyNetwork;
use crate::mcts::BoardStateStatus::{FiftyMove, Ongoing, Threefold};
use crate::net::encoding::{encode_board_state_perspective, encode_move};
use crate::net::model::ChessModel;
use burn::prelude::Backend;
use burn::Tensor;
use chess::BoardStatus::{Checkmate, Stalemate};
use chess::Color::White;
use chess::Piece::{Bishop, Knight, Pawn, Queen, Rook};
use chess::{Board, ChessMove, Color, MoveGen, Piece, ALL_COLORS, ALL_PIECES};
use std::collections::HashMap;
use std::marker::PhantomData;
use std::time::Instant;
pub struct Node {
pub prior: f32,
pub children: Vec<usize>,
pub visit_count: u32,
pub value_sum: f32,
pub board_state: BoardState,
pub last_move: Option<ChessMove>, // move that produced this node from its parent
}
impl Node {
pub fn new(prior: f32, board_state: BoardState, last_move: Option<ChessMove>) -> Node {
Node {
prior,
children: vec![],
visit_count: 0,
value_sum: 0.0,
board_state,
last_move,
}
}
pub fn value(&self) -> f32 {
self.value_sum / self.visit_count as f32
}
pub fn select_child(&self, arena: &[Node], c_puct: &f32) -> usize {
self.children
.iter()
.copied()
.max_by(|&a, &b| {
ucb_score(self, &arena[a], c_puct)
.partial_cmp(&ucb_score(self, &arena[b], c_puct))
.unwrap()
})
.expect("select_child on leaf")
}
}
impl Clone for Node {
fn clone(&self) -> Self {
Self::new(self.prior, self.board_state.clone(), self.last_move.clone())
}
}
#[derive(Clone, Debug)]
pub struct MctsResults {
pub board_state: BoardState,
pub move_dist: HashMap<ChessMove, f32>,
pub value: f32,
}
impl MctsResults {
pub fn new(
board_state: BoardState,
move_dist: HashMap<ChessMove, f32>,
value: f32,
) -> MctsResults {
MctsResults {
board_state,
move_dist,
value,
}
}
}
#[derive(Debug)]
pub struct MctsConfig {
pub num_simulations: usize,
pub c_puct: f32,
pub dirichlet_alpha: f32,
pub dirichlet_epsilon: f32,
}
impl MctsConfig {
pub fn new(
num_simulations: usize,
c_puct: f32,
dirichlet_alpha: f32,
dirichlet_epsilon: f32,
) -> MctsConfig {
MctsConfig {
num_simulations,
c_puct,
dirichlet_alpha,
dirichlet_epsilon,
}
}
}
impl Default for MctsConfig {
fn default() -> MctsConfig {
MctsConfig::new(400, 1.0, 0.05, 0.25)
}
}
pub struct Mcts<B: Backend> {
pub config: MctsConfig,
pub _marker: PhantomData<B>, // if B not otherwise stored
}
impl<B: Backend> Mcts<B> {
pub fn search(
&mut self,
board_state: &BoardState,
model: &ChessModel<B>,
device: &B::Device,
) -> MctsResults {
let mut nodes = Vec::<Node>::new();
let root = 0;
nodes.push(Node::new(0.0, board_state.clone(), None));
self.expand(root, &mut nodes, model, device);
// 👇 APPLY DIRICHLET NOISE HERE
self.add_dirichlet_noise(root, &mut nodes);
for i in 0..self.config.num_simulations {
// if i % 10 == 0 {
// println!("mcts sim #{}", i);
// }
let mut path = vec![root];
let mut current = root;
while !nodes[current].children.is_empty() {
current = nodes[current].select_child(&nodes, &self.config.c_puct);
path.push(current);
}
let value: f32 = self.expand(current, &mut nodes, model, device);
let color = nodes[current].board_state.board.side_to_move();
self.backpropagate(&mut nodes, &path, value, color);
}
let mut move_dist: HashMap<ChessMove, f32> = HashMap::new(); // TODO: make vec<(Chessmove, f32)>
for idx in nodes[root].children.iter() {
move_dist.insert(
nodes[*idx].last_move.expect("move didnt exist"),
nodes[*idx].visit_count as f32 / self.config.num_simulations as f32,
);
}
MctsResults::new(board_state.clone(), move_dist, nodes[root].value())
}
fn expand(
&mut self,
node_idx: usize,
arena: &mut Vec<Node>,
model: &ChessModel<B>,
device: &B::Device,
) -> f32 {
let state: Tensor<B, 4> =
encode_board_state_perspective(&arena[node_idx].board_state, device)
.reshape([1, 18, 8, 8]);
let start = Instant::now();
let (policy_head, value_head) = model.forward(state);
println!("time: {:?}", start.elapsed());
let legal_moves: Vec<ChessMove> =
MoveGen::new_legal(&arena[node_idx].board_state.board).collect();
let policy = policy_head.into_data().to_vec::<f32>().unwrap();
for mv in legal_moves {
let stm = arena[node_idx].board_state.board.side_to_move();
let idx = encode_move(mv, stm);
let prior = policy[idx];
let mut new_board = arena[node_idx].board_state.clone();
new_board.apply_move(mv);
let child_idx = arena.len();
arena.push(Node::new(prior, new_board, Some(mv)));
arena[node_idx].children.push(child_idx);
}
value_head.into_data().to_vec().unwrap()[0]
}
fn backpropagate(&mut self, nodes: &mut [Node], path: &[usize], value: f32, color: Color) {
for &idx in path {
let node = &mut nodes[idx];
if node.board_state.board.side_to_move() == color {
node.value_sum += value;
} else {
node.value_sum -= value;
}
node.visit_count += 1;
}
}
fn add_dirichlet_noise(&mut self, node_id: usize, nodes: &mut Vec<Node>) {
let node_children = &nodes[node_id].children.clone();
let n = node_children.len();
if n == 0 {
return;
}
let alpha = self.config.dirichlet_alpha;
let epsilon = self.config.dirichlet_epsilon;
let noise = dirichlet_sample(n, alpha);
for (node_idx, noise_val) in node_children.iter().zip(noise) {
let prev_prior = nodes[*node_idx].prior;
nodes[*node_idx].prior = (1.0 - epsilon) * prev_prior + epsilon * noise_val;
}
}
}
fn dirichlet_sample(size: usize, alpha: f32) -> Vec<f32> {
use rand_distr::{Distribution, Gamma};
let gamma = Gamma::new(alpha as f64, 1.0).unwrap();
let mut samples: Vec<f32> = (0..size)
.map(|_| gamma.sample(&mut rand::rng()) as f32)
.collect();
let sum: f32 = samples.iter().sum();
for x in &mut samples {
*x /= sum;
}
samples
}
fn ucb_score(parent: &Node, child: &Node, c_puct: &f32) -> f32 {
let prior_score = c_puct * child.prior * (parent.visit_count as f32).sqrt()
/ (1.0 + child.visit_count as f32);
let value_score: f32;
if child.visit_count > 0 {
value_score = -child.value(); // value from opposing side
} else {
value_score = 0.0;
}
value_score + prior_score
}
pub fn find_best_move(nodes: &[Node], root: usize) -> Option<(BoardState, String)> {
// Choose best child of root by visit count (standard MCTS behavior)
let root_node = &nodes[root];
if root_node.children.is_empty() {
return None;
}
let mut best_visits = 0u32;
let mut best_child_idx: Option<usize> = None;
for &child_idx in &root_node.children {
let child = &nodes[child_idx];
if child.visit_count > best_visits {
best_visits = child.visit_count;
best_child_idx = Some(child_idx);
}
}
let child_idx = match best_child_idx {
Some(i) => i,
None => return None,
};
let bn = &nodes[child_idx];
println!("best_child visits: {}", bn.visit_count);
let uci_move = bn.last_move?.to_string();
Some((bn.board_state.clone(), uci_move))
}
pub fn print_tree_arena(nodes: &[Node], root: usize, depth: usize) {
let node = &nodes[root];
println!(
"{}- Value: {:.2}, Visits: {}, Prior: {:.2}, Side: {:?}",
" ".repeat(depth),
node.value_sum,
node.visit_count,
node.prior,
node.board_state.board.side_to_move()
);
for &child_idx in &node.children {
print_tree_arena(nodes, child_idx, depth + 1);
}
}
fn material_of(piece: Piece) -> f32 {
match piece {
Queen => 9.0,
Rook => 5.0,
Bishop => 3.0,
Knight => 3.0,
Pawn => 1.0,
_ => 0.0,
}
}
const MATERIAL_VALUE_DIVISOR: f32 = 40.0;
pub fn heuristic_eval(board: &Board, perspective: Color) -> f32 {
let mut value = 0.0;
// material
for color in ALL_COLORS {
for piece in ALL_PIECES {
let bitboard = board.color_combined(color).0 & board.pieces(piece).0;
let total_val =
(bitboard.count_ones() as f32 * material_of(piece)) / MATERIAL_VALUE_DIVISOR;
if color == perspective {
value += total_val;
} else {
value -= total_val;
}
}
}
value
// board.checkers()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BoardStateStatus {
Ongoing,
Stalemate,
WhiteWinner,
BlackWinner,
Threefold,
FiftyMove,
}
#[derive(Clone, Debug)]
pub struct BoardState {
pub board: Board,
pub halfmove_clock: u8,
pub repetition_table: HashMap<u64, u8>,
pub status: BoardStateStatus,
}
impl BoardState {
pub fn new(board: Board, halfmove_clock: u8, repetition_table: HashMap<u64, u8>) -> BoardState {
BoardState {
board,
halfmove_clock,
repetition_table: repetition_table,
status: Ongoing,
}
}
pub fn apply_move(&mut self, mv: ChessMove) {
self.halfmove_clock += 1;
if self.board.piece_on(mv.get_source()) == Some(Pawn)
|| self.board.piece_on(mv.get_dest()).is_some()
{
self.halfmove_clock = 0;
}
self.board = self.board.make_move_new(mv);
let board_hash = self.board.get_hash();
let current_rep = self.repetition_table.get(&board_hash);
let mut new_rep = 1;
if current_rep.is_some() {
new_rep = current_rep.unwrap() + 1;
}
self.repetition_table.insert(board_hash, new_rep);
if self.board.status() == Checkmate {
if self.board.side_to_move() == White {
// white's move after black played the mating move
self.status = BoardStateStatus::BlackWinner;
} else {
self.status = BoardStateStatus::WhiteWinner;
}
} else if self.board.status() == Stalemate {
self.status = BoardStateStatus::Stalemate;
}
if self.halfmove_clock >= 100 {
self.status = FiftyMove;
}
if new_rep >= 3 {
self.status = Threefold;
}
}
}
impl Default for BoardState {
fn default() -> BoardState {
BoardState::new(Board::default(), 0, HashMap::new())
}
}

2
engine/src/net.rs Normal file
View File

@ -0,0 +1,2 @@
pub mod encoding;
pub mod model;

299
engine/src/net/encoding.rs Normal file
View File

@ -0,0 +1,299 @@
use crate::mcts::BoardState;
use burn::prelude::*;
use chess::{ChessMove, Color, File, Piece, Rank, Square, ALL_SQUARES};
/*
Input planes:
1-6: your pieces (Pawn, Knight, Bishop, Rook, Queen, King)
7-12: opponent pieces
13: your kingside
14: your queenside
15: opponent kingside
16: opponent queenside
17: rep >= 2
18: En passant square (gap pass over)
Output:
4672 moves represented through planes
1-56: sliding moves
57-64: Knight moves
65-73: underpromotions
*/
pub fn encode_board_state_perspective<B: Backend>(
state: &BoardState,
device: &B::Device,
) -> Tensor<B, 3> {
let mut planes = vec![0.0f32; 18 * 64];
let board = &state.board;
let us = board.side_to_move();
let them = !us;
let idx = |plane: usize, rank: usize, file: usize| -> usize { plane * 64 + rank * 8 + file };
let flip = us == Color::Black;
for &square in ALL_SQUARES.iter() {
if let Some(piece) = board.piece_on(square) {
let color = board.color_on(square).unwrap();
let piece_index = match piece {
Piece::Pawn => 0,
Piece::Knight => 1,
Piece::Bishop => 2,
Piece::Rook => 3,
Piece::Queen => 4,
Piece::King => 5,
};
// Determine plane based on perspective
let plane = if color == us {
piece_index
} else {
piece_index + 6
};
let mut rank = square.get_rank().to_index();
let mut file = square.get_file().to_index();
// Flip board if black to move
if flip {
rank = 7 - rank;
file = 7 - file;
}
planes[idx(plane, rank, file)] = 1.0;
}
}
// -------------------------
// Castling Rights (planes 1215)
// From perspective
// -------------------------
let (us_castle, them_castle) = (board.castle_rights(us), board.castle_rights(them));
if us_castle.has_kingside() {
fill_plane(&mut planes, 12);
}
if us_castle.has_queenside() {
fill_plane(&mut planes, 13);
}
if them_castle.has_kingside() {
fill_plane(&mut planes, 14);
}
if them_castle.has_queenside() {
fill_plane(&mut planes, 15);
}
// -------------------------
// Repetition plane (16)
// -------------------------
let current_hash = board.get_hash();
if let Some(count) = state.repetition_table.get(&current_hash) {
if *count >= 2 {
fill_plane(&mut planes, 16);
}
}
let en_passant = board.en_passant();
if let Some(ep) = en_passant {
let mut rank = ep.get_rank().to_index();
let mut file = ep.get_file().to_index();
// Flip board if black to move
if flip {
rank = 7 - rank;
file = 7 - file;
}
planes[idx(17, rank, file)] = 1.0;
}
Tensor::<B, 3>::from_floats(TensorData::new(planes, [18, 8, 8]), device)
}
fn fill_plane(buffer: &mut [f32], plane: usize) {
let start = plane * 64;
for i in 0..64 {
buffer[start + i] = 1.0;
}
}
pub fn encode_move(mv: ChessMove, side_to_move: Color) -> usize {
let from = mv.get_source().to_index();
let to = mv.get_dest().to_index();
let mut from_rank = from / 8;
let mut from_file = from % 8;
let mut to_rank = to / 8;
let mut to_file = to % 8;
if side_to_move == Color::Black {
from_rank = 7 - from_rank;
from_file = 7 - from_file;
to_rank = 7 - to_rank;
to_file = 7 - to_file;
}
let delta_rank = to_rank as i32 - from_rank as i32;
let delta_file = to_file as i32 - from_file as i32;
let plane = encode_move_type(delta_rank, delta_file, mv.get_promotion());
plane * 64 + (from_rank * 8 + from_file)
}
fn encode_move_type(dr: i32, df: i32, promotion: Option<Piece>) -> usize {
// Knight moves
const KNIGHT_DELTAS: [(i32, i32); 8] = [
(2, 1),
(1, 2),
(-1, 2),
(-2, 1),
(-2, -1),
(-1, -2),
(1, -2),
(2, -1),
];
for (i, (r, f)) in KNIGHT_DELTAS.iter().enumerate() {
if dr == *r && df == *f {
return 56 + i;
}
}
// UNDERPromotions
if let Some(promo) = promotion {
if promo != Piece::Queen {
let dir = if df == 0 {
0
} else if df < 0 {
1
} else {
2
};
let piece_index = match promo {
// Piece::Queen => 0,
Piece::Rook => 0,
Piece::Bishop => 1,
Piece::Knight => 2,
_ => unreachable!(),
};
return 64 + dir * 3 + piece_index;
}
}
// Sliding
let direction_index = match (dr.signum(), df.signum()) {
(1, 0) => 0, // N
(1, 1) => 1,
(0, 1) => 2,
(-1, 1) => 3,
(-1, 0) => 4,
(-1, -1) => 5,
(0, -1) => 6,
(1, -1) => 7,
_ => panic!("Invalid move delta"),
};
let distance = dr.abs().max(df.abs()) as usize - 1;
direction_index * 7 + distance
}
pub fn decode_move(index: usize, side_to_move: Color) -> ChessMove {
let from_index = index % 64;
let plane = index / 64;
// Perspective-space coordinates
let mut from_rank = from_index / 8;
let mut from_file = from_index % 8;
let (mut dr, mut df, promotion) = decode_move_type(plane);
// Convert from perspective coordinates back to absolute board coordinates
if side_to_move == Color::Black {
from_rank = 7 - from_rank;
from_file = 7 - from_file;
dr = -dr;
df = -df;
}
let to_rank = (from_rank as i32 + dr) as usize;
let to_file = (from_file as i32 + df) as usize;
let from = Square::make_square(Rank::from_index(from_rank), File::from_index(from_file));
let to = Square::make_square(Rank::from_index(to_rank), File::from_index(to_file));
ChessMove::new(from, to, promotion)
}
fn decode_move_type(plane: usize) -> (i32, i32, Option<Piece>) {
// Knight moves
const KNIGHT_DELTAS: [(i32, i32); 8] = [
(2, 1),
(1, 2),
(-1, 2),
(-2, 1),
(-2, -1),
(-1, -2),
(1, -2),
(2, -1),
];
// 055: sliding moves
if plane < 56 {
let direction = plane / 7;
let distance = (plane % 7) + 1;
let (dr, df) = match direction {
0 => (1, 0),
1 => (1, 1),
2 => (0, 1),
3 => (-1, 1),
4 => (-1, 0),
5 => (-1, -1),
6 => (0, -1),
7 => (1, -1),
_ => unreachable!(),
};
return (dr * distance as i32, df * distance as i32, None);
}
// 5663: knight moves
if plane < 64 {
let (dr, df) = KNIGHT_DELTAS[plane - 56];
return (dr, df, None);
}
// 6472: underpromotions
let promo_plane = plane - 64;
let dir = promo_plane / 3;
let piece_index = promo_plane % 3;
let df = match dir {
0 => 0,
1 => -1,
2 => 1,
_ => unreachable!(),
};
let dr = 1; // always forward (important: assumes white perspective)
let promotion = Some(match piece_index {
0 => Piece::Rook,
1 => Piece::Bishop,
2 => Piece::Knight,
_ => unreachable!(),
});
(dr, df, promotion)
}

338
engine/src/net/model.rs Normal file
View File

@ -0,0 +1,338 @@
use crate::mcts::{BoardState, MctsResults};
use crate::net::encoding::{encode_board_state_perspective, encode_move};
use burn::data::dataloader::batcher::Batcher;
use burn::nn::conv::Conv2dConfig;
use burn::nn::loss::{MseLoss, Reduction};
use burn::nn::{BatchNorm, BatchNormConfig, LinearConfig, PaddingConfig2d};
use burn::tensor::activation::log_softmax;
use burn::tensor::backend::AutodiffBackend;
use burn::tensor::Transaction;
use burn::train::{InferenceStep, ItemLazy, TrainOutput, TrainStep};
use burn::{
nn::{conv::Conv2d, Linear, Relu},
prelude::*,
};
use burn_ndarray::NdArray;
use chess::ChessMove;
use std::collections::HashMap;
/*
Input planes:
1-6: your pieces (Pawn, Knight, Bishop, Rook, Queen, King)
7-12: opponent pieces
13: your kingside
14: your queenside
15: opponent kingside
16: opponent queenside
17: rep >= 2
18: En passant square (gap pass over)
Output:
4672 moves represented through planes
1-56: sliding moves
57-64: Knight moves
65-73: underpromotions
*/
#[derive(Module, Debug)]
pub struct ResidualBlock<B: Backend> {
conv1: Conv2d<B>,
bn1: BatchNorm<B>,
conv2: Conv2d<B>,
bn2: BatchNorm<B>,
activation: Relu,
}
#[derive(Config, Debug)]
pub struct ResidualBlockConfig {
channels: usize,
}
impl<B: Backend> ResidualBlock<B> {
pub fn new(channels: usize, device: &B::Device) -> Self {
Self {
conv1: Conv2dConfig::new([channels, channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1))
.init(device),
bn1: BatchNormConfig::new(channels).init(device),
conv2: Conv2dConfig::new([channels, channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1))
.init(device),
bn2: BatchNormConfig::new(channels).init(device),
activation: Relu::new(),
}
}
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let residual = x.clone();
let x = self.conv1.forward(x);
let x = self.bn1.forward(x);
let x = self.activation.forward(x);
let x = self.conv2.forward(x);
let x = self.bn2.forward(x);
self.activation.forward(x + residual)
}
}
#[derive(Module, Debug)]
pub struct ChessModel<B: Backend> {
// Initial convolution
conv: Conv2d<B>,
bn: BatchNorm<B>,
// Residual tower
residual_blocks: Vec<ResidualBlock<B>>,
// Policy head
policy_conv: Conv2d<B>,
policy_bn: BatchNorm<B>,
policy_fc: Linear<B>,
// Value head
value_conv: Conv2d<B>,
value_bn: BatchNorm<B>,
value_fc1: Linear<B>,
value_fc2: Linear<B>,
activation: Relu,
}
#[derive(Config, Debug)]
pub struct ChessModelConfig {
num_blocks: usize,
channels: usize,
}
impl ChessModelConfig {
pub fn init<B: Backend>(
num_blocks: usize,
channels: usize,
device: &B::Device,
) -> ChessModel<B> {
ChessModel {
conv: Conv2dConfig::new([18, channels], [3, 3]) // 18 plane input
.with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1))
.init(device),
bn: BatchNormConfig::new(channels).init(device),
residual_blocks: (0..num_blocks)
.map(|_| ResidualBlock::new(channels, device))
.collect(),
// Policy head
policy_conv: Conv2dConfig::new([channels, 2], [1, 1]).init(device),
policy_bn: BatchNormConfig::new(2).init(device),
policy_fc: LinearConfig::new(2 * 8 * 8, 8 * 8 * 73).init(device), // 4672 typical chess move space
// Value head
value_conv: Conv2dConfig::new([channels, 1], [1, 1]).init(device),
value_bn: BatchNormConfig::new(1).init(device),
value_fc1: LinearConfig::new(1 * 8 * 8, 256).init(device),
value_fc2: LinearConfig::new(256, 1).init(device),
activation: Relu::new(),
}
}
}
impl<B: Backend> ChessModel<B> {
pub fn forward(&self, x: Tensor<B, 4>) -> (Tensor<B, 2>, Tensor<B, 2>) {
let mut x = self.conv.forward(x);
x = self.bn.forward(x);
x = self.activation.forward(x);
for block in &self.residual_blocks {
x = block.forward(x);
}
let batch_size = x.dims()[0];
// -------- Policy Head --------
let mut p = self.policy_conv.forward(x.clone());
p = self.policy_bn.forward(p);
p = self.activation.forward(p);
let mut p = p.reshape([batch_size, 2 * 8 * 8]);
p = self.policy_fc.forward(p);
// -------- Value Head --------
let mut v = self.value_conv.forward(x);
v = self.value_bn.forward(v);
v = self.activation.forward(v);
let mut v = v.reshape([batch_size, 8 * 8]);
v = self.activation.forward(self.value_fc1.forward(v));
v = self.value_fc2.forward(v).tanh();
(p, v)
}
pub fn forward_chess(
&self,
boards: Tensor<B, 4>, // e.g. [batch, channels, 8, 8]
policy_targets: Tensor<B, 2>, // move distribution (make sure is normalized)
value_targets: Tensor<B, 2>, // scalar evaluation
) -> ChessOutput<B> {
let (policy_logits, value) = self.forward(boards);
let log_probs = log_softmax(policy_logits.clone(), 1);
let policy_loss = policy_targets
.clone()
.mul(log_probs)
.neg()
.sum_dim(1)
.mean();
let value_loss =
MseLoss::new().forward(value.clone(), value_targets.clone(), Reduction::Mean);
let total_loss = policy_loss + 0.5 * value_loss;
ChessOutput {
policy_logits,
value,
policy_targets,
value_targets,
loss: total_loss,
}
}
}
pub struct ChessOutput<B: Backend> {
pub policy_logits: Tensor<B, 2>, // [sample, num_moves (4672)]
pub value: Tensor<B, 2>, // [sample, 1 value]
pub policy_targets: Tensor<B, 2>,
pub value_targets: Tensor<B, 2>, // kept 2d for consistency?
pub loss: Tensor<B, 1>, // [sample]
}
impl<B: Backend> ItemLazy for ChessOutput<B> {
type ItemSync = ChessOutput<NdArray>;
fn sync(self) -> Self::ItemSync {
let [policy_logits, value, policy_targets, value_targets, loss] = Transaction::default()
.register(self.policy_logits)
.register(self.value)
.register(self.policy_targets)
.register(self.value_targets)
.register(self.loss)
.execute()
.try_into()
.expect("Correct amount of tensor data");
let device = &Default::default();
ChessOutput {
policy_logits: Tensor::from_data(policy_logits, device),
value: Tensor::from_data(value, device),
policy_targets: Tensor::from_data(policy_targets, device),
value_targets: Tensor::from_data(value_targets, device),
loss: Tensor::from_data(loss, device),
}
}
}
impl<B: AutodiffBackend> TrainStep for ChessModel<B> {
type Input = ChessBatch<B>;
type Output = ChessOutput<B>;
fn step(&self, batch: ChessBatch<B>) -> TrainOutput<ChessOutput<B>> {
let item = self.forward_chess(batch.states, batch.policy_targets, batch.value_targets);
TrainOutput::new(self, item.loss.backward(), item)
}
}
impl<B: Backend> InferenceStep for ChessModel<B> {
type Input = ChessBatch<B>;
type Output = ChessOutput<B>;
fn step(&self, batch: ChessBatch<B>) -> ChessOutput<B> {
self.forward_chess(batch.states, batch.policy_targets, batch.value_targets)
}
}
#[derive(Clone)]
pub struct TrainingSample {
pub board_state: BoardState,
pub policy_target: HashMap<ChessMove, f32>,
pub value_target: f32,
}
impl TrainingSample {
pub fn new(
board_state: BoardState,
policy_target: HashMap<ChessMove, f32>,
value_target: f32,
) -> Self {
TrainingSample {
board_state,
policy_target,
value_target,
}
}
pub fn from_mcts_with_outcome(mcts_results: MctsResults, outcome: f32) -> Self {
TrainingSample::new(mcts_results.board_state, mcts_results.move_dist, outcome)
}
}
#[derive(Clone, Default)]
pub struct ChessBatcher {}
#[derive(Clone, Debug)]
pub struct ChessBatch<B: Backend> {
pub states: Tensor<B, 4>,
pub policy_targets: Tensor<B, 2>,
pub value_targets: Tensor<B, 2>,
}
impl<B: Backend> Batcher<B, TrainingSample, ChessBatch<B>> for ChessBatcher {
fn batch(&self, items: Vec<TrainingSample>, device: &B::Device) -> ChessBatch<B> {
let state_tensors = items
.iter()
.map(|item| {
encode_board_state_perspective(&item.board_state, device).reshape([1, 18, 8, 8])
})
.collect::<Vec<_>>();
let policy_target_tensors = items
.iter()
.cloned()
.map(|item| {
let mut policy = vec![0.0f32; 4672];
let stm = item.board_state.board.side_to_move();
for (mv, prob) in item.policy_target {
policy[encode_move(mv, stm)] = prob;
}
// Normalize
let sum: f32 = policy.iter().sum();
if sum > 0.0 {
for p in &mut policy {
*p /= sum;
}
}
Tensor::<B, 2>::from_floats(TensorData::new(policy, [4672]), device).unsqueeze()
})
.collect::<Vec<_>>();
let value_target_tensors = items
.iter()
.map(|item| {
Tensor::<B, 2>::from_floats(TensorData::new(vec![item.value_target], [1]), device)
.reshape([1, 1])
})
.collect::<Vec<_>>();
let states = Tensor::cat(state_tensors, 0); // [B, 18, 8, 8]
let policy_targets = Tensor::cat(policy_target_tensors, 0); // [B, 4672]
let value_targets = Tensor::cat(value_target_tensors, 0); // [B, 1]
ChessBatch {
states,
policy_targets,
value_targets,
}
}
}

View File

@ -0,0 +1,59 @@
/// Multi-label classification output adapted for multiple metrics.
///
/// Supported metrics:
/// - HammingScore
/// - Precision (via ConfusionStatsInput)
/// - Recall (via ConfusionStatsInput)
/// - FBetaScore (via ConfusionStatsInput)
/// - Loss
#[derive(new)]
pub struct MultiLabelSoftClassificationOutput<B: Backend> {
/// The loss.
pub loss: Tensor<B, 1>,
/// The label logits or probabilities. Shape: \[batch_size, num_classes\].
pub output: Tensor<B, 2>,
/// The ground truth labels (target values). Shape: \[batch_size, num_classes\].
pub targets: Tensor<B, 2>,
}
impl<B: Backend> ItemLazy for MultiLabelSoftClassificationOutput<B> {
type ItemSync = MultiLabelSoftClassificationOutput<NdArray>;
fn sync(self) -> Self::ItemSync {
let [output, loss, targets] = Transaction::default()
.register(self.output)
.register(self.loss)
.register(self.targets)
.execute()
.try_into()
.expect("Correct amount of tensor data");
let device = &Default::default();
MultiLabelSoftClassificationOutput {
output: Tensor::from_data(output, device),
loss: Tensor::from_data(loss, device),
targets: Tensor::from_data(targets, device),
}
}
}
impl<B: Backend> Adaptor<HammingScoreInput<B>> for MultiLabelSoftClassificationOutput<B> {
fn adapt(&self) -> HammingScoreInput<B> {
HammingScoreInput::new(self.output.clone(), self.targets.clone())
}
}
impl<B: Backend> Adaptor<LossInput<B>> for MultiLabelSoftClassificationOutput<B> {
fn adapt(&self) -> LossInput<B> {
LossInput::new(self.loss.clone())
}
}
impl<B: Backend> Adaptor<ConfusionStatsInput<B>> for MultiLabelSoftClassificationOutput<B> {
fn adapt(&self) -> ConfusionStatsInput<B> {
ConfusionStatsInput::new(self.output.clone(), self.targets.clone().bool())
}
}

134
engine/src/train.rs Normal file
View File

@ -0,0 +1,134 @@
// use crate::model::{ChessBatcher, ChessModel, ChessModelConfig};
// use burn::config::Config;
// use burn::data::dataloader::batcher::Batcher;
// use burn::optim::{Adam, AdamConfig, Optimizer, SimpleOptimizer};
// use burn::optim::adaptor::OptimizerAdaptor;
// use burn::prelude::Backend;
// use burn::tensor::backend::AutodiffBackend;
// use burn::train::TrainStep;
// use crate::mcts::MctsResults;
//
// #[derive(Config, Debug)]
// pub struct ChessTrainerConfig {
// pub model: ChessModelConfig,
// pub optimizer: AdamConfig,
// #[config(default = 10)]
// pub num_epochs: usize,
// #[config(default = 64)]
// pub batch_size: usize,
// #[config(default = 4)]
// pub num_workers: usize,
// #[config(default = 42)]
// pub seed: u64,
// #[config(default = 1.0e-4)]
// pub learning_rate: f64,
// }
//
// impl ChessTrainerConfig {
// pub fn init<B: Backend>(
// model_config: ChessModelConfig,
// optimizer: AdamConfig,
// num_epochs: usize,
// batch_size: usize,
// num_workers: usize,
// seed: u64,
// learning_rate: f64,
// ) -> ChessTrainer<B> {
// ChessTrainer {
// model: model_config::init(),
// optimizer: optimizer.init(),
// num_epochs,
//
//
// }
// }
// }
//
// pub struct ChessTrainer<B: AutodiffBackend> {
// pub model: ChessModel<B>,
// pub optimizer: Adam,
// learning_rate: f64,
// pub batcher: ChessBatcher,
// pub device: B::Device,
// }
//
// impl<B: AutodiffBackend> ChessTrainer<B> {
// pub fn new(model: ChessModel<B>, device: B::Device) -> Self {
// let optimizer = AdamConfig::new().init();
//
// Self {
// model,
// optimizer,
// batcher: ChessBatcher::default(),
// device,
// }
// }
//
// pub fn train_step(&mut self, batch_data: Vec<MctsResults>) -> f32 {
// // 1. Convert to tensors
// let batch = self.batcher.batch(batch_data, &self.device);
//
// // 2. Forward + backward (your TrainStep impl)
// let output = self.model.step(batch);
//
// // 3. Update weights
// self.optimizer.step(, &mut self.model, output.grads);
//
// // 4. Return loss (for logging)
// let loss_tensor = output.item.loss.clone().into_data();
// let loss = loss_tensor.to_vec::<f32>().unwrap()[0];
//
// loss
// }
// }
// fn create_artifact_dir(artifact_dir: &str) {
// // Remove existing artifacts before to get an accurate learner summary
// std::fs::remove_dir_all(artifact_dir).ok();
// std::fs::create_dir_all(artifact_dir).ok();
// }
//
// pub fn train<B: AutodiffBackend>(
// artifact_dir: &str,
// config: ChessTrainingConfig,
// device: B::Device,
// ) {
// create_artifact_dir(artifact_dir);
// config
// .save(format!("{artifact_dir}/config.json"))
// .expect("Config should be saved successfully");
//
// B::seed(&device, config.seed);
//
// let batcher = ChessBatcher::default();
// let dataloader_train = DataLoaderBuilder::new(batcher.clone())
// .batch_size(config.batch_size)
// .shuffle(config.seed)
// .num_workers(config.num_workers)
// .build(MnistDataset::train());
//
// let dataloader_test = DataLoaderBuilder::new(batcher)
// .batch_size(config.batch_size)
// .shuffle(config.seed)
// .num_workers(config.num_workers)
// .build(MnistDataset::test());
//
// let training = SupervisedTraining::new(artifact_dir, dataloader_train, dataloader_test)
// .metrics((AccuracyMetric::new(), LossMetric::new()))
// .with_file_checkpointer(CompactRecorder::new())
// .num_epochs(config.num_epochs)
// .summary();
//
// let model = config.model.init::<B>(&device);
// let result = training.launch(Learner::new(
// model,
// config.optimizer.init(),
// config.learning_rate,
// ));
//
// result
// .model
// .save_file(format!("{artifact_dir}/model"), &CompactRecorder::new())
// .expect("Trained model should be saved successfully");
// }

1
engine/src/training.rs Normal file
View File

@ -0,0 +1 @@
pub mod train;

View File

@ -0,0 +1,236 @@
use crate::mcts::{BoardState, BoardStateStatus, Mcts, MctsConfig, MctsResults};
use crate::net::model::{ChessBatcher, ChessModel, ChessModelConfig, TrainingSample};
use burn::data::dataloader::batcher::Batcher;
use burn::module::{AutodiffModule, Module};
use burn::optim::{AdamConfig, GradientsParams, Optimizer};
use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder};
use burn::tensor::backend::AutodiffBackend;
use chess::ChessMove;
use rand::rngs::ThreadRng;
use rand::seq::SliceRandom;
use rand::RngExt;
use std::collections::{HashMap, VecDeque};
use std::marker::PhantomData;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Instant;
pub struct TrainingConfig {
pub max_time_s: Option<u64>,
pub num_iters: Option<u32>,
pub max_depth: u16, // unused
pub model_name: String,
pub load_model: bool,
pub hidden_channels: usize,
pub num_blocks: usize,
pub batch_size: usize, // num positions sampled to a batch (2048)
pub num_episodes: usize, // num games generated per iteration (5000)
pub buffer_max_size: usize, // max number of samples in the buffer
pub mcts_config: MctsConfig,
pub optimizer: AdamConfig,
pub lr: f64,
}
pub fn train<B: AutodiffBackend>(training_config: TrainingConfig, device: B::Device) {
let model_path = format!("artifacts/{}", training_config.model_name.as_str());
println!("Creating model...");
let mut model: ChessModel<B> = ChessModelConfig::init(
training_config.hidden_channels,
training_config.num_blocks,
&device,
);
if training_config.load_model {
println!("Loading model {}...", model_path);
// Load model in full precision from MessagePack file
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
model = model
.load_file(&model_path, &recorder, &device)
.expect("Should be able to load the model weights from the provided file");
}
let train = Arc::new(AtomicBool::new(true));
let train_signal = Arc::clone(&train);
let mut iter: u32 = 0;
let start_time = Instant::now();
ctrlc::set_handler(move || {
train_signal.store(false, Ordering::Relaxed);
println!("Finishing batch before exiting...");
})
.expect("Error setting Ctrl-C handler");
let mut replay_buffer: VecDeque<TrainingSample> = VecDeque::new();
let mut mcts: Mcts<B::InnerBackend> = Mcts {
config: training_config.mcts_config,
_marker: PhantomData,
};
let mut rng = rand::rng();
println!("Starting training...");
while train.load(Ordering::Relaxed) {
let infer_model = model.valid();
// Gen samples
println!("Generating {} games...", training_config.num_episodes);
for episode in 0..training_config.num_episodes {
println!("Episode: {}", episode);
let mut board_state = BoardState::default();
let mut episode_buffer: Vec<MctsResults> = vec![];
while board_state.status == BoardStateStatus::Ongoing {
let results = mcts.search(&board_state, &infer_model, &device);
episode_buffer.push(results);
let temp = if board_state.halfmove_clock < 30 {
1.0
} else {
0.0
};
let adjusted = apply_temperature(&episode_buffer.last().unwrap().move_dist, temp);
let mv = sample_move(&adjusted, &mut rng).unwrap();
println!("playing move: {}", mv);
board_state.apply_move(mv)
}
for result in episode_buffer.iter().enumerate() {
if board_state.status == BoardStateStatus::Stalemate
|| board_state.status == BoardStateStatus::Threefold
|| board_state.status == BoardStateStatus::FiftyMove
{
replay_buffer.push_back(TrainingSample::from_mcts_with_outcome(
result.1.clone(),
0.0,
));
} else if board_state.status == BoardStateStatus::WhiteWinner {
replay_buffer.push_back(TrainingSample::from_mcts_with_outcome(
result.1.clone(),
((result.0 % 2) as f32 * -2.0) + 1.0,
));
} else if board_state.status == BoardStateStatus::BlackWinner {
replay_buffer.push_back(TrainingSample::from_mcts_with_outcome(
result.1.clone(),
((result.0 % 2) as f32 * 2.0) - 1.0,
));
}
if replay_buffer.len() > training_config.buffer_max_size {
replay_buffer.pop_front();
}
}
}
// train
println!(
"Finished! Training on {} samples...",
training_config.batch_size
);
let mut indices: Vec<usize> = (0..replay_buffer.len()).collect();
indices.shuffle(&mut rng);
let samples = indices
.into_iter()
.take(training_config.batch_size)
.map(|i| replay_buffer[i].clone())
.collect::<Vec<TrainingSample>>();
let batcher = ChessBatcher {};
let batch = batcher.batch(samples, &device);
let mut optim = training_config.optimizer.init();
let output = model.forward_chess(batch.states, batch.policy_targets, batch.value_targets);
let grads = output.loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
model = optim.step(training_config.lr, model, grads);
iter += 1;
if iter % 100 == 0 {
println!("Completed {} iterations", iter);
}
if training_config.max_time_s.is_some()
&& start_time.elapsed().as_secs() > training_config.max_time_s.unwrap()
{
println!("Training stopping due to time limit...");
break;
}
if training_config.num_iters.is_some() && iter >= training_config.num_iters.unwrap() {
println!(
"Training stopping due to iteration limit ({} iters completed)...",
iter
);
break;
}
}
println!("Saving model...");
// Save model in MessagePack format with full precision
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
model
.save_file(&model_path, &recorder)
.expect("Should be able to save the model");
println!("Model saved in {:?}, exiting training.", model_path);
return;
}
fn apply_temperature(
visits: &HashMap<ChessMove, f32>,
temperature: f32,
) -> HashMap<ChessMove, f32> {
if visits.is_empty() {
return HashMap::new();
}
// Special case: deterministic selection
if temperature == 0.0 {
let (&best_move, _) = visits
.iter()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.unwrap();
let mut out = HashMap::new();
out.insert(best_move.clone(), 1.0);
return out;
}
let inv_temp = 1.0 / temperature;
// Step 1: apply exponent
let mut adjusted: HashMap<ChessMove, f32> =
visits.iter().map(|(m, v)| (*m, v.powf(inv_temp))).collect();
// Step 2: normalize
let sum: f32 = adjusted.values().sum();
if sum <= 0.0 {
return adjusted; // fallback (shouldn't happen in normal MCTS)
}
for v in adjusted.values_mut() {
*v /= sum;
}
adjusted
}
fn sample_move(dist: &HashMap<ChessMove, f32>, rng: &mut ThreadRng) -> Option<ChessMove> {
let mut r: f32 = rng.random_range(0.0..1.0);
for (m, p) in dist {
r -= p;
if r <= 0.0 {
return Some(m.clone());
}
}
// fallback due to floating point drift
dist.keys().next().cloned()
}

8
uci/Cargo.toml Normal file
View File

@ -0,0 +1,8 @@
[package]
name = "uci"
version = "0.1.0"
edition = "2024"
[dependencies]
engine = { path = "../engine" }
chess = "3.2.0"

216
uci/src/main.rs Normal file
View File

@ -0,0 +1,216 @@
use chess::{ChessMove, Game};
use engine::{legal_action_mask, Engine};
use std::io;
use std::str::FromStr;
const ENGINE_NAME: &str = "Chess Dragon";
const ENGINE_AUTHOR: &str = "Drake Marino";
pub fn uci_loop() {
let stdin = io::stdin();
let mut game = Game::new();
let mut engine = Engine::default();
loop {
let mut input = String::new();
if stdin.read_line(&mut input).is_err() {
break;
}
let input = input.trim();
if input.is_empty() {
continue;
}
let parts: Vec<&str> = input.split_whitespace().collect();
let command = parts[0];
match command {
"uci" => uci_uci(),
"id" => uci_id(),
"option" => uci_option(),
"setoption" => uci_setoption(input),
"ucinewgame" => uci_ucinewgame(&mut game),
"position" => uci_position(input, &mut game),
"go" => uci_go(input, &mut game, &mut engine),
_ => panic!("Invalid command!"),
}
}
}
// UCI Commands
fn uci_id() {
println!("id name {}", ENGINE_NAME);
println!("id author {}", ENGINE_AUTHOR);
}
fn uci_option() {
// none currently implemented
}
fn uci_uci() {
uci_id();
uci_option();
println!("uciok");
}
fn uci_setoption(input: &str) {
// TODO
}
fn uci_ucinewgame(game: &mut Game) {
*game = Game::new();
}
fn uci_position(input: &str, game: &mut Game) {
let mut tokens = input.split_whitespace();
if tokens.next().unwrap() != "position" {
panic!("position command not provided!");
}
let moves: Vec<String>;
match tokens.next().unwrap() {
"startpos" => {
*game = Game::new();
moves = tokens.skip_while(|&t| t != "moves").skip(1).map(|s| s.to_string()).collect();
}
"fen" => {
// FEN has 6 space-separated fields
let fen_fields: Vec<&str> = tokens.by_ref().take(6).collect();
if fen_fields.len() != 6 {
panic!("fen field invalid!");
}
let fen = fen_fields.join(" ");
*game = Game::from_str(&fen).expect("Invalid board position");
moves = tokens.skip_while(|&t| t != "moves").skip(1).map(|s| s.to_string()).collect();
}
_ => panic!("Position command invalid!"),
}
for mv in moves {
game.make_move(ChessMove::from_str(&mv).expect("Invalid move!"));
}
}
fn uci_go(input: &str, game: &mut Game, engine: &mut Engine) {
let parts: Vec<&str> = input.split_whitespace().collect();
let mut wtime = None;
let mut btime = None;
let mut winc = None;
let mut binc = None;
let mut movetime = None;
let mut max_depth = None;
let mut max_nodes = None;
let mut i = 1; // Skip "go"
while i < parts.len() {
match parts[i] {
"wtime" => {
if i + 1 < parts.len() {
wtime = parts[i + 1].parse::<u64>().ok();
i += 2;
} else {
i += 1;
}
}
"btime" => {
if i + 1 < parts.len() {
btime = parts[i + 1].parse::<u64>().ok();
i += 2;
} else {
i += 1;
}
}
"winc" => {
if i + 1 < parts.len() {
winc = parts[i + 1].parse::<u64>().ok();
i += 2;
} else {
i += 1;
}
}
"binc" => {
if i + 1 < parts.len() {
binc = parts[i + 1].parse::<u64>().ok();
i += 2;
} else {
i += 1;
}
}
"movetime" => {
if i + 1 < parts.len() {
movetime = parts[i + 1].parse::<u64>().ok();
i += 2;
} else {
i += 1;
}
}
"depth" => {
if i + 1 < parts.len() {
max_depth = parts[i + 1].parse::<u16>().ok();
i += 2;
} else {
i += 1;
}
}
"infinite" => {
max_depth = Some(100);
i += 1;
}
"nodes" => {
if i + 1 < parts.len() {
max_nodes = parts[i + 1].parse::<usize>().ok();
i += 2;
} else {
i += 1;
}
}
_ => {
i += 1;
}
}
}
// Update search settings
if let Some(wt) = wtime {
engine.search_settings.wtime = wt;
}
if let Some(bt) = btime {
engine.search_settings.btime = bt;
}
if let Some(wi) = winc {
engine.search_settings.winc = wi;
}
if let Some(bi) = binc {
engine.search_settings.binc = bi;
}
if let Some(mt) = movetime {
engine.search_settings.movetime = Some(mt);
}
if let Some(max_depth) = max_depth {
engine.search_settings.max_depth = Some(max_depth);
}
if let Some(nodes) = max_nodes {
engine.search_settings.max_nodes = Some(nodes);
}
}
fn main() {
uci_loop();
}

6
web/Cargo.toml Normal file
View File

@ -0,0 +1,6 @@
[package]
name = "web"
version = "0.1.0"
edition = "2024"
[dependencies]

3
web/src/main.rs Normal file
View File

@ -0,0 +1,3 @@
fn main() {
println!("Hello, world!");
}