From d68f59aedd21d6d802cd07390b3fdf4929fa9f46 Mon Sep 17 00:00:00 2001 From: Maddox Werts Date: Mon, 17 Mar 2025 10:14:38 -0400 Subject: [PATCH] Moved BURN lib guide to neural --- project/src/{ => neural}/data.rs | 0 project/src/{ => neural}/infrence.rs | 0 project/src/{ => neural}/model.rs | 0 project/src/neural/training.rs | 112 +++++++++++++++++++++++++++ 4 files changed, 112 insertions(+) rename project/src/{ => neural}/data.rs (100%) rename project/src/{ => neural}/infrence.rs (100%) rename project/src/{ => neural}/model.rs (100%) create mode 100644 project/src/neural/training.rs diff --git a/project/src/data.rs b/project/src/neural/data.rs similarity index 100% rename from project/src/data.rs rename to project/src/neural/data.rs diff --git a/project/src/infrence.rs b/project/src/neural/infrence.rs similarity index 100% rename from project/src/infrence.rs rename to project/src/neural/infrence.rs diff --git a/project/src/model.rs b/project/src/neural/model.rs similarity index 100% rename from project/src/model.rs rename to project/src/neural/model.rs diff --git a/project/src/neural/training.rs b/project/src/neural/training.rs new file mode 100644 index 0000000..682e9bb --- /dev/null +++ b/project/src/neural/training.rs @@ -0,0 +1,112 @@ +use crate::{ + data::{MnistBatch, MnistBatcher}, + model::{Model, ModelConfig}, +}; +use burn::{ + data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset}, + nn::loss::CrossEntropyLossConfig, + optim::AdamConfig, + prelude::*, + record::CompactRecorder, + tensor::backend::AutodiffBackend, + train::{ + metric::{AccuracyMetric, LossMetric}, + ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep, + }, +}; + +impl Model { + pub fn forward_classification( + &self, + images: Tensor, + targets: Tensor, + ) -> ClassificationOutput { + let output = self.forward(images); + let loss = CrossEntropyLossConfig::new() + .init(&output.device()) + .forward(output.clone(), targets.clone()); + + ClassificationOutput::new(loss, output, targets) + } +} + +impl TrainStep, ClassificationOutput> for Model { + fn step(&self, batch: MnistBatch) -> TrainOutput> { + let item = self.forward_classification(batch.images, batch.targets); + + TrainOutput::new(self, item.loss.backward(), item) + } +} + +impl ValidStep, ClassificationOutput> for Model { + fn step(&self, batch: MnistBatch) -> ClassificationOutput { + self.forward_classification(batch.images, batch.targets) + } +} + +#[derive(Config)] +pub struct TrainingConfig { + pub model: ModelConfig, + pub optimizer: AdamConfig, + #[config(default = 10)] + pub num_epochs: usize, + #[config(default = 64)] + pub batch_size: usize, + #[config(default = 4)] + pub num_workers: usize, + #[config(default = 42)] + pub seed: u64, + #[config(default = 1.0e-4)] + pub learning_rate: f64, +} + +fn create_artifact_dir(artifact_dir: &str) { + // Remove existing artifacts before to get an accurate learner summary + std::fs::remove_dir_all(artifact_dir).ok(); + std::fs::create_dir_all(artifact_dir).ok(); +} + +pub fn train(artifact_dir: &str, config: TrainingConfig, device: B::Device) { + create_artifact_dir(artifact_dir); + config + .save(format!("{artifact_dir}/config.json")) + .expect("Config should be saved successfully"); + + B::seed(config.seed); + + let batcher_train = MnistBatcher::::new(device.clone()); + let batcher_valid = MnistBatcher::::new(device.clone()); + + let dataloader_train = DataLoaderBuilder::new(batcher_train) + .batch_size(config.batch_size) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(MnistDataset::train()); + + let dataloader_test = DataLoaderBuilder::new(batcher_valid) + .batch_size(config.batch_size) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(MnistDataset::test()); + + let learner = LearnerBuilder::new(artifact_dir) + .metric_train_numeric(AccuracyMetric::new()) + .metric_valid_numeric(AccuracyMetric::new()) + .metric_train_numeric(LossMetric::new()) + .metric_valid_numeric(LossMetric::new()) + .with_file_checkpointer(CompactRecorder::new()) + .devices(vec![device.clone()]) + .num_epochs(config.num_epochs) + .summary() + .build( + config.model.init::(&device), + config.optimizer.init(), + config.learning_rate, + ); + + let model_trained = learner.fit(dataloader_train, dataloader_test); + + model_trained + .save_file(format!("{artifact_dir}/model"), &CompactRecorder::new()) + .expect("Trained model should be saved successfully"); +}