Code Cleanup and Store in Out folder

This commit is contained in:
Maddox Werts 2025-03-19 16:00:27 -04:00
parent 46149b885c
commit b9b07c4b75

View file

@ -11,6 +11,9 @@ use burn::{
optim::AdamConfig, optim::AdamConfig,
}; };
// Constants
const MODEL_DIRECTORY: &str = "./out";
// Structures // Structures
pub struct NeuralNetwork { pub struct NeuralNetwork {
mode: OperationMode, mode: OperationMode,
@ -26,25 +29,23 @@ impl NeuralNetwork {
// Functions // Functions
fn train(&self) { fn train(&self) {
// Creating Backend
type MyBackend = Wgpu<f32, i32>; type MyBackend = Wgpu<f32, i32>;
type MyAutodiffBackend = Autodiff<MyBackend>; type MyAutodiffBackend = Autodiff<MyBackend>;
// Create a default Wgpu device // Create a default Wgpu device
let device = burn::backend::wgpu::WgpuDevice::default(); let device = burn::backend::wgpu::WgpuDevice::default();
// All the training artifacts will be saved in this directory
let artifact_dir = "/tmp/guide";
// Train the model // Train the model
training::train::<MyAutodiffBackend>( training::train::<MyAutodiffBackend>(
artifact_dir, MODEL_DIRECTORY,
training::TrainingConfig::new(model::ModelConfig::new(10, 512), AdamConfig::new()), training::TrainingConfig::new(model::ModelConfig::new(10, 512), AdamConfig::new()),
device.clone(), device.clone(),
); );
// Infer the model // Infer the model
infrence::infer::<MyBackend>( infrence::infer::<MyBackend>(
artifact_dir, MODEL_DIRECTORY,
device, device,
burn::data::dataset::vision::MnistDataset::test() burn::data::dataset::vision::MnistDataset::test()
.get(42) .get(42)
@ -52,16 +53,13 @@ impl NeuralNetwork {
); );
} }
fn infer(&self) { fn infer(&self) {
// Creating Backend
type MyBackend = Wgpu<f32, i32>; type MyBackend = Wgpu<f32, i32>;
let device = burn::backend::wgpu::WgpuDevice::default(); let device = burn::backend::wgpu::WgpuDevice::default();
// All the training artifacts are saved in this directory
let artifact_dir = "/tmp/guide";
// Infer the model // Infer the model
infrence::infer::<MyBackend>( infrence::infer::<MyBackend>(
artifact_dir, MODEL_DIRECTORY,
device, device,
burn::data::dataset::vision::MnistDataset::test() burn::data::dataset::vision::MnistDataset::test()
.get(42) .get(42)