135 lines
4.0 KiB
XML
135 lines
4.0 KiB
XML
// use crate::model::{ChessBatcher, ChessModel, ChessModelConfig};
|
|
// use burn::config::Config;
|
|
// use burn::data::dataloader::batcher::Batcher;
|
|
// use burn::optim::{Adam, AdamConfig, Optimizer, SimpleOptimizer};
|
|
// use burn::optim::adaptor::OptimizerAdaptor;
|
|
// use burn::prelude::Backend;
|
|
// use burn::tensor::backend::AutodiffBackend;
|
|
// use burn::train::TrainStep;
|
|
// use crate::mcts::MctsResults;
|
|
//
|
|
// #[derive(Config, Debug)]
|
|
// pub struct ChessTrainerConfig {
|
|
// pub model: ChessModelConfig,
|
|
// pub optimizer: AdamConfig,
|
|
// #[config(default = 10)]
|
|
// pub num_epochs: usize,
|
|
// #[config(default = 64)]
|
|
// pub batch_size: usize,
|
|
// #[config(default = 4)]
|
|
// pub num_workers: usize,
|
|
// #[config(default = 42)]
|
|
// pub seed: u64,
|
|
// #[config(default = 1.0e-4)]
|
|
// pub learning_rate: f64,
|
|
// }
|
|
//
|
|
// impl ChessTrainerConfig {
|
|
// pub fn init<B: Backend>(
|
|
// model_config: ChessModelConfig,
|
|
// optimizer: AdamConfig,
|
|
// num_epochs: usize,
|
|
// batch_size: usize,
|
|
// num_workers: usize,
|
|
// seed: u64,
|
|
// learning_rate: f64,
|
|
// ) -> ChessTrainer<B> {
|
|
// ChessTrainer {
|
|
// model: model_config::init(),
|
|
// optimizer: optimizer.init(),
|
|
// num_epochs,
|
|
//
|
|
//
|
|
// }
|
|
// }
|
|
// }
|
|
//
|
|
// pub struct ChessTrainer<B: AutodiffBackend> {
|
|
// pub model: ChessModel<B>,
|
|
// pub optimizer: Adam,
|
|
// learning_rate: f64,
|
|
// pub batcher: ChessBatcher,
|
|
// pub device: B::Device,
|
|
// }
|
|
//
|
|
// impl<B: AutodiffBackend> ChessTrainer<B> {
|
|
// pub fn new(model: ChessModel<B>, device: B::Device) -> Self {
|
|
// let optimizer = AdamConfig::new().init();
|
|
//
|
|
// Self {
|
|
// model,
|
|
// optimizer,
|
|
// batcher: ChessBatcher::default(),
|
|
// device,
|
|
// }
|
|
// }
|
|
//
|
|
// pub fn train_step(&mut self, batch_data: Vec<MctsResults>) -> f32 {
|
|
// // 1. Convert to tensors
|
|
// let batch = self.batcher.batch(batch_data, &self.device);
|
|
//
|
|
// // 2. Forward + backward (your TrainStep impl)
|
|
// let output = self.model.step(batch);
|
|
//
|
|
// // 3. Update weights
|
|
// self.optimizer.step(, &mut self.model, output.grads);
|
|
//
|
|
// // 4. Return loss (for logging)
|
|
// let loss_tensor = output.item.loss.clone().into_data();
|
|
// let loss = loss_tensor.to_vec::<f32>().unwrap()[0];
|
|
//
|
|
// loss
|
|
// }
|
|
// }
|
|
|
|
// fn create_artifact_dir(artifact_dir: &str) {
|
|
// // Remove existing artifacts before to get an accurate learner summary
|
|
// std::fs::remove_dir_all(artifact_dir).ok();
|
|
// std::fs::create_dir_all(artifact_dir).ok();
|
|
// }
|
|
//
|
|
// pub fn train<B: AutodiffBackend>(
|
|
// artifact_dir: &str,
|
|
// config: ChessTrainingConfig,
|
|
// device: B::Device,
|
|
// ) {
|
|
// create_artifact_dir(artifact_dir);
|
|
// config
|
|
// .save(format!("{artifact_dir}/config.json"))
|
|
// .expect("Config should be saved successfully");
|
|
//
|
|
// B::seed(&device, config.seed);
|
|
//
|
|
// let batcher = ChessBatcher::default();
|
|
|
|
// let dataloader_train = DataLoaderBuilder::new(batcher.clone())
|
|
// .batch_size(config.batch_size)
|
|
// .shuffle(config.seed)
|
|
// .num_workers(config.num_workers)
|
|
// .build(MnistDataset::train());
|
|
//
|
|
// let dataloader_test = DataLoaderBuilder::new(batcher)
|
|
// .batch_size(config.batch_size)
|
|
// .shuffle(config.seed)
|
|
// .num_workers(config.num_workers)
|
|
// .build(MnistDataset::test());
|
|
//
|
|
// let training = SupervisedTraining::new(artifact_dir, dataloader_train, dataloader_test)
|
|
// .metrics((AccuracyMetric::new(), LossMetric::new()))
|
|
// .with_file_checkpointer(CompactRecorder::new())
|
|
// .num_epochs(config.num_epochs)
|
|
// .summary();
|
|
//
|
|
// let model = config.model.init::<B>(&device);
|
|
// let result = training.launch(Learner::new(
|
|
// model,
|
|
// config.optimizer.init(),
|
|
// config.learning_rate,
|
|
// ));
|
|
//
|
|
// result
|
|
// .model
|
|
// .save_file(format!("{artifact_dir}/model"), &CompactRecorder::new())
|
|
// .expect("Trained model should be saved successfully");
|
|
// }
|