first working
This commit is contained in:
commit
764252b0c4
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
/target
|
||||||
10
.idea/.gitignore
generated
vendored
Normal file
10
.idea/.gitignore
generated
vendored
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
# Default ignored files
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# Ignored default folder with query files
|
||||||
|
/queries/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
||||||
|
# Editor-based HTTP Client requests
|
||||||
|
/httpRequests/
|
||||||
15
.idea/chess_dragon.iml
generated
Normal file
15
.idea/chess_dragon.iml
generated
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="EMPTY_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$">
|
||||||
|
<sourceFolder url="file://$MODULE_DIR$/cli/src" isTestSource="false" />
|
||||||
|
<sourceFolder url="file://$MODULE_DIR$/engine/src" isTestSource="false" />
|
||||||
|
<sourceFolder url="file://$MODULE_DIR$/src" isTestSource="false" />
|
||||||
|
<sourceFolder url="file://$MODULE_DIR$/uci/src" isTestSource="false" />
|
||||||
|
<sourceFolder url="file://$MODULE_DIR$/web/src" isTestSource="false" />
|
||||||
|
<excludeFolder url="file://$MODULE_DIR$/target" />
|
||||||
|
</content>
|
||||||
|
<orderEntry type="inheritedJdk" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
</module>
|
||||||
6
.idea/copilot.data.migration.agent.xml
generated
Normal file
6
.idea/copilot.data.migration.agent.xml
generated
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="AgentMigrationStateService">
|
||||||
|
<option name="migrationStatus" value="COMPLETED" />
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
6
.idea/copilot.data.migration.ask.xml
generated
Normal file
6
.idea/copilot.data.migration.ask.xml
generated
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="AskMigrationStateService">
|
||||||
|
<option name="migrationStatus" value="COMPLETED" />
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
6
.idea/copilot.data.migration.ask2agent.xml
generated
Normal file
6
.idea/copilot.data.migration.ask2agent.xml
generated
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="Ask2AgentMigrationStateService">
|
||||||
|
<option name="migrationStatus" value="COMPLETED" />
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
6
.idea/copilot.data.migration.edit.xml
generated
Normal file
6
.idea/copilot.data.migration.edit.xml
generated
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="EditMigrationStateService">
|
||||||
|
<option name="migrationStatus" value="COMPLETED" />
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
8
.idea/modules.xml
generated
Normal file
8
.idea/modules.xml
generated
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/chess_dragon.iml" filepath="$PROJECT_DIR$/.idea/chess_dragon.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
6
.idea/vcs.xml
generated
Normal file
6
.idea/vcs.xml
generated
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
7461
Cargo.lock
generated
Normal file
7461
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
4
Cargo.toml
Normal file
4
Cargo.toml
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
[workspace]
|
||||||
|
members = ["engine", "cli", "uci", "web"]
|
||||||
|
resolver = "2"
|
||||||
|
|
||||||
8
cli/Cargo.toml
Normal file
8
cli/Cargo.toml
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
[package]
|
||||||
|
name = "cli"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2024"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
engine = { path = "../engine" }
|
||||||
|
|
||||||
3
cli/src/commands.rs
Normal file
3
cli/src/commands.rs
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
pub mod benchmark;
|
||||||
|
pub mod play;
|
||||||
|
pub mod train;
|
||||||
1
cli/src/commands/benchmark.rs
Normal file
1
cli/src/commands/benchmark.rs
Normal file
@ -0,0 +1 @@
|
|||||||
|
|
||||||
0
cli/src/commands/play.rs
Normal file
0
cli/src/commands/play.rs
Normal file
4
cli/src/commands/train.rs
Normal file
4
cli/src/commands/train.rs
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
pub fn run() {
|
||||||
|
|
||||||
|
// train(TrainingConfig {}, );
|
||||||
|
}
|
||||||
16
cli/src/main.rs
Normal file
16
cli/src/main.rs
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
mod commands;
|
||||||
|
|
||||||
|
use std::env;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let args: Vec<String> = env::args().collect();
|
||||||
|
|
||||||
|
match args.get(1).map(|s| s.as_str()) {
|
||||||
|
Some("train") => commands::train::run(),
|
||||||
|
Some("selfplay") => commands::selfplay::run(),
|
||||||
|
Some("benchmark") => commands::benchmark::run(),
|
||||||
|
_ => {
|
||||||
|
println!("unknown command");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
12
engine/Cargo.toml
Normal file
12
engine/Cargo.toml
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
[package]
|
||||||
|
name = "engine"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2024"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
chess = "3.2.0"
|
||||||
|
burn = { version = "0.21.0", features = ["std", "tui", "train", "cuda", "fusion", "wgpu", "ndarray"], default-features = false }
|
||||||
|
burn-ndarray = "0.21.0"
|
||||||
|
rand_distr = "0.6.0"
|
||||||
|
rand = "0.10.1"
|
||||||
|
ctrlc = "3.5.2"
|
||||||
70
engine/src/lib.rs
Normal file
70
engine/src/lib.rs
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
pub mod mcts;
|
||||||
|
mod net;
|
||||||
|
mod train;
|
||||||
|
pub mod training;
|
||||||
|
|
||||||
|
use chess::Game;
|
||||||
|
|
||||||
|
pub const DEFAULT_MAX_DEPTH: u16 = 6;
|
||||||
|
pub const DEFAULT_PLAYER_TIME_REMAINING_MS: u64 = 120_000; // 2 minutes
|
||||||
|
pub const DEFAULT_PLAYER_INCREMENT_MS: u64 = 0;
|
||||||
|
|
||||||
|
pub struct Engine {
|
||||||
|
pub game: Game,
|
||||||
|
pub search_settings: SearchSettings,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct SearchSettings {
|
||||||
|
pub wtime: u64,
|
||||||
|
pub btime: u64,
|
||||||
|
pub winc: u64,
|
||||||
|
pub binc: u64,
|
||||||
|
pub movetime: Option<u64>,
|
||||||
|
pub max_depth: Option<u16>,
|
||||||
|
pub max_nodes: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for SearchSettings {
|
||||||
|
fn default() -> Self {
|
||||||
|
SearchSettings::new(None, None, None, None, None, None, None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SearchSettings {
|
||||||
|
pub fn new(
|
||||||
|
wtime: Option<u64>,
|
||||||
|
btime: Option<u64>,
|
||||||
|
winc: Option<u64>,
|
||||||
|
binc: Option<u64>,
|
||||||
|
movetime: Option<u64>,
|
||||||
|
max_depth: Option<u16>,
|
||||||
|
max_nodes: Option<usize>,
|
||||||
|
) -> Self {
|
||||||
|
SearchSettings {
|
||||||
|
wtime: wtime.unwrap_or(DEFAULT_PLAYER_TIME_REMAINING_MS),
|
||||||
|
btime: btime.unwrap_or(DEFAULT_PLAYER_TIME_REMAINING_MS),
|
||||||
|
winc: winc.unwrap_or(DEFAULT_PLAYER_INCREMENT_MS),
|
||||||
|
binc: binc.unwrap_or(DEFAULT_PLAYER_INCREMENT_MS),
|
||||||
|
movetime,
|
||||||
|
max_depth,
|
||||||
|
max_nodes,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Engine {
|
||||||
|
fn default() -> Self {
|
||||||
|
Engine::new(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Engine {
|
||||||
|
pub fn new(search_settings: Option<SearchSettings>) -> Self {
|
||||||
|
Engine {
|
||||||
|
game: Game::new(),
|
||||||
|
search_settings: search_settings.unwrap_or(SearchSettings::default()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn go(&mut self) {}
|
||||||
|
}
|
||||||
104
engine/src/main.rs
Normal file
104
engine/src/main.rs
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
#![recursion_limit = "256"]
|
||||||
|
|
||||||
|
use burn::backend::{Autodiff, Cuda};
|
||||||
|
use burn::optim::AdamConfig;
|
||||||
|
use engine::mcts::MctsConfig;
|
||||||
|
use engine::training::train::{train, TrainingConfig};
|
||||||
|
// fn main() {
|
||||||
|
// type MyBackend = Wgpu<f32, i32>;
|
||||||
|
//
|
||||||
|
// let device = Default::default();
|
||||||
|
// let model = ChessModelConfig::new(10, 512).init::<MyBackend>(&device);
|
||||||
|
//
|
||||||
|
// println!("{model}");
|
||||||
|
// }
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
// type MyBackend = Wgpu<f32, i32>;
|
||||||
|
type MyBackend = Cuda<f32, i32>;
|
||||||
|
// type MyBackend = NdArray<f32, i32>;
|
||||||
|
type MyAutodiffBackend = Autodiff<MyBackend>;
|
||||||
|
// let device = burn::backend::wgpu::WgpuDevice::default();
|
||||||
|
// let device = burn::backend::ndarray::NdArrayDevice::default();
|
||||||
|
let device = burn::backend::cuda::CudaDevice::default();
|
||||||
|
|
||||||
|
let mcts_config = MctsConfig::new(10, 1.0, 0.05, 0.25);
|
||||||
|
|
||||||
|
let adam_config = AdamConfig::new();
|
||||||
|
|
||||||
|
let training_config = TrainingConfig {
|
||||||
|
max_time_s: None,
|
||||||
|
num_iters: None,
|
||||||
|
max_depth: 100, // unused
|
||||||
|
model_name: String::from("Test1"),
|
||||||
|
load_model: false,
|
||||||
|
hidden_channels: 64,
|
||||||
|
num_blocks: 4,
|
||||||
|
batch_size: 128, // num positions sampled to a batch (2048)
|
||||||
|
num_episodes: 10, // num games generated per iteration (5000)
|
||||||
|
buffer_max_size: 200_000, // max number of samples in the buffer
|
||||||
|
mcts_config,
|
||||||
|
optimizer: adam_config,
|
||||||
|
lr: 2e-4,
|
||||||
|
};
|
||||||
|
|
||||||
|
train::<MyAutodiffBackend>(training_config, device);
|
||||||
|
|
||||||
|
// let mut board_state: Option<BoardState> = Some(BoardState::new(
|
||||||
|
// Board::default(),
|
||||||
|
// 0,
|
||||||
|
// HashMap::<u64, u8>::new(),
|
||||||
|
// ));
|
||||||
|
//
|
||||||
|
//
|
||||||
|
// // let start = Instant::now();
|
||||||
|
// // rollout(&board, &PolicyNetwork::new());
|
||||||
|
// // let end = Instant::now();
|
||||||
|
// // println!("Rollout: {:?}", end - start);
|
||||||
|
// let mut previous_state = board_state.clone().unwrap();
|
||||||
|
//
|
||||||
|
// let mut move_history: String = String::new();
|
||||||
|
//
|
||||||
|
// let mut mcts = MctsConfig::new()
|
||||||
|
//
|
||||||
|
// while board_state.is_some() {
|
||||||
|
// // mcts now returns Option<(BoardState, String)>
|
||||||
|
// let results = Mcts::search(&previous_state, 1000000, 0);
|
||||||
|
// if !results.move_dist.is_empty() {
|
||||||
|
// let mut best_prob: f32 = -1.0;
|
||||||
|
// let mut best_move: &ChessMove = &ChessMove::default();
|
||||||
|
//
|
||||||
|
// for m in results.move_dist.iter() {
|
||||||
|
// if m.1 > &best_prob {
|
||||||
|
// best_move = m.0;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// previous_state = results.board_state;
|
||||||
|
// previous_state.board = previous_state.board.make_move_new(*best_move);
|
||||||
|
// board_state = Some(previous_state.clone());
|
||||||
|
//
|
||||||
|
// println!(
|
||||||
|
// "Chosen move: {}, board state: {}",
|
||||||
|
// best_move.to_string(),
|
||||||
|
// previous_state.board
|
||||||
|
// );
|
||||||
|
// move_history += &*best_move.to_string();
|
||||||
|
// move_history += " ";
|
||||||
|
// } else {
|
||||||
|
// board_state = None;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// println!("Finished game w/ board state: {}", previous_state.board);
|
||||||
|
// println!("Move history: {}", move_history);
|
||||||
|
// println!(
|
||||||
|
// "Game ended! status: {:?}, clock: {}, 3-fold: {}",
|
||||||
|
// previous_state.board.status(),
|
||||||
|
// previous_state.halfmove_clock,
|
||||||
|
// previous_state
|
||||||
|
// .repetition_table
|
||||||
|
// .iter()
|
||||||
|
// .any(|(_, ct)| ct >= &3u8)
|
||||||
|
// );
|
||||||
|
}
|
||||||
416
engine/src/mcts.rs
Normal file
416
engine/src/mcts.rs
Normal file
@ -0,0 +1,416 @@
|
|||||||
|
// use crate::networks::PolicyNetwork;
|
||||||
|
use crate::mcts::BoardStateStatus::{FiftyMove, Ongoing, Threefold};
|
||||||
|
use crate::net::encoding::{encode_board_state_perspective, encode_move};
|
||||||
|
use crate::net::model::ChessModel;
|
||||||
|
use burn::prelude::Backend;
|
||||||
|
use burn::Tensor;
|
||||||
|
use chess::BoardStatus::{Checkmate, Stalemate};
|
||||||
|
use chess::Color::White;
|
||||||
|
use chess::Piece::{Bishop, Knight, Pawn, Queen, Rook};
|
||||||
|
use chess::{Board, ChessMove, Color, MoveGen, Piece, ALL_COLORS, ALL_PIECES};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
pub struct Node {
|
||||||
|
pub prior: f32,
|
||||||
|
pub children: Vec<usize>,
|
||||||
|
pub visit_count: u32,
|
||||||
|
pub value_sum: f32,
|
||||||
|
pub board_state: BoardState,
|
||||||
|
pub last_move: Option<ChessMove>, // move that produced this node from its parent
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Node {
|
||||||
|
pub fn new(prior: f32, board_state: BoardState, last_move: Option<ChessMove>) -> Node {
|
||||||
|
Node {
|
||||||
|
prior,
|
||||||
|
children: vec![],
|
||||||
|
visit_count: 0,
|
||||||
|
value_sum: 0.0,
|
||||||
|
board_state,
|
||||||
|
last_move,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn value(&self) -> f32 {
|
||||||
|
self.value_sum / self.visit_count as f32
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn select_child(&self, arena: &[Node], c_puct: &f32) -> usize {
|
||||||
|
self.children
|
||||||
|
.iter()
|
||||||
|
.copied()
|
||||||
|
.max_by(|&a, &b| {
|
||||||
|
ucb_score(self, &arena[a], c_puct)
|
||||||
|
.partial_cmp(&ucb_score(self, &arena[b], c_puct))
|
||||||
|
.unwrap()
|
||||||
|
})
|
||||||
|
.expect("select_child on leaf")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Clone for Node {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self::new(self.prior, self.board_state.clone(), self.last_move.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct MctsResults {
|
||||||
|
pub board_state: BoardState,
|
||||||
|
pub move_dist: HashMap<ChessMove, f32>,
|
||||||
|
pub value: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MctsResults {
|
||||||
|
pub fn new(
|
||||||
|
board_state: BoardState,
|
||||||
|
move_dist: HashMap<ChessMove, f32>,
|
||||||
|
value: f32,
|
||||||
|
) -> MctsResults {
|
||||||
|
MctsResults {
|
||||||
|
board_state,
|
||||||
|
move_dist,
|
||||||
|
value,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct MctsConfig {
|
||||||
|
pub num_simulations: usize,
|
||||||
|
pub c_puct: f32,
|
||||||
|
pub dirichlet_alpha: f32,
|
||||||
|
pub dirichlet_epsilon: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MctsConfig {
|
||||||
|
pub fn new(
|
||||||
|
num_simulations: usize,
|
||||||
|
c_puct: f32,
|
||||||
|
dirichlet_alpha: f32,
|
||||||
|
dirichlet_epsilon: f32,
|
||||||
|
) -> MctsConfig {
|
||||||
|
MctsConfig {
|
||||||
|
num_simulations,
|
||||||
|
c_puct,
|
||||||
|
dirichlet_alpha,
|
||||||
|
dirichlet_epsilon,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for MctsConfig {
|
||||||
|
fn default() -> MctsConfig {
|
||||||
|
MctsConfig::new(400, 1.0, 0.05, 0.25)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Mcts<B: Backend> {
|
||||||
|
pub config: MctsConfig,
|
||||||
|
pub _marker: PhantomData<B>, // if B not otherwise stored
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> Mcts<B> {
|
||||||
|
pub fn search(
|
||||||
|
&mut self,
|
||||||
|
board_state: &BoardState,
|
||||||
|
model: &ChessModel<B>,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> MctsResults {
|
||||||
|
let mut nodes = Vec::<Node>::new();
|
||||||
|
|
||||||
|
let root = 0;
|
||||||
|
nodes.push(Node::new(0.0, board_state.clone(), None));
|
||||||
|
|
||||||
|
self.expand(root, &mut nodes, model, device);
|
||||||
|
|
||||||
|
// 👇 APPLY DIRICHLET NOISE HERE
|
||||||
|
self.add_dirichlet_noise(root, &mut nodes);
|
||||||
|
|
||||||
|
for i in 0..self.config.num_simulations {
|
||||||
|
// if i % 10 == 0 {
|
||||||
|
// println!("mcts sim #{}", i);
|
||||||
|
// }
|
||||||
|
let mut path = vec![root];
|
||||||
|
let mut current = root;
|
||||||
|
|
||||||
|
while !nodes[current].children.is_empty() {
|
||||||
|
current = nodes[current].select_child(&nodes, &self.config.c_puct);
|
||||||
|
path.push(current);
|
||||||
|
}
|
||||||
|
|
||||||
|
let value: f32 = self.expand(current, &mut nodes, model, device);
|
||||||
|
|
||||||
|
let color = nodes[current].board_state.board.side_to_move();
|
||||||
|
self.backpropagate(&mut nodes, &path, value, color);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut move_dist: HashMap<ChessMove, f32> = HashMap::new(); // TODO: make vec<(Chessmove, f32)>
|
||||||
|
for idx in nodes[root].children.iter() {
|
||||||
|
move_dist.insert(
|
||||||
|
nodes[*idx].last_move.expect("move didnt exist"),
|
||||||
|
nodes[*idx].visit_count as f32 / self.config.num_simulations as f32,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
MctsResults::new(board_state.clone(), move_dist, nodes[root].value())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn expand(
|
||||||
|
&mut self,
|
||||||
|
node_idx: usize,
|
||||||
|
arena: &mut Vec<Node>,
|
||||||
|
model: &ChessModel<B>,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> f32 {
|
||||||
|
let state: Tensor<B, 4> =
|
||||||
|
encode_board_state_perspective(&arena[node_idx].board_state, device)
|
||||||
|
.reshape([1, 18, 8, 8]);
|
||||||
|
let start = Instant::now();
|
||||||
|
let (policy_head, value_head) = model.forward(state);
|
||||||
|
println!("time: {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let legal_moves: Vec<ChessMove> =
|
||||||
|
MoveGen::new_legal(&arena[node_idx].board_state.board).collect();
|
||||||
|
|
||||||
|
let policy = policy_head.into_data().to_vec::<f32>().unwrap();
|
||||||
|
|
||||||
|
for mv in legal_moves {
|
||||||
|
let stm = arena[node_idx].board_state.board.side_to_move();
|
||||||
|
let idx = encode_move(mv, stm);
|
||||||
|
let prior = policy[idx];
|
||||||
|
|
||||||
|
let mut new_board = arena[node_idx].board_state.clone();
|
||||||
|
new_board.apply_move(mv);
|
||||||
|
|
||||||
|
let child_idx = arena.len();
|
||||||
|
|
||||||
|
arena.push(Node::new(prior, new_board, Some(mv)));
|
||||||
|
arena[node_idx].children.push(child_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
value_head.into_data().to_vec().unwrap()[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn backpropagate(&mut self, nodes: &mut [Node], path: &[usize], value: f32, color: Color) {
|
||||||
|
for &idx in path {
|
||||||
|
let node = &mut nodes[idx];
|
||||||
|
|
||||||
|
if node.board_state.board.side_to_move() == color {
|
||||||
|
node.value_sum += value;
|
||||||
|
} else {
|
||||||
|
node.value_sum -= value;
|
||||||
|
}
|
||||||
|
|
||||||
|
node.visit_count += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add_dirichlet_noise(&mut self, node_id: usize, nodes: &mut Vec<Node>) {
|
||||||
|
let node_children = &nodes[node_id].children.clone();
|
||||||
|
|
||||||
|
let n = node_children.len();
|
||||||
|
if n == 0 {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let alpha = self.config.dirichlet_alpha;
|
||||||
|
let epsilon = self.config.dirichlet_epsilon;
|
||||||
|
|
||||||
|
let noise = dirichlet_sample(n, alpha);
|
||||||
|
|
||||||
|
for (node_idx, noise_val) in node_children.iter().zip(noise) {
|
||||||
|
let prev_prior = nodes[*node_idx].prior;
|
||||||
|
nodes[*node_idx].prior = (1.0 - epsilon) * prev_prior + epsilon * noise_val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dirichlet_sample(size: usize, alpha: f32) -> Vec<f32> {
|
||||||
|
use rand_distr::{Distribution, Gamma};
|
||||||
|
|
||||||
|
let gamma = Gamma::new(alpha as f64, 1.0).unwrap();
|
||||||
|
|
||||||
|
let mut samples: Vec<f32> = (0..size)
|
||||||
|
.map(|_| gamma.sample(&mut rand::rng()) as f32)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let sum: f32 = samples.iter().sum();
|
||||||
|
|
||||||
|
for x in &mut samples {
|
||||||
|
*x /= sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
samples
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ucb_score(parent: &Node, child: &Node, c_puct: &f32) -> f32 {
|
||||||
|
let prior_score = c_puct * child.prior * (parent.visit_count as f32).sqrt()
|
||||||
|
/ (1.0 + child.visit_count as f32);
|
||||||
|
|
||||||
|
let value_score: f32;
|
||||||
|
if child.visit_count > 0 {
|
||||||
|
value_score = -child.value(); // value from opposing side
|
||||||
|
} else {
|
||||||
|
value_score = 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
value_score + prior_score
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn find_best_move(nodes: &[Node], root: usize) -> Option<(BoardState, String)> {
|
||||||
|
// Choose best child of root by visit count (standard MCTS behavior)
|
||||||
|
let root_node = &nodes[root];
|
||||||
|
if root_node.children.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut best_visits = 0u32;
|
||||||
|
let mut best_child_idx: Option<usize> = None;
|
||||||
|
for &child_idx in &root_node.children {
|
||||||
|
let child = &nodes[child_idx];
|
||||||
|
if child.visit_count > best_visits {
|
||||||
|
best_visits = child.visit_count;
|
||||||
|
best_child_idx = Some(child_idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let child_idx = match best_child_idx {
|
||||||
|
Some(i) => i,
|
||||||
|
None => return None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let bn = &nodes[child_idx];
|
||||||
|
|
||||||
|
println!("best_child visits: {}", bn.visit_count);
|
||||||
|
let uci_move = bn.last_move?.to_string();
|
||||||
|
Some((bn.board_state.clone(), uci_move))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn print_tree_arena(nodes: &[Node], root: usize, depth: usize) {
|
||||||
|
let node = &nodes[root];
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"{}- Value: {:.2}, Visits: {}, Prior: {:.2}, Side: {:?}",
|
||||||
|
" ".repeat(depth),
|
||||||
|
node.value_sum,
|
||||||
|
node.visit_count,
|
||||||
|
node.prior,
|
||||||
|
node.board_state.board.side_to_move()
|
||||||
|
);
|
||||||
|
|
||||||
|
for &child_idx in &node.children {
|
||||||
|
print_tree_arena(nodes, child_idx, depth + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn material_of(piece: Piece) -> f32 {
|
||||||
|
match piece {
|
||||||
|
Queen => 9.0,
|
||||||
|
Rook => 5.0,
|
||||||
|
Bishop => 3.0,
|
||||||
|
Knight => 3.0,
|
||||||
|
Pawn => 1.0,
|
||||||
|
_ => 0.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const MATERIAL_VALUE_DIVISOR: f32 = 40.0;
|
||||||
|
|
||||||
|
pub fn heuristic_eval(board: &Board, perspective: Color) -> f32 {
|
||||||
|
let mut value = 0.0;
|
||||||
|
|
||||||
|
// material
|
||||||
|
for color in ALL_COLORS {
|
||||||
|
for piece in ALL_PIECES {
|
||||||
|
let bitboard = board.color_combined(color).0 & board.pieces(piece).0;
|
||||||
|
let total_val =
|
||||||
|
(bitboard.count_ones() as f32 * material_of(piece)) / MATERIAL_VALUE_DIVISOR;
|
||||||
|
|
||||||
|
if color == perspective {
|
||||||
|
value += total_val;
|
||||||
|
} else {
|
||||||
|
value -= total_val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
value
|
||||||
|
|
||||||
|
// board.checkers()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum BoardStateStatus {
|
||||||
|
Ongoing,
|
||||||
|
Stalemate,
|
||||||
|
WhiteWinner,
|
||||||
|
BlackWinner,
|
||||||
|
Threefold,
|
||||||
|
FiftyMove,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct BoardState {
|
||||||
|
pub board: Board,
|
||||||
|
pub halfmove_clock: u8,
|
||||||
|
pub repetition_table: HashMap<u64, u8>,
|
||||||
|
pub status: BoardStateStatus,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BoardState {
|
||||||
|
pub fn new(board: Board, halfmove_clock: u8, repetition_table: HashMap<u64, u8>) -> BoardState {
|
||||||
|
BoardState {
|
||||||
|
board,
|
||||||
|
halfmove_clock,
|
||||||
|
repetition_table: repetition_table,
|
||||||
|
status: Ongoing,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn apply_move(&mut self, mv: ChessMove) {
|
||||||
|
self.halfmove_clock += 1;
|
||||||
|
if self.board.piece_on(mv.get_source()) == Some(Pawn)
|
||||||
|
|| self.board.piece_on(mv.get_dest()).is_some()
|
||||||
|
{
|
||||||
|
self.halfmove_clock = 0;
|
||||||
|
}
|
||||||
|
self.board = self.board.make_move_new(mv);
|
||||||
|
|
||||||
|
let board_hash = self.board.get_hash();
|
||||||
|
|
||||||
|
let current_rep = self.repetition_table.get(&board_hash);
|
||||||
|
let mut new_rep = 1;
|
||||||
|
if current_rep.is_some() {
|
||||||
|
new_rep = current_rep.unwrap() + 1;
|
||||||
|
}
|
||||||
|
self.repetition_table.insert(board_hash, new_rep);
|
||||||
|
|
||||||
|
if self.board.status() == Checkmate {
|
||||||
|
if self.board.side_to_move() == White {
|
||||||
|
// white's move after black played the mating move
|
||||||
|
self.status = BoardStateStatus::BlackWinner;
|
||||||
|
} else {
|
||||||
|
self.status = BoardStateStatus::WhiteWinner;
|
||||||
|
}
|
||||||
|
} else if self.board.status() == Stalemate {
|
||||||
|
self.status = BoardStateStatus::Stalemate;
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.halfmove_clock >= 100 {
|
||||||
|
self.status = FiftyMove;
|
||||||
|
}
|
||||||
|
|
||||||
|
if new_rep >= 3 {
|
||||||
|
self.status = Threefold;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for BoardState {
|
||||||
|
fn default() -> BoardState {
|
||||||
|
BoardState::new(Board::default(), 0, HashMap::new())
|
||||||
|
}
|
||||||
|
}
|
||||||
2
engine/src/net.rs
Normal file
2
engine/src/net.rs
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
pub mod encoding;
|
||||||
|
pub mod model;
|
||||||
299
engine/src/net/encoding.rs
Normal file
299
engine/src/net/encoding.rs
Normal file
@ -0,0 +1,299 @@
|
|||||||
|
use crate::mcts::BoardState;
|
||||||
|
use burn::prelude::*;
|
||||||
|
use chess::{ChessMove, Color, File, Piece, Rank, Square, ALL_SQUARES};
|
||||||
|
/*
|
||||||
|
Input planes:
|
||||||
|
1-6: your pieces (Pawn, Knight, Bishop, Rook, Queen, King)
|
||||||
|
7-12: opponent pieces
|
||||||
|
13: your kingside
|
||||||
|
14: your queenside
|
||||||
|
15: opponent kingside
|
||||||
|
16: opponent queenside
|
||||||
|
17: rep >= 2
|
||||||
|
18: En passant square (gap pass over)
|
||||||
|
|
||||||
|
|
||||||
|
Output:
|
||||||
|
4672 moves represented through planes
|
||||||
|
1-56: sliding moves
|
||||||
|
57-64: Knight moves
|
||||||
|
65-73: underpromotions
|
||||||
|
*/
|
||||||
|
|
||||||
|
pub fn encode_board_state_perspective<B: Backend>(
|
||||||
|
state: &BoardState,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Tensor<B, 3> {
|
||||||
|
let mut planes = vec![0.0f32; 18 * 64];
|
||||||
|
|
||||||
|
let board = &state.board;
|
||||||
|
let us = board.side_to_move();
|
||||||
|
let them = !us;
|
||||||
|
|
||||||
|
let idx = |plane: usize, rank: usize, file: usize| -> usize { plane * 64 + rank * 8 + file };
|
||||||
|
|
||||||
|
let flip = us == Color::Black;
|
||||||
|
|
||||||
|
for &square in ALL_SQUARES.iter() {
|
||||||
|
if let Some(piece) = board.piece_on(square) {
|
||||||
|
let color = board.color_on(square).unwrap();
|
||||||
|
|
||||||
|
let piece_index = match piece {
|
||||||
|
Piece::Pawn => 0,
|
||||||
|
Piece::Knight => 1,
|
||||||
|
Piece::Bishop => 2,
|
||||||
|
Piece::Rook => 3,
|
||||||
|
Piece::Queen => 4,
|
||||||
|
Piece::King => 5,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Determine plane based on perspective
|
||||||
|
let plane = if color == us {
|
||||||
|
piece_index
|
||||||
|
} else {
|
||||||
|
piece_index + 6
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut rank = square.get_rank().to_index();
|
||||||
|
let mut file = square.get_file().to_index();
|
||||||
|
|
||||||
|
// Flip board if black to move
|
||||||
|
if flip {
|
||||||
|
rank = 7 - rank;
|
||||||
|
file = 7 - file;
|
||||||
|
}
|
||||||
|
|
||||||
|
planes[idx(plane, rank, file)] = 1.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -------------------------
|
||||||
|
// Castling Rights (planes 12–15)
|
||||||
|
// From perspective
|
||||||
|
// -------------------------
|
||||||
|
let (us_castle, them_castle) = (board.castle_rights(us), board.castle_rights(them));
|
||||||
|
|
||||||
|
if us_castle.has_kingside() {
|
||||||
|
fill_plane(&mut planes, 12);
|
||||||
|
}
|
||||||
|
if us_castle.has_queenside() {
|
||||||
|
fill_plane(&mut planes, 13);
|
||||||
|
}
|
||||||
|
if them_castle.has_kingside() {
|
||||||
|
fill_plane(&mut planes, 14);
|
||||||
|
}
|
||||||
|
if them_castle.has_queenside() {
|
||||||
|
fill_plane(&mut planes, 15);
|
||||||
|
}
|
||||||
|
|
||||||
|
// -------------------------
|
||||||
|
// Repetition plane (16)
|
||||||
|
// -------------------------
|
||||||
|
let current_hash = board.get_hash();
|
||||||
|
if let Some(count) = state.repetition_table.get(¤t_hash) {
|
||||||
|
if *count >= 2 {
|
||||||
|
fill_plane(&mut planes, 16);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let en_passant = board.en_passant();
|
||||||
|
if let Some(ep) = en_passant {
|
||||||
|
let mut rank = ep.get_rank().to_index();
|
||||||
|
let mut file = ep.get_file().to_index();
|
||||||
|
|
||||||
|
// Flip board if black to move
|
||||||
|
if flip {
|
||||||
|
rank = 7 - rank;
|
||||||
|
file = 7 - file;
|
||||||
|
}
|
||||||
|
|
||||||
|
planes[idx(17, rank, file)] = 1.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor::<B, 3>::from_floats(TensorData::new(planes, [18, 8, 8]), device)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fill_plane(buffer: &mut [f32], plane: usize) {
|
||||||
|
let start = plane * 64;
|
||||||
|
for i in 0..64 {
|
||||||
|
buffer[start + i] = 1.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn encode_move(mv: ChessMove, side_to_move: Color) -> usize {
|
||||||
|
let from = mv.get_source().to_index();
|
||||||
|
let to = mv.get_dest().to_index();
|
||||||
|
|
||||||
|
let mut from_rank = from / 8;
|
||||||
|
let mut from_file = from % 8;
|
||||||
|
let mut to_rank = to / 8;
|
||||||
|
let mut to_file = to % 8;
|
||||||
|
|
||||||
|
if side_to_move == Color::Black {
|
||||||
|
from_rank = 7 - from_rank;
|
||||||
|
from_file = 7 - from_file;
|
||||||
|
|
||||||
|
to_rank = 7 - to_rank;
|
||||||
|
to_file = 7 - to_file;
|
||||||
|
}
|
||||||
|
|
||||||
|
let delta_rank = to_rank as i32 - from_rank as i32;
|
||||||
|
let delta_file = to_file as i32 - from_file as i32;
|
||||||
|
|
||||||
|
let plane = encode_move_type(delta_rank, delta_file, mv.get_promotion());
|
||||||
|
|
||||||
|
plane * 64 + (from_rank * 8 + from_file)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encode_move_type(dr: i32, df: i32, promotion: Option<Piece>) -> usize {
|
||||||
|
// Knight moves
|
||||||
|
const KNIGHT_DELTAS: [(i32, i32); 8] = [
|
||||||
|
(2, 1),
|
||||||
|
(1, 2),
|
||||||
|
(-1, 2),
|
||||||
|
(-2, 1),
|
||||||
|
(-2, -1),
|
||||||
|
(-1, -2),
|
||||||
|
(1, -2),
|
||||||
|
(2, -1),
|
||||||
|
];
|
||||||
|
|
||||||
|
for (i, (r, f)) in KNIGHT_DELTAS.iter().enumerate() {
|
||||||
|
if dr == *r && df == *f {
|
||||||
|
return 56 + i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UNDERPromotions
|
||||||
|
if let Some(promo) = promotion {
|
||||||
|
if promo != Piece::Queen {
|
||||||
|
let dir = if df == 0 {
|
||||||
|
0
|
||||||
|
} else if df < 0 {
|
||||||
|
1
|
||||||
|
} else {
|
||||||
|
2
|
||||||
|
};
|
||||||
|
|
||||||
|
let piece_index = match promo {
|
||||||
|
// Piece::Queen => 0,
|
||||||
|
Piece::Rook => 0,
|
||||||
|
Piece::Bishop => 1,
|
||||||
|
Piece::Knight => 2,
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
return 64 + dir * 3 + piece_index;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sliding
|
||||||
|
let direction_index = match (dr.signum(), df.signum()) {
|
||||||
|
(1, 0) => 0, // N
|
||||||
|
(1, 1) => 1,
|
||||||
|
(0, 1) => 2,
|
||||||
|
(-1, 1) => 3,
|
||||||
|
(-1, 0) => 4,
|
||||||
|
(-1, -1) => 5,
|
||||||
|
(0, -1) => 6,
|
||||||
|
(1, -1) => 7,
|
||||||
|
_ => panic!("Invalid move delta"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let distance = dr.abs().max(df.abs()) as usize - 1;
|
||||||
|
|
||||||
|
direction_index * 7 + distance
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decode_move(index: usize, side_to_move: Color) -> ChessMove {
|
||||||
|
let from_index = index % 64;
|
||||||
|
let plane = index / 64;
|
||||||
|
|
||||||
|
// Perspective-space coordinates
|
||||||
|
let mut from_rank = from_index / 8;
|
||||||
|
let mut from_file = from_index % 8;
|
||||||
|
|
||||||
|
let (mut dr, mut df, promotion) = decode_move_type(plane);
|
||||||
|
|
||||||
|
// Convert from perspective coordinates back to absolute board coordinates
|
||||||
|
if side_to_move == Color::Black {
|
||||||
|
from_rank = 7 - from_rank;
|
||||||
|
from_file = 7 - from_file;
|
||||||
|
|
||||||
|
dr = -dr;
|
||||||
|
df = -df;
|
||||||
|
}
|
||||||
|
|
||||||
|
let to_rank = (from_rank as i32 + dr) as usize;
|
||||||
|
let to_file = (from_file as i32 + df) as usize;
|
||||||
|
|
||||||
|
let from = Square::make_square(Rank::from_index(from_rank), File::from_index(from_file));
|
||||||
|
|
||||||
|
let to = Square::make_square(Rank::from_index(to_rank), File::from_index(to_file));
|
||||||
|
|
||||||
|
ChessMove::new(from, to, promotion)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn decode_move_type(plane: usize) -> (i32, i32, Option<Piece>) {
|
||||||
|
// Knight moves
|
||||||
|
const KNIGHT_DELTAS: [(i32, i32); 8] = [
|
||||||
|
(2, 1),
|
||||||
|
(1, 2),
|
||||||
|
(-1, 2),
|
||||||
|
(-2, 1),
|
||||||
|
(-2, -1),
|
||||||
|
(-1, -2),
|
||||||
|
(1, -2),
|
||||||
|
(2, -1),
|
||||||
|
];
|
||||||
|
|
||||||
|
// 0–55: sliding moves
|
||||||
|
if plane < 56 {
|
||||||
|
let direction = plane / 7;
|
||||||
|
let distance = (plane % 7) + 1;
|
||||||
|
|
||||||
|
let (dr, df) = match direction {
|
||||||
|
0 => (1, 0),
|
||||||
|
1 => (1, 1),
|
||||||
|
2 => (0, 1),
|
||||||
|
3 => (-1, 1),
|
||||||
|
4 => (-1, 0),
|
||||||
|
5 => (-1, -1),
|
||||||
|
6 => (0, -1),
|
||||||
|
7 => (1, -1),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
return (dr * distance as i32, df * distance as i32, None);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 56–63: knight moves
|
||||||
|
if plane < 64 {
|
||||||
|
let (dr, df) = KNIGHT_DELTAS[plane - 56];
|
||||||
|
return (dr, df, None);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 64–72: underpromotions
|
||||||
|
let promo_plane = plane - 64;
|
||||||
|
|
||||||
|
let dir = promo_plane / 3;
|
||||||
|
let piece_index = promo_plane % 3;
|
||||||
|
|
||||||
|
let df = match dir {
|
||||||
|
0 => 0,
|
||||||
|
1 => -1,
|
||||||
|
2 => 1,
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let dr = 1; // always forward (important: assumes white perspective)
|
||||||
|
|
||||||
|
let promotion = Some(match piece_index {
|
||||||
|
0 => Piece::Rook,
|
||||||
|
1 => Piece::Bishop,
|
||||||
|
2 => Piece::Knight,
|
||||||
|
_ => unreachable!(),
|
||||||
|
});
|
||||||
|
|
||||||
|
(dr, df, promotion)
|
||||||
|
}
|
||||||
338
engine/src/net/model.rs
Normal file
338
engine/src/net/model.rs
Normal file
@ -0,0 +1,338 @@
|
|||||||
|
use crate::mcts::{BoardState, MctsResults};
|
||||||
|
use crate::net::encoding::{encode_board_state_perspective, encode_move};
|
||||||
|
use burn::data::dataloader::batcher::Batcher;
|
||||||
|
use burn::nn::conv::Conv2dConfig;
|
||||||
|
use burn::nn::loss::{MseLoss, Reduction};
|
||||||
|
use burn::nn::{BatchNorm, BatchNormConfig, LinearConfig, PaddingConfig2d};
|
||||||
|
use burn::tensor::activation::log_softmax;
|
||||||
|
use burn::tensor::backend::AutodiffBackend;
|
||||||
|
use burn::tensor::Transaction;
|
||||||
|
use burn::train::{InferenceStep, ItemLazy, TrainOutput, TrainStep};
|
||||||
|
use burn::{
|
||||||
|
nn::{conv::Conv2d, Linear, Relu},
|
||||||
|
prelude::*,
|
||||||
|
};
|
||||||
|
use burn_ndarray::NdArray;
|
||||||
|
use chess::ChessMove;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
/*
|
||||||
|
Input planes:
|
||||||
|
1-6: your pieces (Pawn, Knight, Bishop, Rook, Queen, King)
|
||||||
|
7-12: opponent pieces
|
||||||
|
13: your kingside
|
||||||
|
14: your queenside
|
||||||
|
15: opponent kingside
|
||||||
|
16: opponent queenside
|
||||||
|
17: rep >= 2
|
||||||
|
18: En passant square (gap pass over)
|
||||||
|
|
||||||
|
|
||||||
|
Output:
|
||||||
|
4672 moves represented through planes
|
||||||
|
1-56: sliding moves
|
||||||
|
57-64: Knight moves
|
||||||
|
65-73: underpromotions
|
||||||
|
*/
|
||||||
|
|
||||||
|
#[derive(Module, Debug)]
|
||||||
|
pub struct ResidualBlock<B: Backend> {
|
||||||
|
conv1: Conv2d<B>,
|
||||||
|
bn1: BatchNorm<B>,
|
||||||
|
conv2: Conv2d<B>,
|
||||||
|
bn2: BatchNorm<B>,
|
||||||
|
activation: Relu,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Config, Debug)]
|
||||||
|
pub struct ResidualBlockConfig {
|
||||||
|
channels: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> ResidualBlock<B> {
|
||||||
|
pub fn new(channels: usize, device: &B::Device) -> Self {
|
||||||
|
Self {
|
||||||
|
conv1: Conv2dConfig::new([channels, channels], [3, 3])
|
||||||
|
.with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1))
|
||||||
|
.init(device),
|
||||||
|
bn1: BatchNormConfig::new(channels).init(device),
|
||||||
|
conv2: Conv2dConfig::new([channels, channels], [3, 3])
|
||||||
|
.with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1))
|
||||||
|
.init(device),
|
||||||
|
bn2: BatchNormConfig::new(channels).init(device),
|
||||||
|
activation: Relu::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||||
|
let residual = x.clone();
|
||||||
|
|
||||||
|
let x = self.conv1.forward(x);
|
||||||
|
let x = self.bn1.forward(x);
|
||||||
|
let x = self.activation.forward(x);
|
||||||
|
|
||||||
|
let x = self.conv2.forward(x);
|
||||||
|
let x = self.bn2.forward(x);
|
||||||
|
|
||||||
|
self.activation.forward(x + residual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Module, Debug)]
|
||||||
|
pub struct ChessModel<B: Backend> {
|
||||||
|
// Initial convolution
|
||||||
|
conv: Conv2d<B>,
|
||||||
|
bn: BatchNorm<B>,
|
||||||
|
|
||||||
|
// Residual tower
|
||||||
|
residual_blocks: Vec<ResidualBlock<B>>,
|
||||||
|
|
||||||
|
// Policy head
|
||||||
|
policy_conv: Conv2d<B>,
|
||||||
|
policy_bn: BatchNorm<B>,
|
||||||
|
policy_fc: Linear<B>,
|
||||||
|
|
||||||
|
// Value head
|
||||||
|
value_conv: Conv2d<B>,
|
||||||
|
value_bn: BatchNorm<B>,
|
||||||
|
value_fc1: Linear<B>,
|
||||||
|
value_fc2: Linear<B>,
|
||||||
|
|
||||||
|
activation: Relu,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Config, Debug)]
|
||||||
|
pub struct ChessModelConfig {
|
||||||
|
num_blocks: usize,
|
||||||
|
channels: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChessModelConfig {
|
||||||
|
pub fn init<B: Backend>(
|
||||||
|
num_blocks: usize,
|
||||||
|
channels: usize,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> ChessModel<B> {
|
||||||
|
ChessModel {
|
||||||
|
conv: Conv2dConfig::new([18, channels], [3, 3]) // 18 plane input
|
||||||
|
.with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1))
|
||||||
|
.init(device),
|
||||||
|
bn: BatchNormConfig::new(channels).init(device),
|
||||||
|
|
||||||
|
residual_blocks: (0..num_blocks)
|
||||||
|
.map(|_| ResidualBlock::new(channels, device))
|
||||||
|
.collect(),
|
||||||
|
|
||||||
|
// Policy head
|
||||||
|
policy_conv: Conv2dConfig::new([channels, 2], [1, 1]).init(device),
|
||||||
|
policy_bn: BatchNormConfig::new(2).init(device),
|
||||||
|
policy_fc: LinearConfig::new(2 * 8 * 8, 8 * 8 * 73).init(device), // 4672 typical chess move space
|
||||||
|
|
||||||
|
// Value head
|
||||||
|
value_conv: Conv2dConfig::new([channels, 1], [1, 1]).init(device),
|
||||||
|
value_bn: BatchNormConfig::new(1).init(device),
|
||||||
|
value_fc1: LinearConfig::new(1 * 8 * 8, 256).init(device),
|
||||||
|
value_fc2: LinearConfig::new(256, 1).init(device),
|
||||||
|
|
||||||
|
activation: Relu::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> ChessModel<B> {
|
||||||
|
pub fn forward(&self, x: Tensor<B, 4>) -> (Tensor<B, 2>, Tensor<B, 2>) {
|
||||||
|
let mut x = self.conv.forward(x);
|
||||||
|
x = self.bn.forward(x);
|
||||||
|
x = self.activation.forward(x);
|
||||||
|
|
||||||
|
for block in &self.residual_blocks {
|
||||||
|
x = block.forward(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
let batch_size = x.dims()[0];
|
||||||
|
|
||||||
|
// -------- Policy Head --------
|
||||||
|
let mut p = self.policy_conv.forward(x.clone());
|
||||||
|
p = self.policy_bn.forward(p);
|
||||||
|
p = self.activation.forward(p);
|
||||||
|
let mut p = p.reshape([batch_size, 2 * 8 * 8]);
|
||||||
|
p = self.policy_fc.forward(p);
|
||||||
|
|
||||||
|
// -------- Value Head --------
|
||||||
|
let mut v = self.value_conv.forward(x);
|
||||||
|
v = self.value_bn.forward(v);
|
||||||
|
v = self.activation.forward(v);
|
||||||
|
let mut v = v.reshape([batch_size, 8 * 8]);
|
||||||
|
v = self.activation.forward(self.value_fc1.forward(v));
|
||||||
|
v = self.value_fc2.forward(v).tanh();
|
||||||
|
|
||||||
|
(p, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward_chess(
|
||||||
|
&self,
|
||||||
|
boards: Tensor<B, 4>, // e.g. [batch, channels, 8, 8]
|
||||||
|
policy_targets: Tensor<B, 2>, // move distribution (make sure is normalized)
|
||||||
|
value_targets: Tensor<B, 2>, // scalar evaluation
|
||||||
|
) -> ChessOutput<B> {
|
||||||
|
let (policy_logits, value) = self.forward(boards);
|
||||||
|
|
||||||
|
let log_probs = log_softmax(policy_logits.clone(), 1);
|
||||||
|
let policy_loss = policy_targets
|
||||||
|
.clone()
|
||||||
|
.mul(log_probs)
|
||||||
|
.neg()
|
||||||
|
.sum_dim(1)
|
||||||
|
.mean();
|
||||||
|
|
||||||
|
let value_loss =
|
||||||
|
MseLoss::new().forward(value.clone(), value_targets.clone(), Reduction::Mean);
|
||||||
|
|
||||||
|
let total_loss = policy_loss + 0.5 * value_loss;
|
||||||
|
|
||||||
|
ChessOutput {
|
||||||
|
policy_logits,
|
||||||
|
value,
|
||||||
|
policy_targets,
|
||||||
|
value_targets,
|
||||||
|
loss: total_loss,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ChessOutput<B: Backend> {
|
||||||
|
pub policy_logits: Tensor<B, 2>, // [sample, num_moves (4672)]
|
||||||
|
pub value: Tensor<B, 2>, // [sample, 1 value]
|
||||||
|
pub policy_targets: Tensor<B, 2>,
|
||||||
|
pub value_targets: Tensor<B, 2>, // kept 2d for consistency?
|
||||||
|
pub loss: Tensor<B, 1>, // [sample]
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> ItemLazy for ChessOutput<B> {
|
||||||
|
type ItemSync = ChessOutput<NdArray>;
|
||||||
|
|
||||||
|
fn sync(self) -> Self::ItemSync {
|
||||||
|
let [policy_logits, value, policy_targets, value_targets, loss] = Transaction::default()
|
||||||
|
.register(self.policy_logits)
|
||||||
|
.register(self.value)
|
||||||
|
.register(self.policy_targets)
|
||||||
|
.register(self.value_targets)
|
||||||
|
.register(self.loss)
|
||||||
|
.execute()
|
||||||
|
.try_into()
|
||||||
|
.expect("Correct amount of tensor data");
|
||||||
|
|
||||||
|
let device = &Default::default();
|
||||||
|
|
||||||
|
ChessOutput {
|
||||||
|
policy_logits: Tensor::from_data(policy_logits, device),
|
||||||
|
value: Tensor::from_data(value, device),
|
||||||
|
policy_targets: Tensor::from_data(policy_targets, device),
|
||||||
|
value_targets: Tensor::from_data(value_targets, device),
|
||||||
|
loss: Tensor::from_data(loss, device),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: AutodiffBackend> TrainStep for ChessModel<B> {
|
||||||
|
type Input = ChessBatch<B>;
|
||||||
|
type Output = ChessOutput<B>;
|
||||||
|
|
||||||
|
fn step(&self, batch: ChessBatch<B>) -> TrainOutput<ChessOutput<B>> {
|
||||||
|
let item = self.forward_chess(batch.states, batch.policy_targets, batch.value_targets);
|
||||||
|
|
||||||
|
TrainOutput::new(self, item.loss.backward(), item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> InferenceStep for ChessModel<B> {
|
||||||
|
type Input = ChessBatch<B>;
|
||||||
|
type Output = ChessOutput<B>;
|
||||||
|
|
||||||
|
fn step(&self, batch: ChessBatch<B>) -> ChessOutput<B> {
|
||||||
|
self.forward_chess(batch.states, batch.policy_targets, batch.value_targets)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct TrainingSample {
|
||||||
|
pub board_state: BoardState,
|
||||||
|
pub policy_target: HashMap<ChessMove, f32>,
|
||||||
|
pub value_target: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TrainingSample {
|
||||||
|
pub fn new(
|
||||||
|
board_state: BoardState,
|
||||||
|
policy_target: HashMap<ChessMove, f32>,
|
||||||
|
value_target: f32,
|
||||||
|
) -> Self {
|
||||||
|
TrainingSample {
|
||||||
|
board_state,
|
||||||
|
policy_target,
|
||||||
|
value_target,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_mcts_with_outcome(mcts_results: MctsResults, outcome: f32) -> Self {
|
||||||
|
TrainingSample::new(mcts_results.board_state, mcts_results.move_dist, outcome)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Default)]
|
||||||
|
pub struct ChessBatcher {}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct ChessBatch<B: Backend> {
|
||||||
|
pub states: Tensor<B, 4>,
|
||||||
|
pub policy_targets: Tensor<B, 2>,
|
||||||
|
pub value_targets: Tensor<B, 2>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> Batcher<B, TrainingSample, ChessBatch<B>> for ChessBatcher {
|
||||||
|
fn batch(&self, items: Vec<TrainingSample>, device: &B::Device) -> ChessBatch<B> {
|
||||||
|
let state_tensors = items
|
||||||
|
.iter()
|
||||||
|
.map(|item| {
|
||||||
|
encode_board_state_perspective(&item.board_state, device).reshape([1, 18, 8, 8])
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let policy_target_tensors = items
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.map(|item| {
|
||||||
|
let mut policy = vec![0.0f32; 4672];
|
||||||
|
let stm = item.board_state.board.side_to_move();
|
||||||
|
for (mv, prob) in item.policy_target {
|
||||||
|
policy[encode_move(mv, stm)] = prob;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize
|
||||||
|
let sum: f32 = policy.iter().sum();
|
||||||
|
if sum > 0.0 {
|
||||||
|
for p in &mut policy {
|
||||||
|
*p /= sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Tensor::<B, 2>::from_floats(TensorData::new(policy, [4672]), device).unsqueeze()
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let value_target_tensors = items
|
||||||
|
.iter()
|
||||||
|
.map(|item| {
|
||||||
|
Tensor::<B, 2>::from_floats(TensorData::new(vec![item.value_target], [1]), device)
|
||||||
|
.reshape([1, 1])
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let states = Tensor::cat(state_tensors, 0); // [B, 18, 8, 8]
|
||||||
|
let policy_targets = Tensor::cat(policy_target_tensors, 0); // [B, 4672]
|
||||||
|
let value_targets = Tensor::cat(value_target_tensors, 0); // [B, 1]
|
||||||
|
|
||||||
|
ChessBatch {
|
||||||
|
states,
|
||||||
|
policy_targets,
|
||||||
|
value_targets,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
59
engine/src/rl/classification.rs
Normal file
59
engine/src/rl/classification.rs
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
/// Multi-label classification output adapted for multiple metrics.
|
||||||
|
///
|
||||||
|
/// Supported metrics:
|
||||||
|
/// - HammingScore
|
||||||
|
/// - Precision (via ConfusionStatsInput)
|
||||||
|
/// - Recall (via ConfusionStatsInput)
|
||||||
|
/// - FBetaScore (via ConfusionStatsInput)
|
||||||
|
/// - Loss
|
||||||
|
#[derive(new)]
|
||||||
|
pub struct MultiLabelSoftClassificationOutput<B: Backend> {
|
||||||
|
/// The loss.
|
||||||
|
pub loss: Tensor<B, 1>,
|
||||||
|
|
||||||
|
/// The label logits or probabilities. Shape: \[batch_size, num_classes\].
|
||||||
|
pub output: Tensor<B, 2>,
|
||||||
|
|
||||||
|
/// The ground truth labels (target values). Shape: \[batch_size, num_classes\].
|
||||||
|
pub targets: Tensor<B, 2>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> ItemLazy for MultiLabelSoftClassificationOutput<B> {
|
||||||
|
type ItemSync = MultiLabelSoftClassificationOutput<NdArray>;
|
||||||
|
|
||||||
|
fn sync(self) -> Self::ItemSync {
|
||||||
|
let [output, loss, targets] = Transaction::default()
|
||||||
|
.register(self.output)
|
||||||
|
.register(self.loss)
|
||||||
|
.register(self.targets)
|
||||||
|
.execute()
|
||||||
|
.try_into()
|
||||||
|
.expect("Correct amount of tensor data");
|
||||||
|
|
||||||
|
let device = &Default::default();
|
||||||
|
|
||||||
|
MultiLabelSoftClassificationOutput {
|
||||||
|
output: Tensor::from_data(output, device),
|
||||||
|
loss: Tensor::from_data(loss, device),
|
||||||
|
targets: Tensor::from_data(targets, device),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> Adaptor<HammingScoreInput<B>> for MultiLabelSoftClassificationOutput<B> {
|
||||||
|
fn adapt(&self) -> HammingScoreInput<B> {
|
||||||
|
HammingScoreInput::new(self.output.clone(), self.targets.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> Adaptor<LossInput<B>> for MultiLabelSoftClassificationOutput<B> {
|
||||||
|
fn adapt(&self) -> LossInput<B> {
|
||||||
|
LossInput::new(self.loss.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> Adaptor<ConfusionStatsInput<B>> for MultiLabelSoftClassificationOutput<B> {
|
||||||
|
fn adapt(&self) -> ConfusionStatsInput<B> {
|
||||||
|
ConfusionStatsInput::new(self.output.clone(), self.targets.clone().bool())
|
||||||
|
}
|
||||||
|
}
|
||||||
134
engine/src/train.rs
Normal file
134
engine/src/train.rs
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
// use crate::model::{ChessBatcher, ChessModel, ChessModelConfig};
|
||||||
|
// use burn::config::Config;
|
||||||
|
// use burn::data::dataloader::batcher::Batcher;
|
||||||
|
// use burn::optim::{Adam, AdamConfig, Optimizer, SimpleOptimizer};
|
||||||
|
// use burn::optim::adaptor::OptimizerAdaptor;
|
||||||
|
// use burn::prelude::Backend;
|
||||||
|
// use burn::tensor::backend::AutodiffBackend;
|
||||||
|
// use burn::train::TrainStep;
|
||||||
|
// use crate::mcts::MctsResults;
|
||||||
|
//
|
||||||
|
// #[derive(Config, Debug)]
|
||||||
|
// pub struct ChessTrainerConfig {
|
||||||
|
// pub model: ChessModelConfig,
|
||||||
|
// pub optimizer: AdamConfig,
|
||||||
|
// #[config(default = 10)]
|
||||||
|
// pub num_epochs: usize,
|
||||||
|
// #[config(default = 64)]
|
||||||
|
// pub batch_size: usize,
|
||||||
|
// #[config(default = 4)]
|
||||||
|
// pub num_workers: usize,
|
||||||
|
// #[config(default = 42)]
|
||||||
|
// pub seed: u64,
|
||||||
|
// #[config(default = 1.0e-4)]
|
||||||
|
// pub learning_rate: f64,
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// impl ChessTrainerConfig {
|
||||||
|
// pub fn init<B: Backend>(
|
||||||
|
// model_config: ChessModelConfig,
|
||||||
|
// optimizer: AdamConfig,
|
||||||
|
// num_epochs: usize,
|
||||||
|
// batch_size: usize,
|
||||||
|
// num_workers: usize,
|
||||||
|
// seed: u64,
|
||||||
|
// learning_rate: f64,
|
||||||
|
// ) -> ChessTrainer<B> {
|
||||||
|
// ChessTrainer {
|
||||||
|
// model: model_config::init(),
|
||||||
|
// optimizer: optimizer.init(),
|
||||||
|
// num_epochs,
|
||||||
|
//
|
||||||
|
//
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// pub struct ChessTrainer<B: AutodiffBackend> {
|
||||||
|
// pub model: ChessModel<B>,
|
||||||
|
// pub optimizer: Adam,
|
||||||
|
// learning_rate: f64,
|
||||||
|
// pub batcher: ChessBatcher,
|
||||||
|
// pub device: B::Device,
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// impl<B: AutodiffBackend> ChessTrainer<B> {
|
||||||
|
// pub fn new(model: ChessModel<B>, device: B::Device) -> Self {
|
||||||
|
// let optimizer = AdamConfig::new().init();
|
||||||
|
//
|
||||||
|
// Self {
|
||||||
|
// model,
|
||||||
|
// optimizer,
|
||||||
|
// batcher: ChessBatcher::default(),
|
||||||
|
// device,
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// pub fn train_step(&mut self, batch_data: Vec<MctsResults>) -> f32 {
|
||||||
|
// // 1. Convert to tensors
|
||||||
|
// let batch = self.batcher.batch(batch_data, &self.device);
|
||||||
|
//
|
||||||
|
// // 2. Forward + backward (your TrainStep impl)
|
||||||
|
// let output = self.model.step(batch);
|
||||||
|
//
|
||||||
|
// // 3. Update weights
|
||||||
|
// self.optimizer.step(, &mut self.model, output.grads);
|
||||||
|
//
|
||||||
|
// // 4. Return loss (for logging)
|
||||||
|
// let loss_tensor = output.item.loss.clone().into_data();
|
||||||
|
// let loss = loss_tensor.to_vec::<f32>().unwrap()[0];
|
||||||
|
//
|
||||||
|
// loss
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// fn create_artifact_dir(artifact_dir: &str) {
|
||||||
|
// // Remove existing artifacts before to get an accurate learner summary
|
||||||
|
// std::fs::remove_dir_all(artifact_dir).ok();
|
||||||
|
// std::fs::create_dir_all(artifact_dir).ok();
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// pub fn train<B: AutodiffBackend>(
|
||||||
|
// artifact_dir: &str,
|
||||||
|
// config: ChessTrainingConfig,
|
||||||
|
// device: B::Device,
|
||||||
|
// ) {
|
||||||
|
// create_artifact_dir(artifact_dir);
|
||||||
|
// config
|
||||||
|
// .save(format!("{artifact_dir}/config.json"))
|
||||||
|
// .expect("Config should be saved successfully");
|
||||||
|
//
|
||||||
|
// B::seed(&device, config.seed);
|
||||||
|
//
|
||||||
|
// let batcher = ChessBatcher::default();
|
||||||
|
|
||||||
|
// let dataloader_train = DataLoaderBuilder::new(batcher.clone())
|
||||||
|
// .batch_size(config.batch_size)
|
||||||
|
// .shuffle(config.seed)
|
||||||
|
// .num_workers(config.num_workers)
|
||||||
|
// .build(MnistDataset::train());
|
||||||
|
//
|
||||||
|
// let dataloader_test = DataLoaderBuilder::new(batcher)
|
||||||
|
// .batch_size(config.batch_size)
|
||||||
|
// .shuffle(config.seed)
|
||||||
|
// .num_workers(config.num_workers)
|
||||||
|
// .build(MnistDataset::test());
|
||||||
|
//
|
||||||
|
// let training = SupervisedTraining::new(artifact_dir, dataloader_train, dataloader_test)
|
||||||
|
// .metrics((AccuracyMetric::new(), LossMetric::new()))
|
||||||
|
// .with_file_checkpointer(CompactRecorder::new())
|
||||||
|
// .num_epochs(config.num_epochs)
|
||||||
|
// .summary();
|
||||||
|
//
|
||||||
|
// let model = config.model.init::<B>(&device);
|
||||||
|
// let result = training.launch(Learner::new(
|
||||||
|
// model,
|
||||||
|
// config.optimizer.init(),
|
||||||
|
// config.learning_rate,
|
||||||
|
// ));
|
||||||
|
//
|
||||||
|
// result
|
||||||
|
// .model
|
||||||
|
// .save_file(format!("{artifact_dir}/model"), &CompactRecorder::new())
|
||||||
|
// .expect("Trained model should be saved successfully");
|
||||||
|
// }
|
||||||
1
engine/src/training.rs
Normal file
1
engine/src/training.rs
Normal file
@ -0,0 +1 @@
|
|||||||
|
pub mod train;
|
||||||
236
engine/src/training/train.rs
Normal file
236
engine/src/training/train.rs
Normal file
@ -0,0 +1,236 @@
|
|||||||
|
use crate::mcts::{BoardState, BoardStateStatus, Mcts, MctsConfig, MctsResults};
|
||||||
|
use crate::net::model::{ChessBatcher, ChessModel, ChessModelConfig, TrainingSample};
|
||||||
|
use burn::data::dataloader::batcher::Batcher;
|
||||||
|
use burn::module::{AutodiffModule, Module};
|
||||||
|
use burn::optim::{AdamConfig, GradientsParams, Optimizer};
|
||||||
|
use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder};
|
||||||
|
use burn::tensor::backend::AutodiffBackend;
|
||||||
|
use chess::ChessMove;
|
||||||
|
use rand::rngs::ThreadRng;
|
||||||
|
use rand::seq::SliceRandom;
|
||||||
|
use rand::RngExt;
|
||||||
|
use std::collections::{HashMap, VecDeque};
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
pub struct TrainingConfig {
|
||||||
|
pub max_time_s: Option<u64>,
|
||||||
|
pub num_iters: Option<u32>,
|
||||||
|
pub max_depth: u16, // unused
|
||||||
|
pub model_name: String,
|
||||||
|
pub load_model: bool,
|
||||||
|
pub hidden_channels: usize,
|
||||||
|
pub num_blocks: usize,
|
||||||
|
pub batch_size: usize, // num positions sampled to a batch (2048)
|
||||||
|
pub num_episodes: usize, // num games generated per iteration (5000)
|
||||||
|
pub buffer_max_size: usize, // max number of samples in the buffer
|
||||||
|
pub mcts_config: MctsConfig,
|
||||||
|
pub optimizer: AdamConfig,
|
||||||
|
pub lr: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn train<B: AutodiffBackend>(training_config: TrainingConfig, device: B::Device) {
|
||||||
|
let model_path = format!("artifacts/{}", training_config.model_name.as_str());
|
||||||
|
println!("Creating model...");
|
||||||
|
let mut model: ChessModel<B> = ChessModelConfig::init(
|
||||||
|
training_config.hidden_channels,
|
||||||
|
training_config.num_blocks,
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
if training_config.load_model {
|
||||||
|
println!("Loading model {}...", model_path);
|
||||||
|
// Load model in full precision from MessagePack file
|
||||||
|
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
|
||||||
|
model = model
|
||||||
|
.load_file(&model_path, &recorder, &device)
|
||||||
|
.expect("Should be able to load the model weights from the provided file");
|
||||||
|
}
|
||||||
|
|
||||||
|
let train = Arc::new(AtomicBool::new(true));
|
||||||
|
let train_signal = Arc::clone(&train);
|
||||||
|
|
||||||
|
let mut iter: u32 = 0;
|
||||||
|
let start_time = Instant::now();
|
||||||
|
|
||||||
|
ctrlc::set_handler(move || {
|
||||||
|
train_signal.store(false, Ordering::Relaxed);
|
||||||
|
println!("Finishing batch before exiting...");
|
||||||
|
})
|
||||||
|
.expect("Error setting Ctrl-C handler");
|
||||||
|
|
||||||
|
let mut replay_buffer: VecDeque<TrainingSample> = VecDeque::new();
|
||||||
|
|
||||||
|
let mut mcts: Mcts<B::InnerBackend> = Mcts {
|
||||||
|
config: training_config.mcts_config,
|
||||||
|
_marker: PhantomData,
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut rng = rand::rng();
|
||||||
|
|
||||||
|
println!("Starting training...");
|
||||||
|
while train.load(Ordering::Relaxed) {
|
||||||
|
let infer_model = model.valid();
|
||||||
|
// Gen samples
|
||||||
|
println!("Generating {} games...", training_config.num_episodes);
|
||||||
|
for episode in 0..training_config.num_episodes {
|
||||||
|
println!("Episode: {}", episode);
|
||||||
|
let mut board_state = BoardState::default();
|
||||||
|
let mut episode_buffer: Vec<MctsResults> = vec![];
|
||||||
|
|
||||||
|
while board_state.status == BoardStateStatus::Ongoing {
|
||||||
|
let results = mcts.search(&board_state, &infer_model, &device);
|
||||||
|
episode_buffer.push(results);
|
||||||
|
|
||||||
|
let temp = if board_state.halfmove_clock < 30 {
|
||||||
|
1.0
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
|
||||||
|
let adjusted = apply_temperature(&episode_buffer.last().unwrap().move_dist, temp);
|
||||||
|
|
||||||
|
let mv = sample_move(&adjusted, &mut rng).unwrap();
|
||||||
|
println!("playing move: {}", mv);
|
||||||
|
board_state.apply_move(mv)
|
||||||
|
}
|
||||||
|
|
||||||
|
for result in episode_buffer.iter().enumerate() {
|
||||||
|
if board_state.status == BoardStateStatus::Stalemate
|
||||||
|
|| board_state.status == BoardStateStatus::Threefold
|
||||||
|
|| board_state.status == BoardStateStatus::FiftyMove
|
||||||
|
{
|
||||||
|
replay_buffer.push_back(TrainingSample::from_mcts_with_outcome(
|
||||||
|
result.1.clone(),
|
||||||
|
0.0,
|
||||||
|
));
|
||||||
|
} else if board_state.status == BoardStateStatus::WhiteWinner {
|
||||||
|
replay_buffer.push_back(TrainingSample::from_mcts_with_outcome(
|
||||||
|
result.1.clone(),
|
||||||
|
((result.0 % 2) as f32 * -2.0) + 1.0,
|
||||||
|
));
|
||||||
|
} else if board_state.status == BoardStateStatus::BlackWinner {
|
||||||
|
replay_buffer.push_back(TrainingSample::from_mcts_with_outcome(
|
||||||
|
result.1.clone(),
|
||||||
|
((result.0 % 2) as f32 * 2.0) - 1.0,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
if replay_buffer.len() > training_config.buffer_max_size {
|
||||||
|
replay_buffer.pop_front();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// train
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"Finished! Training on {} samples...",
|
||||||
|
training_config.batch_size
|
||||||
|
);
|
||||||
|
let mut indices: Vec<usize> = (0..replay_buffer.len()).collect();
|
||||||
|
indices.shuffle(&mut rng);
|
||||||
|
|
||||||
|
let samples = indices
|
||||||
|
.into_iter()
|
||||||
|
.take(training_config.batch_size)
|
||||||
|
.map(|i| replay_buffer[i].clone())
|
||||||
|
.collect::<Vec<TrainingSample>>();
|
||||||
|
|
||||||
|
let batcher = ChessBatcher {};
|
||||||
|
|
||||||
|
let batch = batcher.batch(samples, &device);
|
||||||
|
|
||||||
|
let mut optim = training_config.optimizer.init();
|
||||||
|
|
||||||
|
let output = model.forward_chess(batch.states, batch.policy_targets, batch.value_targets);
|
||||||
|
|
||||||
|
let grads = output.loss.backward();
|
||||||
|
let grads = GradientsParams::from_grads(grads, &model);
|
||||||
|
|
||||||
|
model = optim.step(training_config.lr, model, grads);
|
||||||
|
|
||||||
|
iter += 1;
|
||||||
|
|
||||||
|
if iter % 100 == 0 {
|
||||||
|
println!("Completed {} iterations", iter);
|
||||||
|
}
|
||||||
|
|
||||||
|
if training_config.max_time_s.is_some()
|
||||||
|
&& start_time.elapsed().as_secs() > training_config.max_time_s.unwrap()
|
||||||
|
{
|
||||||
|
println!("Training stopping due to time limit...");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if training_config.num_iters.is_some() && iter >= training_config.num_iters.unwrap() {
|
||||||
|
println!(
|
||||||
|
"Training stopping due to iteration limit ({} iters completed)...",
|
||||||
|
iter
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("Saving model...");
|
||||||
|
// Save model in MessagePack format with full precision
|
||||||
|
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
|
||||||
|
model
|
||||||
|
.save_file(&model_path, &recorder)
|
||||||
|
.expect("Should be able to save the model");
|
||||||
|
println!("Model saved in {:?}, exiting training.", model_path);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_temperature(
|
||||||
|
visits: &HashMap<ChessMove, f32>,
|
||||||
|
temperature: f32,
|
||||||
|
) -> HashMap<ChessMove, f32> {
|
||||||
|
if visits.is_empty() {
|
||||||
|
return HashMap::new();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Special case: deterministic selection
|
||||||
|
if temperature == 0.0 {
|
||||||
|
let (&best_move, _) = visits
|
||||||
|
.iter()
|
||||||
|
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let mut out = HashMap::new();
|
||||||
|
out.insert(best_move.clone(), 1.0);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
let inv_temp = 1.0 / temperature;
|
||||||
|
|
||||||
|
// Step 1: apply exponent
|
||||||
|
let mut adjusted: HashMap<ChessMove, f32> =
|
||||||
|
visits.iter().map(|(m, v)| (*m, v.powf(inv_temp))).collect();
|
||||||
|
|
||||||
|
// Step 2: normalize
|
||||||
|
let sum: f32 = adjusted.values().sum();
|
||||||
|
|
||||||
|
if sum <= 0.0 {
|
||||||
|
return adjusted; // fallback (shouldn't happen in normal MCTS)
|
||||||
|
}
|
||||||
|
|
||||||
|
for v in adjusted.values_mut() {
|
||||||
|
*v /= sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
adjusted
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sample_move(dist: &HashMap<ChessMove, f32>, rng: &mut ThreadRng) -> Option<ChessMove> {
|
||||||
|
let mut r: f32 = rng.random_range(0.0..1.0);
|
||||||
|
|
||||||
|
for (m, p) in dist {
|
||||||
|
r -= p;
|
||||||
|
if r <= 0.0 {
|
||||||
|
return Some(m.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fallback due to floating point drift
|
||||||
|
dist.keys().next().cloned()
|
||||||
|
}
|
||||||
8
uci/Cargo.toml
Normal file
8
uci/Cargo.toml
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
[package]
|
||||||
|
name = "uci"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2024"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
engine = { path = "../engine" }
|
||||||
|
chess = "3.2.0"
|
||||||
216
uci/src/main.rs
Normal file
216
uci/src/main.rs
Normal file
@ -0,0 +1,216 @@
|
|||||||
|
use chess::{ChessMove, Game};
|
||||||
|
use engine::{legal_action_mask, Engine};
|
||||||
|
use std::io;
|
||||||
|
use std::str::FromStr;
|
||||||
|
|
||||||
|
const ENGINE_NAME: &str = "Chess Dragon";
|
||||||
|
const ENGINE_AUTHOR: &str = "Drake Marino";
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
pub fn uci_loop() {
|
||||||
|
let stdin = io::stdin();
|
||||||
|
|
||||||
|
let mut game = Game::new();
|
||||||
|
let mut engine = Engine::default();
|
||||||
|
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let mut input = String::new();
|
||||||
|
if stdin.read_line(&mut input).is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let input = input.trim();
|
||||||
|
if input.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let parts: Vec<&str> = input.split_whitespace().collect();
|
||||||
|
let command = parts[0];
|
||||||
|
match command {
|
||||||
|
"uci" => uci_uci(),
|
||||||
|
"id" => uci_id(),
|
||||||
|
"option" => uci_option(),
|
||||||
|
"setoption" => uci_setoption(input),
|
||||||
|
"ucinewgame" => uci_ucinewgame(&mut game),
|
||||||
|
"position" => uci_position(input, &mut game),
|
||||||
|
"go" => uci_go(input, &mut game, &mut engine),
|
||||||
|
_ => panic!("Invalid command!"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UCI Commands
|
||||||
|
|
||||||
|
fn uci_id() {
|
||||||
|
println!("id name {}", ENGINE_NAME);
|
||||||
|
println!("id author {}", ENGINE_AUTHOR);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn uci_option() {
|
||||||
|
// none currently implemented
|
||||||
|
}
|
||||||
|
|
||||||
|
fn uci_uci() {
|
||||||
|
uci_id();
|
||||||
|
uci_option();
|
||||||
|
println!("uciok");
|
||||||
|
}
|
||||||
|
|
||||||
|
fn uci_setoption(input: &str) {
|
||||||
|
// TODO
|
||||||
|
}
|
||||||
|
|
||||||
|
fn uci_ucinewgame(game: &mut Game) {
|
||||||
|
*game = Game::new();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn uci_position(input: &str, game: &mut Game) {
|
||||||
|
let mut tokens = input.split_whitespace();
|
||||||
|
|
||||||
|
if tokens.next().unwrap() != "position" {
|
||||||
|
panic!("position command not provided!");
|
||||||
|
}
|
||||||
|
|
||||||
|
let moves: Vec<String>;
|
||||||
|
|
||||||
|
match tokens.next().unwrap() {
|
||||||
|
"startpos" => {
|
||||||
|
*game = Game::new();
|
||||||
|
moves = tokens.skip_while(|&t| t != "moves").skip(1).map(|s| s.to_string()).collect();
|
||||||
|
}
|
||||||
|
"fen" => {
|
||||||
|
// FEN has 6 space-separated fields
|
||||||
|
let fen_fields: Vec<&str> = tokens.by_ref().take(6).collect();
|
||||||
|
if fen_fields.len() != 6 {
|
||||||
|
panic!("fen field invalid!");
|
||||||
|
}
|
||||||
|
let fen = fen_fields.join(" ");
|
||||||
|
*game = Game::from_str(&fen).expect("Invalid board position");
|
||||||
|
moves = tokens.skip_while(|&t| t != "moves").skip(1).map(|s| s.to_string()).collect();
|
||||||
|
}
|
||||||
|
_ => panic!("Position command invalid!"),
|
||||||
|
}
|
||||||
|
|
||||||
|
for mv in moves {
|
||||||
|
game.make_move(ChessMove::from_str(&mv).expect("Invalid move!"));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
fn uci_go(input: &str, game: &mut Game, engine: &mut Engine) {
|
||||||
|
let parts: Vec<&str> = input.split_whitespace().collect();
|
||||||
|
|
||||||
|
let mut wtime = None;
|
||||||
|
let mut btime = None;
|
||||||
|
let mut winc = None;
|
||||||
|
let mut binc = None;
|
||||||
|
let mut movetime = None;
|
||||||
|
let mut max_depth = None;
|
||||||
|
let mut max_nodes = None;
|
||||||
|
|
||||||
|
let mut i = 1; // Skip "go"
|
||||||
|
while i < parts.len() {
|
||||||
|
match parts[i] {
|
||||||
|
"wtime" => {
|
||||||
|
if i + 1 < parts.len() {
|
||||||
|
wtime = parts[i + 1].parse::<u64>().ok();
|
||||||
|
i += 2;
|
||||||
|
} else {
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"btime" => {
|
||||||
|
if i + 1 < parts.len() {
|
||||||
|
btime = parts[i + 1].parse::<u64>().ok();
|
||||||
|
i += 2;
|
||||||
|
} else {
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"winc" => {
|
||||||
|
if i + 1 < parts.len() {
|
||||||
|
winc = parts[i + 1].parse::<u64>().ok();
|
||||||
|
i += 2;
|
||||||
|
} else {
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"binc" => {
|
||||||
|
if i + 1 < parts.len() {
|
||||||
|
binc = parts[i + 1].parse::<u64>().ok();
|
||||||
|
i += 2;
|
||||||
|
} else {
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"movetime" => {
|
||||||
|
if i + 1 < parts.len() {
|
||||||
|
movetime = parts[i + 1].parse::<u64>().ok();
|
||||||
|
i += 2;
|
||||||
|
} else {
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"depth" => {
|
||||||
|
if i + 1 < parts.len() {
|
||||||
|
max_depth = parts[i + 1].parse::<u16>().ok();
|
||||||
|
i += 2;
|
||||||
|
} else {
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"infinite" => {
|
||||||
|
max_depth = Some(100);
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
"nodes" => {
|
||||||
|
if i + 1 < parts.len() {
|
||||||
|
max_nodes = parts[i + 1].parse::<usize>().ok();
|
||||||
|
i += 2;
|
||||||
|
} else {
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update search settings
|
||||||
|
if let Some(wt) = wtime {
|
||||||
|
engine.search_settings.wtime = wt;
|
||||||
|
}
|
||||||
|
if let Some(bt) = btime {
|
||||||
|
engine.search_settings.btime = bt;
|
||||||
|
}
|
||||||
|
if let Some(wi) = winc {
|
||||||
|
engine.search_settings.winc = wi;
|
||||||
|
}
|
||||||
|
if let Some(bi) = binc {
|
||||||
|
engine.search_settings.binc = bi;
|
||||||
|
}
|
||||||
|
if let Some(mt) = movetime {
|
||||||
|
engine.search_settings.movetime = Some(mt);
|
||||||
|
}
|
||||||
|
if let Some(max_depth) = max_depth {
|
||||||
|
engine.search_settings.max_depth = Some(max_depth);
|
||||||
|
}
|
||||||
|
if let Some(nodes) = max_nodes {
|
||||||
|
engine.search_settings.max_nodes = Some(nodes);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
uci_loop();
|
||||||
|
}
|
||||||
6
web/Cargo.toml
Normal file
6
web/Cargo.toml
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
[package]
|
||||||
|
name = "web"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2024"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
3
web/src/main.rs
Normal file
3
web/src/main.rs
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
fn main() {
|
||||||
|
println!("Hello, world!");
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user