This commit is contained in:
Drake Marino 2026-05-23 23:12:27 -05:00
parent 764252b0c4
commit 7d5c0cf7eb
3 changed files with 16 additions and 14 deletions

View File

@ -1,6 +1,6 @@
#![recursion_limit = "256"]
use burn::backend::{Autodiff, Cuda};
use burn::backend::{Autodiff, Wgpu};
use burn::optim::AdamConfig;
use engine::mcts::MctsConfig;
use engine::training::train::{train, TrainingConfig};
@ -14,15 +14,15 @@ use engine::training::train::{train, TrainingConfig};
// }
fn main() {
// type MyBackend = Wgpu<f32, i32>;
type MyBackend = Cuda<f32, i32>;
type MyBackend = Wgpu<f32, i32>;
// type MyBackend = Cuda<f32, i32>;
// type MyBackend = NdArray<f32, i32>;
type MyAutodiffBackend = Autodiff<MyBackend>;
// let device = burn::backend::wgpu::WgpuDevice::default();
let device = burn::backend::wgpu::WgpuDevice::default();
// let device = burn::backend::ndarray::NdArrayDevice::default();
let device = burn::backend::cuda::CudaDevice::default();
// let device = burn::backend::cuda::CudaDevice::default();
let mcts_config = MctsConfig::new(10, 1.0, 0.05, 0.25);
let mcts_config = MctsConfig::new(100, 1.0, 0.05, 0.25);
let adam_config = AdamConfig::new();
@ -34,8 +34,8 @@ fn main() {
load_model: false,
hidden_channels: 64,
num_blocks: 4,
batch_size: 128, // num positions sampled to a batch (2048)
num_episodes: 10, // num games generated per iteration (5000)
batch_size: 1024, // num positions sampled to a batch (2048)
num_episodes: 100, // num games generated per iteration (5000)
buffer_max_size: 200_000, // max number of samples in the buffer
mcts_config,
optimizer: adam_config,

View File

@ -10,7 +10,6 @@ use chess::Piece::{Bishop, Knight, Pawn, Queen, Rook};
use chess::{Board, ChessMove, Color, MoveGen, Piece, ALL_COLORS, ALL_PIECES};
use std::collections::HashMap;
use std::marker::PhantomData;
use std::time::Instant;
pub struct Node {
pub prior: f32,
@ -168,9 +167,9 @@ impl<B: Backend> Mcts<B> {
let state: Tensor<B, 4> =
encode_board_state_perspective(&arena[node_idx].board_state, device)
.reshape([1, 18, 8, 8]);
let start = Instant::now();
// let start = Instant::now();
let (policy_head, value_head) = model.forward(state);
println!("time: {:?}", start.elapsed());
// println!("time: {:?}", start.elapsed());
let legal_moves: Vec<ChessMove> =
MoveGen::new_legal(&arena[node_idx].board_state.board).collect();

View File

@ -313,15 +313,18 @@ impl<B: Backend> Batcher<B, TrainingSample, ChessBatch<B>> for ChessBatcher {
*p /= sum;
}
}
Tensor::<B, 2>::from_floats(TensorData::new(policy, [4672]), device).unsqueeze()
Tensor::<B, 2>::from_floats(TensorData::new(policy, [1, 4672]), device).unsqueeze()
})
.collect::<Vec<_>>();
let value_target_tensors = items
.iter()
.map(|item| {
Tensor::<B, 2>::from_floats(TensorData::new(vec![item.value_target], [1]), device)
.reshape([1, 1])
Tensor::<B, 2>::from_floats(
TensorData::new(vec![item.value_target], [1, 1]),
device,
)
.reshape([1, 1])
})
.collect::<Vec<_>>();