// use crate::model::{ChessBatcher, ChessModel, ChessModelConfig}; // use burn::config::Config; // use burn::data::dataloader::batcher::Batcher; // use burn::optim::{Adam, AdamConfig, Optimizer, SimpleOptimizer}; // use burn::optim::adaptor::OptimizerAdaptor; // use burn::prelude::Backend; // use burn::tensor::backend::AutodiffBackend; // use burn::train::TrainStep; // use crate::mcts::MctsResults; // // #[derive(Config, Debug)] // pub struct ChessTrainerConfig { // pub model: ChessModelConfig, // pub optimizer: AdamConfig, // #[config(default = 10)] // pub num_epochs: usize, // #[config(default = 64)] // pub batch_size: usize, // #[config(default = 4)] // pub num_workers: usize, // #[config(default = 42)] // pub seed: u64, // #[config(default = 1.0e-4)] // pub learning_rate: f64, // } // // impl ChessTrainerConfig { // pub fn init( // model_config: ChessModelConfig, // optimizer: AdamConfig, // num_epochs: usize, // batch_size: usize, // num_workers: usize, // seed: u64, // learning_rate: f64, // ) -> ChessTrainer { // ChessTrainer { // model: model_config::init(), // optimizer: optimizer.init(), // num_epochs, // // // } // } // } // // pub struct ChessTrainer { // pub model: ChessModel, // pub optimizer: Adam, // learning_rate: f64, // pub batcher: ChessBatcher, // pub device: B::Device, // } // // impl ChessTrainer { // pub fn new(model: ChessModel, device: B::Device) -> Self { // let optimizer = AdamConfig::new().init(); // // Self { // model, // optimizer, // batcher: ChessBatcher::default(), // device, // } // } // // pub fn train_step(&mut self, batch_data: Vec) -> f32 { // // 1. Convert to tensors // let batch = self.batcher.batch(batch_data, &self.device); // // // 2. Forward + backward (your TrainStep impl) // let output = self.model.step(batch); // // // 3. Update weights // self.optimizer.step(, &mut self.model, output.grads); // // // 4. Return loss (for logging) // let loss_tensor = output.item.loss.clone().into_data(); // let loss = loss_tensor.to_vec::().unwrap()[0]; // // loss // } // } // fn create_artifact_dir(artifact_dir: &str) { // // Remove existing artifacts before to get an accurate learner summary // std::fs::remove_dir_all(artifact_dir).ok(); // std::fs::create_dir_all(artifact_dir).ok(); // } // // pub fn train( // artifact_dir: &str, // config: ChessTrainingConfig, // device: B::Device, // ) { // create_artifact_dir(artifact_dir); // config // .save(format!("{artifact_dir}/config.json")) // .expect("Config should be saved successfully"); // // B::seed(&device, config.seed); // // let batcher = ChessBatcher::default(); // let dataloader_train = DataLoaderBuilder::new(batcher.clone()) // .batch_size(config.batch_size) // .shuffle(config.seed) // .num_workers(config.num_workers) // .build(MnistDataset::train()); // // let dataloader_test = DataLoaderBuilder::new(batcher) // .batch_size(config.batch_size) // .shuffle(config.seed) // .num_workers(config.num_workers) // .build(MnistDataset::test()); // // let training = SupervisedTraining::new(artifact_dir, dataloader_train, dataloader_test) // .metrics((AccuracyMetric::new(), LossMetric::new())) // .with_file_checkpointer(CompactRecorder::new()) // .num_epochs(config.num_epochs) // .summary(); // // let model = config.model.init::(&device); // let result = training.launch(Learner::new( // model, // config.optimizer.init(), // config.learning_rate, // )); // // result // .model // .save_file(format!("{artifact_dir}/model"), &CompactRecorder::new()) // .expect("Trained model should be saved successfully"); // }