diff --git a/project/src/neural.rs b/project/src/neural.rs new file mode 100644 index 0000000..45e086c --- /dev/null +++ b/project/src/neural.rs @@ -0,0 +1,30 @@ +// Libraries +mod data; +mod infrence; +mod model; +mod training; + +use super::config::OperationMode; + +// Structures +pub struct NeuralNetwork { + mode: OperationMode, +} + +// Implementaions +impl NeuralNetwork { + // Constructors + pub fn new(mode: OperationMode) -> Self { + // Return Result + Self { mode } + } + + // Functions + pub fn start(&self) { + // Switching based on mode + match self.mode { + OperationMode::Training => {} + OperationMode::Infrence => {} + } + } +} diff --git a/project/src/training.rs b/project/src/training.rs deleted file mode 100644 index 682e9bb..0000000 --- a/project/src/training.rs +++ /dev/null @@ -1,112 +0,0 @@ -use crate::{ - data::{MnistBatch, MnistBatcher}, - model::{Model, ModelConfig}, -}; -use burn::{ - data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset}, - nn::loss::CrossEntropyLossConfig, - optim::AdamConfig, - prelude::*, - record::CompactRecorder, - tensor::backend::AutodiffBackend, - train::{ - metric::{AccuracyMetric, LossMetric}, - ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep, - }, -}; - -impl Model { - pub fn forward_classification( - &self, - images: Tensor, - targets: Tensor, - ) -> ClassificationOutput { - let output = self.forward(images); - let loss = CrossEntropyLossConfig::new() - .init(&output.device()) - .forward(output.clone(), targets.clone()); - - ClassificationOutput::new(loss, output, targets) - } -} - -impl TrainStep, ClassificationOutput> for Model { - fn step(&self, batch: MnistBatch) -> TrainOutput> { - let item = self.forward_classification(batch.images, batch.targets); - - TrainOutput::new(self, item.loss.backward(), item) - } -} - -impl ValidStep, ClassificationOutput> for Model { - fn step(&self, batch: MnistBatch) -> ClassificationOutput { - self.forward_classification(batch.images, batch.targets) - } -} - -#[derive(Config)] -pub struct TrainingConfig { - pub model: ModelConfig, - pub optimizer: AdamConfig, - #[config(default = 10)] - pub num_epochs: usize, - #[config(default = 64)] - pub batch_size: usize, - #[config(default = 4)] - pub num_workers: usize, - #[config(default = 42)] - pub seed: u64, - #[config(default = 1.0e-4)] - pub learning_rate: f64, -} - -fn create_artifact_dir(artifact_dir: &str) { - // Remove existing artifacts before to get an accurate learner summary - std::fs::remove_dir_all(artifact_dir).ok(); - std::fs::create_dir_all(artifact_dir).ok(); -} - -pub fn train(artifact_dir: &str, config: TrainingConfig, device: B::Device) { - create_artifact_dir(artifact_dir); - config - .save(format!("{artifact_dir}/config.json")) - .expect("Config should be saved successfully"); - - B::seed(config.seed); - - let batcher_train = MnistBatcher::::new(device.clone()); - let batcher_valid = MnistBatcher::::new(device.clone()); - - let dataloader_train = DataLoaderBuilder::new(batcher_train) - .batch_size(config.batch_size) - .shuffle(config.seed) - .num_workers(config.num_workers) - .build(MnistDataset::train()); - - let dataloader_test = DataLoaderBuilder::new(batcher_valid) - .batch_size(config.batch_size) - .shuffle(config.seed) - .num_workers(config.num_workers) - .build(MnistDataset::test()); - - let learner = LearnerBuilder::new(artifact_dir) - .metric_train_numeric(AccuracyMetric::new()) - .metric_valid_numeric(AccuracyMetric::new()) - .metric_train_numeric(LossMetric::new()) - .metric_valid_numeric(LossMetric::new()) - .with_file_checkpointer(CompactRecorder::new()) - .devices(vec![device.clone()]) - .num_epochs(config.num_epochs) - .summary() - .build( - config.model.init::(&device), - config.optimizer.init(), - config.learning_rate, - ); - - let model_trained = learner.fit(dataloader_train, dataloader_test); - - model_trained - .save_file(format!("{artifact_dir}/model"), &CompactRecorder::new()) - .expect("Trained model should be saved successfully"); -}