Compare commits

...

2 commits

Author SHA1 Message Date
ef7046b2c8 Fixed Inference spelling mistake 2025-03-19 16:01:24 -04:00
b9b07c4b75 Code Cleanup and Store in Out folder 2025-03-19 16:00:27 -04:00
4 changed files with 16 additions and 18 deletions

View file

@ -4,7 +4,7 @@ use std::env;
// Enums
pub enum OperationMode {
Training,
Infrence,
Inference,
}
// Functions
@ -15,7 +15,7 @@ pub fn get_operation_mode() -> Option<OperationMode> {
// Getting operation mode
match args[1].as_str() {
"training" => Some(OperationMode::Training),
"infrence" => Some(OperationMode::Infrence),
"inference" => Some(OperationMode::Inference),
_ => None,
}
}

View file

@ -13,10 +13,10 @@ fn main() {
// Creating a Neural Network with the Operation Mode
match operation_mode {
None => panic!("Main: `OperationMode` not defined!"),
Some(mode) => {
neural = NeuralNetwork::new(mode);
}
},
_ => panic!("Main: `OperationMode` not defined!"),
}
// Starting the network

View file

@ -1,6 +1,6 @@
// Libraries
mod data;
mod infrence;
mod inference;
mod model;
mod training;
use super::config::OperationMode;
@ -11,6 +11,9 @@ use burn::{
optim::AdamConfig,
};
// Constants
const MODEL_DIRECTORY: &str = "./out";
// Structures
pub struct NeuralNetwork {
mode: OperationMode,
@ -26,25 +29,23 @@ impl NeuralNetwork {
// Functions
fn train(&self) {
// Creating Backend
type MyBackend = Wgpu<f32, i32>;
type MyAutodiffBackend = Autodiff<MyBackend>;
// Create a default Wgpu device
let device = burn::backend::wgpu::WgpuDevice::default();
// All the training artifacts will be saved in this directory
let artifact_dir = "/tmp/guide";
// Train the model
training::train::<MyAutodiffBackend>(
artifact_dir,
MODEL_DIRECTORY,
training::TrainingConfig::new(model::ModelConfig::new(10, 512), AdamConfig::new()),
device.clone(),
);
// Infer the model
infrence::infer::<MyBackend>(
artifact_dir,
inference::infer::<MyBackend>(
MODEL_DIRECTORY,
device,
burn::data::dataset::vision::MnistDataset::test()
.get(42)
@ -52,16 +53,13 @@ impl NeuralNetwork {
);
}
fn infer(&self) {
// Creating Backend
type MyBackend = Wgpu<f32, i32>;
let device = burn::backend::wgpu::WgpuDevice::default();
// All the training artifacts are saved in this directory
let artifact_dir = "/tmp/guide";
// Infer the model
infrence::infer::<MyBackend>(
artifact_dir,
inference::infer::<MyBackend>(
MODEL_DIRECTORY,
device,
burn::data::dataset::vision::MnistDataset::test()
.get(42)
@ -73,7 +71,7 @@ impl NeuralNetwork {
// Switching based on mode
match self.mode {
OperationMode::Training => self.train(),
OperationMode::Infrence => self.infer(),
OperationMode::Inference => self.infer(),
}
}
}