This commit is contained in:
Drake Marino 2026-05-23 23:12:27 -05:00
parent 764252b0c4
commit 7d5c0cf7eb
3 changed files with 16 additions and 14 deletions

View File

@ -1,6 +1,6 @@
#![recursion_limit = "256"] #![recursion_limit = "256"]
use burn::backend::{Autodiff, Cuda}; use burn::backend::{Autodiff, Wgpu};
use burn::optim::AdamConfig; use burn::optim::AdamConfig;
use engine::mcts::MctsConfig; use engine::mcts::MctsConfig;
use engine::training::train::{train, TrainingConfig}; use engine::training::train::{train, TrainingConfig};
@ -14,15 +14,15 @@ use engine::training::train::{train, TrainingConfig};
// } // }
fn main() { fn main() {
// type MyBackend = Wgpu<f32, i32>; type MyBackend = Wgpu<f32, i32>;
type MyBackend = Cuda<f32, i32>; // type MyBackend = Cuda<f32, i32>;
// type MyBackend = NdArray<f32, i32>; // type MyBackend = NdArray<f32, i32>;
type MyAutodiffBackend = Autodiff<MyBackend>; type MyAutodiffBackend = Autodiff<MyBackend>;
// let device = burn::backend::wgpu::WgpuDevice::default(); let device = burn::backend::wgpu::WgpuDevice::default();
// let device = burn::backend::ndarray::NdArrayDevice::default(); // let device = burn::backend::ndarray::NdArrayDevice::default();
let device = burn::backend::cuda::CudaDevice::default(); // let device = burn::backend::cuda::CudaDevice::default();
let mcts_config = MctsConfig::new(10, 1.0, 0.05, 0.25); let mcts_config = MctsConfig::new(100, 1.0, 0.05, 0.25);
let adam_config = AdamConfig::new(); let adam_config = AdamConfig::new();
@ -34,8 +34,8 @@ fn main() {
load_model: false, load_model: false,
hidden_channels: 64, hidden_channels: 64,
num_blocks: 4, num_blocks: 4,
batch_size: 128, // num positions sampled to a batch (2048) batch_size: 1024, // num positions sampled to a batch (2048)
num_episodes: 10, // num games generated per iteration (5000) num_episodes: 100, // num games generated per iteration (5000)
buffer_max_size: 200_000, // max number of samples in the buffer buffer_max_size: 200_000, // max number of samples in the buffer
mcts_config, mcts_config,
optimizer: adam_config, optimizer: adam_config,

View File

@ -10,7 +10,6 @@ use chess::Piece::{Bishop, Knight, Pawn, Queen, Rook};
use chess::{Board, ChessMove, Color, MoveGen, Piece, ALL_COLORS, ALL_PIECES}; use chess::{Board, ChessMove, Color, MoveGen, Piece, ALL_COLORS, ALL_PIECES};
use std::collections::HashMap; use std::collections::HashMap;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::time::Instant;
pub struct Node { pub struct Node {
pub prior: f32, pub prior: f32,
@ -168,9 +167,9 @@ impl<B: Backend> Mcts<B> {
let state: Tensor<B, 4> = let state: Tensor<B, 4> =
encode_board_state_perspective(&arena[node_idx].board_state, device) encode_board_state_perspective(&arena[node_idx].board_state, device)
.reshape([1, 18, 8, 8]); .reshape([1, 18, 8, 8]);
let start = Instant::now(); // let start = Instant::now();
let (policy_head, value_head) = model.forward(state); let (policy_head, value_head) = model.forward(state);
println!("time: {:?}", start.elapsed()); // println!("time: {:?}", start.elapsed());
let legal_moves: Vec<ChessMove> = let legal_moves: Vec<ChessMove> =
MoveGen::new_legal(&arena[node_idx].board_state.board).collect(); MoveGen::new_legal(&arena[node_idx].board_state.board).collect();

View File

@ -313,14 +313,17 @@ impl<B: Backend> Batcher<B, TrainingSample, ChessBatch<B>> for ChessBatcher {
*p /= sum; *p /= sum;
} }
} }
Tensor::<B, 2>::from_floats(TensorData::new(policy, [4672]), device).unsqueeze() Tensor::<B, 2>::from_floats(TensorData::new(policy, [1, 4672]), device).unsqueeze()
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let value_target_tensors = items let value_target_tensors = items
.iter() .iter()
.map(|item| { .map(|item| {
Tensor::<B, 2>::from_floats(TensorData::new(vec![item.value_target], [1]), device) Tensor::<B, 2>::from_floats(
TensorData::new(vec![item.value_target], [1, 1]),
device,
)
.reshape([1, 1]) .reshape([1, 1])
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();