use crate::{ data::{MnistBatch, MnistBatcher}, model::{Model, ModelConfig}, }; use burn::{ data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset}, nn::loss::CrossEntropyLossConfig, optim::AdamConfig, prelude::*, record::CompactRecorder, tensor::backend::AutodiffBackend, train::{ metric::{AccuracyMetric, LossMetric}, ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep, }, }; impl Model { pub fn forward_classification( &self, images: Tensor, targets: Tensor, ) -> ClassificationOutput { let output = self.forward(images); let loss = CrossEntropyLossConfig::new() .init(&output.device()) .forward(output.clone(), targets.clone()); ClassificationOutput::new(loss, output, targets) } } impl TrainStep, ClassificationOutput> for Model { fn step(&self, batch: MnistBatch) -> TrainOutput> { let item = self.forward_classification(batch.images, batch.targets); TrainOutput::new(self, item.loss.backward(), item) } } impl ValidStep, ClassificationOutput> for Model { fn step(&self, batch: MnistBatch) -> ClassificationOutput { self.forward_classification(batch.images, batch.targets) } } #[derive(Config)] pub struct TrainingConfig { pub model: ModelConfig, pub optimizer: AdamConfig, #[config(default = 10)] pub num_epochs: usize, #[config(default = 64)] pub batch_size: usize, #[config(default = 4)] pub num_workers: usize, #[config(default = 42)] pub seed: u64, #[config(default = 1.0e-4)] pub learning_rate: f64, } fn create_artifact_dir(artifact_dir: &str) { // Remove existing artifacts before to get an accurate learner summary std::fs::remove_dir_all(artifact_dir).ok(); std::fs::create_dir_all(artifact_dir).ok(); } pub fn train(artifact_dir: &str, config: TrainingConfig, device: B::Device) { create_artifact_dir(artifact_dir); config .save(format!("{artifact_dir}/config.json")) .expect("Config should be saved successfully"); B::seed(config.seed); let batcher_train = MnistBatcher::::new(device.clone()); let batcher_valid = MnistBatcher::::new(device.clone()); let dataloader_train = DataLoaderBuilder::new(batcher_train) .batch_size(config.batch_size) .shuffle(config.seed) .num_workers(config.num_workers) .build(MnistDataset::train()); let dataloader_test = DataLoaderBuilder::new(batcher_valid) .batch_size(config.batch_size) .shuffle(config.seed) .num_workers(config.num_workers) .build(MnistDataset::test()); let learner = LearnerBuilder::new(artifact_dir) .metric_train_numeric(AccuracyMetric::new()) .metric_valid_numeric(AccuracyMetric::new()) .metric_train_numeric(LossMetric::new()) .metric_valid_numeric(LossMetric::new()) .with_file_checkpointer(CompactRecorder::new()) .devices(vec![device.clone()]) .num_epochs(config.num_epochs) .summary() .build( config.model.init::(&device), config.optimizer.init(), config.learning_rate, ); let model_trained = learner.fit(dataloader_train, dataloader_test); model_trained .save_file(format!("{artifact_dir}/model"), &CompactRecorder::new()) .expect("Trained model should be saved successfully"); }