Compare commits

..

No commits in common. "ef7046b2c84954e0c0aac083696e4bb9acbd95ee" and "46149b885c0cee867a97b18a525820355ae7fb16" have entirely different histories.

4 changed files with 18 additions and 16 deletions

View file

@ -4,7 +4,7 @@ use std::env;
// Enums // Enums
pub enum OperationMode { pub enum OperationMode {
Training, Training,
Inference, Infrence,
} }
// Functions // Functions
@ -15,7 +15,7 @@ pub fn get_operation_mode() -> Option<OperationMode> {
// Getting operation mode // Getting operation mode
match args[1].as_str() { match args[1].as_str() {
"training" => Some(OperationMode::Training), "training" => Some(OperationMode::Training),
"inference" => Some(OperationMode::Inference), "infrence" => Some(OperationMode::Infrence),
_ => None, _ => None,
} }
} }

View file

@ -13,10 +13,10 @@ fn main() {
// Creating a Neural Network with the Operation Mode // Creating a Neural Network with the Operation Mode
match operation_mode { match operation_mode {
None => panic!("Main: `OperationMode` not defined!"),
Some(mode) => { Some(mode) => {
neural = NeuralNetwork::new(mode); neural = NeuralNetwork::new(mode);
}, }
_ => panic!("Main: `OperationMode` not defined!"),
} }
// Starting the network // Starting the network

View file

@ -1,6 +1,6 @@
// Libraries // Libraries
mod data; mod data;
mod inference; mod infrence;
mod model; mod model;
mod training; mod training;
use super::config::OperationMode; use super::config::OperationMode;
@ -11,9 +11,6 @@ use burn::{
optim::AdamConfig, optim::AdamConfig,
}; };
// Constants
const MODEL_DIRECTORY: &str = "./out";
// Structures // Structures
pub struct NeuralNetwork { pub struct NeuralNetwork {
mode: OperationMode, mode: OperationMode,
@ -29,23 +26,25 @@ impl NeuralNetwork {
// Functions // Functions
fn train(&self) { fn train(&self) {
// Creating Backend
type MyBackend = Wgpu<f32, i32>; type MyBackend = Wgpu<f32, i32>;
type MyAutodiffBackend = Autodiff<MyBackend>; type MyAutodiffBackend = Autodiff<MyBackend>;
// Create a default Wgpu device // Create a default Wgpu device
let device = burn::backend::wgpu::WgpuDevice::default(); let device = burn::backend::wgpu::WgpuDevice::default();
// All the training artifacts will be saved in this directory
let artifact_dir = "/tmp/guide";
// Train the model // Train the model
training::train::<MyAutodiffBackend>( training::train::<MyAutodiffBackend>(
MODEL_DIRECTORY, artifact_dir,
training::TrainingConfig::new(model::ModelConfig::new(10, 512), AdamConfig::new()), training::TrainingConfig::new(model::ModelConfig::new(10, 512), AdamConfig::new()),
device.clone(), device.clone(),
); );
// Infer the model // Infer the model
inference::infer::<MyBackend>( infrence::infer::<MyBackend>(
MODEL_DIRECTORY, artifact_dir,
device, device,
burn::data::dataset::vision::MnistDataset::test() burn::data::dataset::vision::MnistDataset::test()
.get(42) .get(42)
@ -53,13 +52,16 @@ impl NeuralNetwork {
); );
} }
fn infer(&self) { fn infer(&self) {
// Creating Backend
type MyBackend = Wgpu<f32, i32>; type MyBackend = Wgpu<f32, i32>;
let device = burn::backend::wgpu::WgpuDevice::default(); let device = burn::backend::wgpu::WgpuDevice::default();
// All the training artifacts are saved in this directory
let artifact_dir = "/tmp/guide";
// Infer the model // Infer the model
inference::infer::<MyBackend>( infrence::infer::<MyBackend>(
MODEL_DIRECTORY, artifact_dir,
device, device,
burn::data::dataset::vision::MnistDataset::test() burn::data::dataset::vision::MnistDataset::test()
.get(42) .get(42)
@ -71,7 +73,7 @@ impl NeuralNetwork {
// Switching based on mode // Switching based on mode
match self.mode { match self.mode {
OperationMode::Training => self.train(), OperationMode::Training => self.train(),
OperationMode::Inference => self.infer(), OperationMode::Infrence => self.infer(),
} }
} }
} }