updates
This commit is contained in:
parent
764252b0c4
commit
7d5c0cf7eb
@ -1,6 +1,6 @@
|
||||
#![recursion_limit = "256"]
|
||||
|
||||
use burn::backend::{Autodiff, Cuda};
|
||||
use burn::backend::{Autodiff, Wgpu};
|
||||
use burn::optim::AdamConfig;
|
||||
use engine::mcts::MctsConfig;
|
||||
use engine::training::train::{train, TrainingConfig};
|
||||
@ -14,15 +14,15 @@ use engine::training::train::{train, TrainingConfig};
|
||||
// }
|
||||
|
||||
fn main() {
|
||||
// type MyBackend = Wgpu<f32, i32>;
|
||||
type MyBackend = Cuda<f32, i32>;
|
||||
type MyBackend = Wgpu<f32, i32>;
|
||||
// type MyBackend = Cuda<f32, i32>;
|
||||
// type MyBackend = NdArray<f32, i32>;
|
||||
type MyAutodiffBackend = Autodiff<MyBackend>;
|
||||
// let device = burn::backend::wgpu::WgpuDevice::default();
|
||||
let device = burn::backend::wgpu::WgpuDevice::default();
|
||||
// let device = burn::backend::ndarray::NdArrayDevice::default();
|
||||
let device = burn::backend::cuda::CudaDevice::default();
|
||||
// let device = burn::backend::cuda::CudaDevice::default();
|
||||
|
||||
let mcts_config = MctsConfig::new(10, 1.0, 0.05, 0.25);
|
||||
let mcts_config = MctsConfig::new(100, 1.0, 0.05, 0.25);
|
||||
|
||||
let adam_config = AdamConfig::new();
|
||||
|
||||
@ -34,8 +34,8 @@ fn main() {
|
||||
load_model: false,
|
||||
hidden_channels: 64,
|
||||
num_blocks: 4,
|
||||
batch_size: 128, // num positions sampled to a batch (2048)
|
||||
num_episodes: 10, // num games generated per iteration (5000)
|
||||
batch_size: 1024, // num positions sampled to a batch (2048)
|
||||
num_episodes: 100, // num games generated per iteration (5000)
|
||||
buffer_max_size: 200_000, // max number of samples in the buffer
|
||||
mcts_config,
|
||||
optimizer: adam_config,
|
||||
|
||||
@ -10,7 +10,6 @@ use chess::Piece::{Bishop, Knight, Pawn, Queen, Rook};
|
||||
use chess::{Board, ChessMove, Color, MoveGen, Piece, ALL_COLORS, ALL_PIECES};
|
||||
use std::collections::HashMap;
|
||||
use std::marker::PhantomData;
|
||||
use std::time::Instant;
|
||||
|
||||
pub struct Node {
|
||||
pub prior: f32,
|
||||
@ -168,9 +167,9 @@ impl<B: Backend> Mcts<B> {
|
||||
let state: Tensor<B, 4> =
|
||||
encode_board_state_perspective(&arena[node_idx].board_state, device)
|
||||
.reshape([1, 18, 8, 8]);
|
||||
let start = Instant::now();
|
||||
// let start = Instant::now();
|
||||
let (policy_head, value_head) = model.forward(state);
|
||||
println!("time: {:?}", start.elapsed());
|
||||
// println!("time: {:?}", start.elapsed());
|
||||
|
||||
let legal_moves: Vec<ChessMove> =
|
||||
MoveGen::new_legal(&arena[node_idx].board_state.board).collect();
|
||||
|
||||
@ -313,14 +313,17 @@ impl<B: Backend> Batcher<B, TrainingSample, ChessBatch<B>> for ChessBatcher {
|
||||
*p /= sum;
|
||||
}
|
||||
}
|
||||
Tensor::<B, 2>::from_floats(TensorData::new(policy, [4672]), device).unsqueeze()
|
||||
Tensor::<B, 2>::from_floats(TensorData::new(policy, [1, 4672]), device).unsqueeze()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let value_target_tensors = items
|
||||
.iter()
|
||||
.map(|item| {
|
||||
Tensor::<B, 2>::from_floats(TensorData::new(vec![item.value_target], [1]), device)
|
||||
Tensor::<B, 2>::from_floats(
|
||||
TensorData::new(vec![item.value_target], [1, 1]),
|
||||
device,
|
||||
)
|
||||
.reshape([1, 1])
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Loading…
Reference in New Issue
Block a user