diff --git a/project/src/data.rs b/project/src/data.rs new file mode 100644 index 0000000..3adc7c5 --- /dev/null +++ b/project/src/data.rs @@ -0,0 +1,57 @@ +// Libraries +use burn::{ + data::{dataloader::batcher::Batcher, dataset::vision::MnistItem}, + prelude::*, +}; + +use burn::{ + data::{dataloader::batcher::Batcher, dataset::vision::MnistItem}, + prelude::*, +}; + +#[derive(Clone)] +pub struct MnistBatcher { + device: B::Device, +} + +impl MnistBatcher { + pub fn new(device: B::Device) -> Self { + Self { device } + } +} + +#[derive(Clone, Debug)] +pub struct MnistBatch { + pub images: Tensor, + pub targets: Tensor, +} + +impl Batcher> for MnistBatcher { + fn batch(&self, items: Vec) -> MnistBatch { + let images = items + .iter() + .map(|item| TensorData::from(item.image).convert::()) + .map(|data| Tensor::::from_data(data, &self.device)) + .map(|tensor| tensor.reshape([1, 28, 28])) + // normalize: make between [0,1] and make the mean = 0 and std = 1 + // values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example + // https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122 + .map(|tensor| ((tensor / 255) - 0.1307) / 0.3081) + .collect(); + + let targets = items + .iter() + .map(|item| { + Tensor::::from_data( + [(item.label as i64).elem::()], + &self.device, + ) + }) + .collect(); + + let images = Tensor::cat(images, 0); + let targets = Tensor::cat(targets, 0); + + MnistBatch { images, targets } + } +} diff --git a/project/src/infrence.rs b/project/src/infrence.rs new file mode 100644 index 0000000..9c7477f --- /dev/null +++ b/project/src/infrence.rs @@ -0,0 +1,24 @@ +use crate::{data::MnistBatcher, model::Model, training::TrainingConfig}; +use burn::{ + data::{dataloader::batcher::Batcher, dataset::vision::MnistItem}, + prelude::*, + record::{CompactRecorder, Recorder}, +}; + +pub fn infer(artifact_dir: &str, device: B::Device, item: MnistItem) { + let config = TrainingConfig::load(format!("{artifact_dir}/config.json")) + .expect("Config should exist for the model; run train first"); + let record = CompactRecorder::new() + .load(format!("{artifact_dir}/model").into(), &device) + .expect("Trained model should exist; run train first"); + + let model: Model = config.model.init(&device).load_record(record); + + let label = item.label; + let batcher = MnistBatcher::new(device); + let batch = batcher.batch(vec![item]); + let output = model.forward(batch.images); + let predicted = output.argmax(1).flatten::<1>(0, 1).into_scalar(); + + println!("Predicted {} Expected {}", predicted, label); +} diff --git a/project/src/model.rs b/project/src/model.rs new file mode 100644 index 0000000..705729c --- /dev/null +++ b/project/src/model.rs @@ -0,0 +1,68 @@ +use burn::{ + nn::{ + conv::{Conv2d, Conv2dConfig}, + pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig}, + Dropout, DropoutConfig, Linear, LinearConfig, Relu, + }, + prelude::*, +}; + +#[derive(Module, Debug)] +pub struct Model { + conv1: Conv2d, + conv2: Conv2d, + pool: AdaptiveAvgPool2d, + dropout: Dropout, + linear1: Linear, + linear2: Linear, + activation: Relu, +} + +#[derive(Config, Debug)] +pub struct ModelConfig { + num_classes: usize, + hidden_size: usize, + #[config(default = "0.5")] + dropout: f64, +} + +impl ModelConfig { + /// Returns the initialized model. + pub fn init(&self, device: &B::Device) -> Model { + Model { + conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device), + conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device), + pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(), + activation: Relu::new(), + linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device), + linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device), + dropout: DropoutConfig::new(self.dropout).init(), + } + } +} + +impl Model { + /// # Shapes + /// - Images [batch_size, height, width] + /// - Output [batch_size, class_prob] + pub fn forward(&self, images: Tensor) -> Tensor { + let [batch_size, height, width] = images.dims(); + + // Create a channel. + let x = images.reshape([batch_size, 1, height, width]); + + let x = self.conv1.forward(x); // [batch_size, 8, _, _] + let x = self.dropout.forward(x); + let x = self.conv2.forward(x); // [batch_size, 16, _, _] + let x = self.dropout.forward(x); + let x = self.activation.forward(x); + + let x = self.pool.forward(x); // [batch_size, 16, 8, 8] + let x = x.reshape([batch_size, 16 * 8 * 8]); + let x = self.linear1.forward(x); + let x = self.dropout.forward(x); + let x = self.activation.forward(x); + + self.linear2.forward(x) // [batch_size, num_classes] + } +} diff --git a/project/src/training.rs b/project/src/training.rs new file mode 100644 index 0000000..682e9bb --- /dev/null +++ b/project/src/training.rs @@ -0,0 +1,112 @@ +use crate::{ + data::{MnistBatch, MnistBatcher}, + model::{Model, ModelConfig}, +}; +use burn::{ + data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset}, + nn::loss::CrossEntropyLossConfig, + optim::AdamConfig, + prelude::*, + record::CompactRecorder, + tensor::backend::AutodiffBackend, + train::{ + metric::{AccuracyMetric, LossMetric}, + ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep, + }, +}; + +impl Model { + pub fn forward_classification( + &self, + images: Tensor, + targets: Tensor, + ) -> ClassificationOutput { + let output = self.forward(images); + let loss = CrossEntropyLossConfig::new() + .init(&output.device()) + .forward(output.clone(), targets.clone()); + + ClassificationOutput::new(loss, output, targets) + } +} + +impl TrainStep, ClassificationOutput> for Model { + fn step(&self, batch: MnistBatch) -> TrainOutput> { + let item = self.forward_classification(batch.images, batch.targets); + + TrainOutput::new(self, item.loss.backward(), item) + } +} + +impl ValidStep, ClassificationOutput> for Model { + fn step(&self, batch: MnistBatch) -> ClassificationOutput { + self.forward_classification(batch.images, batch.targets) + } +} + +#[derive(Config)] +pub struct TrainingConfig { + pub model: ModelConfig, + pub optimizer: AdamConfig, + #[config(default = 10)] + pub num_epochs: usize, + #[config(default = 64)] + pub batch_size: usize, + #[config(default = 4)] + pub num_workers: usize, + #[config(default = 42)] + pub seed: u64, + #[config(default = 1.0e-4)] + pub learning_rate: f64, +} + +fn create_artifact_dir(artifact_dir: &str) { + // Remove existing artifacts before to get an accurate learner summary + std::fs::remove_dir_all(artifact_dir).ok(); + std::fs::create_dir_all(artifact_dir).ok(); +} + +pub fn train(artifact_dir: &str, config: TrainingConfig, device: B::Device) { + create_artifact_dir(artifact_dir); + config + .save(format!("{artifact_dir}/config.json")) + .expect("Config should be saved successfully"); + + B::seed(config.seed); + + let batcher_train = MnistBatcher::::new(device.clone()); + let batcher_valid = MnistBatcher::::new(device.clone()); + + let dataloader_train = DataLoaderBuilder::new(batcher_train) + .batch_size(config.batch_size) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(MnistDataset::train()); + + let dataloader_test = DataLoaderBuilder::new(batcher_valid) + .batch_size(config.batch_size) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(MnistDataset::test()); + + let learner = LearnerBuilder::new(artifact_dir) + .metric_train_numeric(AccuracyMetric::new()) + .metric_valid_numeric(AccuracyMetric::new()) + .metric_train_numeric(LossMetric::new()) + .metric_valid_numeric(LossMetric::new()) + .with_file_checkpointer(CompactRecorder::new()) + .devices(vec![device.clone()]) + .num_epochs(config.num_epochs) + .summary() + .build( + config.model.init::(&device), + config.optimizer.init(), + config.learning_rate, + ); + + let model_trained = learner.fit(dataloader_train, dataloader_test); + + model_trained + .save_file(format!("{artifact_dir}/model"), &CompactRecorder::new()) + .expect("Trained model should be saved successfully"); +}