diff --git a/project/src/neural.rs b/project/src/neural.rs index 45e086c..5623f2c 100644 --- a/project/src/neural.rs +++ b/project/src/neural.rs @@ -3,9 +3,14 @@ mod data; mod infrence; mod model; mod training; - use super::config::OperationMode; +use burn::{ + backend::{Autodiff, WebGpu}, + data::dataset::Dataset, + optim::AdamConfig, +}; + // Structures pub struct NeuralNetwork { mode: OperationMode, @@ -20,11 +25,55 @@ impl NeuralNetwork { } // Functions + fn train(&self) { + type MyBackend = WebGpu; + type MyAutodiffBackend = Autodiff; + + // Create a default Wgpu device + let device = burn::backend::wgpu::WgpuDevice::default(); + + // All the training artifacts will be saved in this directory + let artifact_dir = "/tmp/guide"; + + // Train the model + training::train::( + artifact_dir, + training::TrainingConfig::new(model::ModelConfig::new(10, 512), AdamConfig::new()), + device.clone(), + ); + + // Infer the model + inference::infer::( + artifact_dir, + device, + burn::data::dataset::vision::MnistDataset::test() + .get(42) + .unwrap(), + ); + } + fn infer(&self) { + type MyBackend = WebGpu; + + let device = burn::backend::wgpu::WgpuDevice::default(); + + // All the training artifacts are saved in this directory + let artifact_dir = "/tmp/guide"; + + // Infer the model + infrence::infer::( + artifact_dir, + device, + burn::data::dataset::vision::MnistDataset::test() + .get(42) + .unwrap(), + ); + } + pub fn start(&self) { // Switching based on mode match self.mode { - OperationMode::Training => {} - OperationMode::Infrence => {} + OperationMode::Training => self.train(), + OperationMode::Infrence => self.infer(), } } }