chess_dragon/engine/src/train.rs
2026-05-23 15:06:10 -05:00

135 lines
4.0 KiB
XML

// use crate::model::{ChessBatcher, ChessModel, ChessModelConfig};
// use burn::config::Config;
// use burn::data::dataloader::batcher::Batcher;
// use burn::optim::{Adam, AdamConfig, Optimizer, SimpleOptimizer};
// use burn::optim::adaptor::OptimizerAdaptor;
// use burn::prelude::Backend;
// use burn::tensor::backend::AutodiffBackend;
// use burn::train::TrainStep;
// use crate::mcts::MctsResults;
//
// #[derive(Config, Debug)]
// pub struct ChessTrainerConfig {
// pub model: ChessModelConfig,
// pub optimizer: AdamConfig,
// #[config(default = 10)]
// pub num_epochs: usize,
// #[config(default = 64)]
// pub batch_size: usize,
// #[config(default = 4)]
// pub num_workers: usize,
// #[config(default = 42)]
// pub seed: u64,
// #[config(default = 1.0e-4)]
// pub learning_rate: f64,
// }
//
// impl ChessTrainerConfig {
// pub fn init<B: Backend>(
// model_config: ChessModelConfig,
// optimizer: AdamConfig,
// num_epochs: usize,
// batch_size: usize,
// num_workers: usize,
// seed: u64,
// learning_rate: f64,
// ) -> ChessTrainer<B> {
// ChessTrainer {
// model: model_config::init(),
// optimizer: optimizer.init(),
// num_epochs,
//
//
// }
// }
// }
//
// pub struct ChessTrainer<B: AutodiffBackend> {
// pub model: ChessModel<B>,
// pub optimizer: Adam,
// learning_rate: f64,
// pub batcher: ChessBatcher,
// pub device: B::Device,
// }
//
// impl<B: AutodiffBackend> ChessTrainer<B> {
// pub fn new(model: ChessModel<B>, device: B::Device) -> Self {
// let optimizer = AdamConfig::new().init();
//
// Self {
// model,
// optimizer,
// batcher: ChessBatcher::default(),
// device,
// }
// }
//
// pub fn train_step(&mut self, batch_data: Vec<MctsResults>) -> f32 {
// // 1. Convert to tensors
// let batch = self.batcher.batch(batch_data, &self.device);
//
// // 2. Forward + backward (your TrainStep impl)
// let output = self.model.step(batch);
//
// // 3. Update weights
// self.optimizer.step(, &mut self.model, output.grads);
//
// // 4. Return loss (for logging)
// let loss_tensor = output.item.loss.clone().into_data();
// let loss = loss_tensor.to_vec::<f32>().unwrap()[0];
//
// loss
// }
// }
// fn create_artifact_dir(artifact_dir: &str) {
// // Remove existing artifacts before to get an accurate learner summary
// std::fs::remove_dir_all(artifact_dir).ok();
// std::fs::create_dir_all(artifact_dir).ok();
// }
//
// pub fn train<B: AutodiffBackend>(
// artifact_dir: &str,
// config: ChessTrainingConfig,
// device: B::Device,
// ) {
// create_artifact_dir(artifact_dir);
// config
// .save(format!("{artifact_dir}/config.json"))
// .expect("Config should be saved successfully");
//
// B::seed(&device, config.seed);
//
// let batcher = ChessBatcher::default();
// let dataloader_train = DataLoaderBuilder::new(batcher.clone())
// .batch_size(config.batch_size)
// .shuffle(config.seed)
// .num_workers(config.num_workers)
// .build(MnistDataset::train());
//
// let dataloader_test = DataLoaderBuilder::new(batcher)
// .batch_size(config.batch_size)
// .shuffle(config.seed)
// .num_workers(config.num_workers)
// .build(MnistDataset::test());
//
// let training = SupervisedTraining::new(artifact_dir, dataloader_train, dataloader_test)
// .metrics((AccuracyMetric::new(), LossMetric::new()))
// .with_file_checkpointer(CompactRecorder::new())
// .num_epochs(config.num_epochs)
// .summary();
//
// let model = config.model.init::<B>(&device);
// let result = training.launch(Learner::new(
// model,
// config.optimizer.init(),
// config.learning_rate,
// ));
//
// result
// .model
// .save_file(format!("{artifact_dir}/model"), &CompactRecorder::new())
// .expect("Trained model should be saved successfully");
// }