diff --git a/project/src/neural.rs b/project/src/neural.rs index a834664..812baa1 100644 --- a/project/src/neural.rs +++ b/project/src/neural.rs @@ -11,6 +11,9 @@ use burn::{ optim::AdamConfig, }; +// Constants +const MODEL_DIRECTORY: &str = "./out"; + // Structures pub struct NeuralNetwork { mode: OperationMode, @@ -26,25 +29,23 @@ impl NeuralNetwork { // Functions fn train(&self) { + // Creating Backend type MyBackend = Wgpu; type MyAutodiffBackend = Autodiff; // Create a default Wgpu device let device = burn::backend::wgpu::WgpuDevice::default(); - // All the training artifacts will be saved in this directory - let artifact_dir = "/tmp/guide"; - // Train the model training::train::( - artifact_dir, + MODEL_DIRECTORY, training::TrainingConfig::new(model::ModelConfig::new(10, 512), AdamConfig::new()), device.clone(), ); // Infer the model infrence::infer::( - artifact_dir, + MODEL_DIRECTORY, device, burn::data::dataset::vision::MnistDataset::test() .get(42) @@ -52,16 +53,13 @@ impl NeuralNetwork { ); } fn infer(&self) { + // Creating Backend type MyBackend = Wgpu; - let device = burn::backend::wgpu::WgpuDevice::default(); - // All the training artifacts are saved in this directory - let artifact_dir = "/tmp/guide"; - // Infer the model infrence::infer::( - artifact_dir, + MODEL_DIRECTORY, device, burn::data::dataset::vision::MnistDataset::test() .get(42)