From b9b07c4b75f2525b5442a4312523543c4a782877 Mon Sep 17 00:00:00 2001 From: Maddox Werts Date: Wed, 19 Mar 2025 16:00:27 -0400 Subject: [PATCH] Code Cleanup and Store in Out folder --- project/src/neural.rs | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/project/src/neural.rs b/project/src/neural.rs index a834664..812baa1 100644 --- a/project/src/neural.rs +++ b/project/src/neural.rs @@ -11,6 +11,9 @@ use burn::{ optim::AdamConfig, }; +// Constants +const MODEL_DIRECTORY: &str = "./out"; + // Structures pub struct NeuralNetwork { mode: OperationMode, @@ -26,25 +29,23 @@ impl NeuralNetwork { // Functions fn train(&self) { + // Creating Backend type MyBackend = Wgpu; type MyAutodiffBackend = Autodiff; // Create a default Wgpu device let device = burn::backend::wgpu::WgpuDevice::default(); - // All the training artifacts will be saved in this directory - let artifact_dir = "/tmp/guide"; - // Train the model training::train::( - artifact_dir, + MODEL_DIRECTORY, training::TrainingConfig::new(model::ModelConfig::new(10, 512), AdamConfig::new()), device.clone(), ); // Infer the model infrence::infer::( - artifact_dir, + MODEL_DIRECTORY, device, burn::data::dataset::vision::MnistDataset::test() .get(42) @@ -52,16 +53,13 @@ impl NeuralNetwork { ); } fn infer(&self) { + // Creating Backend type MyBackend = Wgpu; - let device = burn::backend::wgpu::WgpuDevice::default(); - // All the training artifacts are saved in this directory - let artifact_dir = "/tmp/guide"; - // Infer the model infrence::infer::( - artifact_dir, + MODEL_DIRECTORY, device, burn::data::dataset::vision::MnistDataset::test() .get(42)