#![recursion_limit = "256"] use burn::backend::{Autodiff, Wgpu}; use burn::optim::AdamConfig; use engine::mcts::MctsConfig; use engine::training::train::{train, TrainingConfig}; // fn main() { // type MyBackend = Wgpu; // // let device = Default::default(); // let model = ChessModelConfig::new(10, 512).init::(&device); // // println!("{model}"); // } fn main() { type MyBackend = Wgpu; // type MyBackend = Cuda; // type MyBackend = NdArray; type MyAutodiffBackend = Autodiff; let device = burn::backend::wgpu::WgpuDevice::default(); // let device = burn::backend::ndarray::NdArrayDevice::default(); // let device = burn::backend::cuda::CudaDevice::default(); let mcts_config = MctsConfig::new(100, 1.0, 0.05, 0.25); let adam_config = AdamConfig::new(); let training_config = TrainingConfig { max_time_s: None, num_iters: None, max_depth: 100, // unused model_name: String::from("Test1"), load_model: false, hidden_channels: 64, num_blocks: 4, batch_size: 1024, // num positions sampled to a batch (2048) num_episodes: 100, // num games generated per iteration (5000) buffer_max_size: 200_000, // max number of samples in the buffer mcts_config, optimizer: adam_config, lr: 2e-4, }; train::(training_config, device); // let mut board_state: Option = Some(BoardState::new( // Board::default(), // 0, // HashMap::::new(), // )); // // // // let start = Instant::now(); // // rollout(&board, &PolicyNetwork::new()); // // let end = Instant::now(); // // println!("Rollout: {:?}", end - start); // let mut previous_state = board_state.clone().unwrap(); // // let mut move_history: String = String::new(); // // let mut mcts = MctsConfig::new() // // while board_state.is_some() { // // mcts now returns Option<(BoardState, String)> // let results = Mcts::search(&previous_state, 1000000, 0); // if !results.move_dist.is_empty() { // let mut best_prob: f32 = -1.0; // let mut best_move: &ChessMove = &ChessMove::default(); // // for m in results.move_dist.iter() { // if m.1 > &best_prob { // best_move = m.0; // } // } // // previous_state = results.board_state; // previous_state.board = previous_state.board.make_move_new(*best_move); // board_state = Some(previous_state.clone()); // // println!( // "Chosen move: {}, board state: {}", // best_move.to_string(), // previous_state.board // ); // move_history += &*best_move.to_string(); // move_history += " "; // } else { // board_state = None; // } // } // // println!("Finished game w/ board state: {}", previous_state.board); // println!("Move history: {}", move_history); // println!( // "Game ended! status: {:?}, clock: {}, 3-fold: {}", // previous_state.board.status(), // previous_state.halfmove_clock, // previous_state // .repetition_table // .iter() // .any(|(_, ct)| ct >= &3u8) // ); }