105 lines
3.4 KiB
Rust
105 lines
3.4 KiB
Rust
#![recursion_limit = "256"]
|
|
|
|
use burn::backend::{Autodiff, Wgpu};
|
|
use burn::optim::AdamConfig;
|
|
use engine::mcts::MctsConfig;
|
|
use engine::training::train::{train, TrainingConfig};
|
|
// fn main() {
|
|
// type MyBackend = Wgpu<f32, i32>;
|
|
//
|
|
// let device = Default::default();
|
|
// let model = ChessModelConfig::new(10, 512).init::<MyBackend>(&device);
|
|
//
|
|
// println!("{model}");
|
|
// }
|
|
|
|
fn main() {
|
|
type MyBackend = Wgpu<f32, i32>;
|
|
// type MyBackend = Cuda<f32, i32>;
|
|
// type MyBackend = NdArray<f32, i32>;
|
|
type MyAutodiffBackend = Autodiff<MyBackend>;
|
|
let device = burn::backend::wgpu::WgpuDevice::default();
|
|
// let device = burn::backend::ndarray::NdArrayDevice::default();
|
|
// let device = burn::backend::cuda::CudaDevice::default();
|
|
|
|
let mcts_config = MctsConfig::new(100, 1.0, 0.05, 0.25);
|
|
|
|
let adam_config = AdamConfig::new();
|
|
|
|
let training_config = TrainingConfig {
|
|
max_time_s: None,
|
|
num_iters: None,
|
|
max_depth: 100, // unused
|
|
model_name: String::from("Test1"),
|
|
load_model: false,
|
|
hidden_channels: 64,
|
|
num_blocks: 4,
|
|
batch_size: 1024, // num positions sampled to a batch (2048)
|
|
num_episodes: 100, // num games generated per iteration (5000)
|
|
buffer_max_size: 200_000, // max number of samples in the buffer
|
|
mcts_config,
|
|
optimizer: adam_config,
|
|
lr: 2e-4,
|
|
};
|
|
|
|
train::<MyAutodiffBackend>(training_config, device);
|
|
|
|
// let mut board_state: Option<BoardState> = Some(BoardState::new(
|
|
// Board::default(),
|
|
// 0,
|
|
// HashMap::<u64, u8>::new(),
|
|
// ));
|
|
//
|
|
//
|
|
// // let start = Instant::now();
|
|
// // rollout(&board, &PolicyNetwork::new());
|
|
// // let end = Instant::now();
|
|
// // println!("Rollout: {:?}", end - start);
|
|
// let mut previous_state = board_state.clone().unwrap();
|
|
//
|
|
// let mut move_history: String = String::new();
|
|
//
|
|
// let mut mcts = MctsConfig::new()
|
|
//
|
|
// while board_state.is_some() {
|
|
// // mcts now returns Option<(BoardState, String)>
|
|
// let results = Mcts::search(&previous_state, 1000000, 0);
|
|
// if !results.move_dist.is_empty() {
|
|
// let mut best_prob: f32 = -1.0;
|
|
// let mut best_move: &ChessMove = &ChessMove::default();
|
|
//
|
|
// for m in results.move_dist.iter() {
|
|
// if m.1 > &best_prob {
|
|
// best_move = m.0;
|
|
// }
|
|
// }
|
|
//
|
|
// previous_state = results.board_state;
|
|
// previous_state.board = previous_state.board.make_move_new(*best_move);
|
|
// board_state = Some(previous_state.clone());
|
|
//
|
|
// println!(
|
|
// "Chosen move: {}, board state: {}",
|
|
// best_move.to_string(),
|
|
// previous_state.board
|
|
// );
|
|
// move_history += &*best_move.to_string();
|
|
// move_history += " ";
|
|
// } else {
|
|
// board_state = None;
|
|
// }
|
|
// }
|
|
//
|
|
// println!("Finished game w/ board state: {}", previous_state.board);
|
|
// println!("Move history: {}", move_history);
|
|
// println!(
|
|
// "Game ended! status: {:?}, clock: {}, 3-fold: {}",
|
|
// previous_state.board.status(),
|
|
// previous_state.halfmove_clock,
|
|
// previous_state
|
|
// .repetition_table
|
|
// .iter()
|
|
// .any(|(_, ct)| ct >= &3u8)
|
|
// );
|
|
}
|