chess_dragon/engine/src/main.rs
2026-05-23 23:12:27 -05:00

105 lines
3.4 KiB
Rust

#![recursion_limit = "256"]
use burn::backend::{Autodiff, Wgpu};
use burn::optim::AdamConfig;
use engine::mcts::MctsConfig;
use engine::training::train::{train, TrainingConfig};
// fn main() {
// type MyBackend = Wgpu<f32, i32>;
//
// let device = Default::default();
// let model = ChessModelConfig::new(10, 512).init::<MyBackend>(&device);
//
// println!("{model}");
// }
fn main() {
type MyBackend = Wgpu<f32, i32>;
// type MyBackend = Cuda<f32, i32>;
// type MyBackend = NdArray<f32, i32>;
type MyAutodiffBackend = Autodiff<MyBackend>;
let device = burn::backend::wgpu::WgpuDevice::default();
// let device = burn::backend::ndarray::NdArrayDevice::default();
// let device = burn::backend::cuda::CudaDevice::default();
let mcts_config = MctsConfig::new(100, 1.0, 0.05, 0.25);
let adam_config = AdamConfig::new();
let training_config = TrainingConfig {
max_time_s: None,
num_iters: None,
max_depth: 100, // unused
model_name: String::from("Test1"),
load_model: false,
hidden_channels: 64,
num_blocks: 4,
batch_size: 1024, // num positions sampled to a batch (2048)
num_episodes: 100, // num games generated per iteration (5000)
buffer_max_size: 200_000, // max number of samples in the buffer
mcts_config,
optimizer: adam_config,
lr: 2e-4,
};
train::<MyAutodiffBackend>(training_config, device);
// let mut board_state: Option<BoardState> = Some(BoardState::new(
// Board::default(),
// 0,
// HashMap::<u64, u8>::new(),
// ));
//
//
// // let start = Instant::now();
// // rollout(&board, &PolicyNetwork::new());
// // let end = Instant::now();
// // println!("Rollout: {:?}", end - start);
// let mut previous_state = board_state.clone().unwrap();
//
// let mut move_history: String = String::new();
//
// let mut mcts = MctsConfig::new()
//
// while board_state.is_some() {
// // mcts now returns Option<(BoardState, String)>
// let results = Mcts::search(&previous_state, 1000000, 0);
// if !results.move_dist.is_empty() {
// let mut best_prob: f32 = -1.0;
// let mut best_move: &ChessMove = &ChessMove::default();
//
// for m in results.move_dist.iter() {
// if m.1 > &best_prob {
// best_move = m.0;
// }
// }
//
// previous_state = results.board_state;
// previous_state.board = previous_state.board.make_move_new(*best_move);
// board_state = Some(previous_state.clone());
//
// println!(
// "Chosen move: {}, board state: {}",
// best_move.to_string(),
// previous_state.board
// );
// move_history += &*best_move.to_string();
// move_history += " ";
// } else {
// board_state = None;
// }
// }
//
// println!("Finished game w/ board state: {}", previous_state.board);
// println!("Move history: {}", move_history);
// println!(
// "Game ended! status: {:?}, clock: {}, 3-fold: {}",
// previous_state.board.status(),
// previous_state.halfmove_clock,
// previous_state
// .repetition_table
// .iter()
// .any(|(_, ct)| ct >= &3u8)
// );
}