updates
This commit is contained in:
parent
764252b0c4
commit
7d5c0cf7eb
@ -1,6 +1,6 @@
|
|||||||
#![recursion_limit = "256"]
|
#![recursion_limit = "256"]
|
||||||
|
|
||||||
use burn::backend::{Autodiff, Cuda};
|
use burn::backend::{Autodiff, Wgpu};
|
||||||
use burn::optim::AdamConfig;
|
use burn::optim::AdamConfig;
|
||||||
use engine::mcts::MctsConfig;
|
use engine::mcts::MctsConfig;
|
||||||
use engine::training::train::{train, TrainingConfig};
|
use engine::training::train::{train, TrainingConfig};
|
||||||
@ -14,15 +14,15 @@ use engine::training::train::{train, TrainingConfig};
|
|||||||
// }
|
// }
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
// type MyBackend = Wgpu<f32, i32>;
|
type MyBackend = Wgpu<f32, i32>;
|
||||||
type MyBackend = Cuda<f32, i32>;
|
// type MyBackend = Cuda<f32, i32>;
|
||||||
// type MyBackend = NdArray<f32, i32>;
|
// type MyBackend = NdArray<f32, i32>;
|
||||||
type MyAutodiffBackend = Autodiff<MyBackend>;
|
type MyAutodiffBackend = Autodiff<MyBackend>;
|
||||||
// let device = burn::backend::wgpu::WgpuDevice::default();
|
let device = burn::backend::wgpu::WgpuDevice::default();
|
||||||
// let device = burn::backend::ndarray::NdArrayDevice::default();
|
// let device = burn::backend::ndarray::NdArrayDevice::default();
|
||||||
let device = burn::backend::cuda::CudaDevice::default();
|
// let device = burn::backend::cuda::CudaDevice::default();
|
||||||
|
|
||||||
let mcts_config = MctsConfig::new(10, 1.0, 0.05, 0.25);
|
let mcts_config = MctsConfig::new(100, 1.0, 0.05, 0.25);
|
||||||
|
|
||||||
let adam_config = AdamConfig::new();
|
let adam_config = AdamConfig::new();
|
||||||
|
|
||||||
@ -34,8 +34,8 @@ fn main() {
|
|||||||
load_model: false,
|
load_model: false,
|
||||||
hidden_channels: 64,
|
hidden_channels: 64,
|
||||||
num_blocks: 4,
|
num_blocks: 4,
|
||||||
batch_size: 128, // num positions sampled to a batch (2048)
|
batch_size: 1024, // num positions sampled to a batch (2048)
|
||||||
num_episodes: 10, // num games generated per iteration (5000)
|
num_episodes: 100, // num games generated per iteration (5000)
|
||||||
buffer_max_size: 200_000, // max number of samples in the buffer
|
buffer_max_size: 200_000, // max number of samples in the buffer
|
||||||
mcts_config,
|
mcts_config,
|
||||||
optimizer: adam_config,
|
optimizer: adam_config,
|
||||||
|
|||||||
@ -10,7 +10,6 @@ use chess::Piece::{Bishop, Knight, Pawn, Queen, Rook};
|
|||||||
use chess::{Board, ChessMove, Color, MoveGen, Piece, ALL_COLORS, ALL_PIECES};
|
use chess::{Board, ChessMove, Color, MoveGen, Piece, ALL_COLORS, ALL_PIECES};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
use std::time::Instant;
|
|
||||||
|
|
||||||
pub struct Node {
|
pub struct Node {
|
||||||
pub prior: f32,
|
pub prior: f32,
|
||||||
@ -168,9 +167,9 @@ impl<B: Backend> Mcts<B> {
|
|||||||
let state: Tensor<B, 4> =
|
let state: Tensor<B, 4> =
|
||||||
encode_board_state_perspective(&arena[node_idx].board_state, device)
|
encode_board_state_perspective(&arena[node_idx].board_state, device)
|
||||||
.reshape([1, 18, 8, 8]);
|
.reshape([1, 18, 8, 8]);
|
||||||
let start = Instant::now();
|
// let start = Instant::now();
|
||||||
let (policy_head, value_head) = model.forward(state);
|
let (policy_head, value_head) = model.forward(state);
|
||||||
println!("time: {:?}", start.elapsed());
|
// println!("time: {:?}", start.elapsed());
|
||||||
|
|
||||||
let legal_moves: Vec<ChessMove> =
|
let legal_moves: Vec<ChessMove> =
|
||||||
MoveGen::new_legal(&arena[node_idx].board_state.board).collect();
|
MoveGen::new_legal(&arena[node_idx].board_state.board).collect();
|
||||||
|
|||||||
@ -313,14 +313,17 @@ impl<B: Backend> Batcher<B, TrainingSample, ChessBatch<B>> for ChessBatcher {
|
|||||||
*p /= sum;
|
*p /= sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Tensor::<B, 2>::from_floats(TensorData::new(policy, [4672]), device).unsqueeze()
|
Tensor::<B, 2>::from_floats(TensorData::new(policy, [1, 4672]), device).unsqueeze()
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
let value_target_tensors = items
|
let value_target_tensors = items
|
||||||
.iter()
|
.iter()
|
||||||
.map(|item| {
|
.map(|item| {
|
||||||
Tensor::<B, 2>::from_floats(TensorData::new(vec![item.value_target], [1]), device)
|
Tensor::<B, 2>::from_floats(
|
||||||
|
TensorData::new(vec![item.value_target], [1, 1]),
|
||||||
|
device,
|
||||||
|
)
|
||||||
.reshape([1, 1])
|
.reshape([1, 1])
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user