generated from OBJNULL/Dockerized-Rust
Compare commits
2 commits
46149b885c
...
ef7046b2c8
Author | SHA1 | Date | |
---|---|---|---|
ef7046b2c8 | |||
b9b07c4b75 |
4 changed files with 16 additions and 18 deletions
|
@ -4,7 +4,7 @@ use std::env;
|
||||||
// Enums
|
// Enums
|
||||||
pub enum OperationMode {
|
pub enum OperationMode {
|
||||||
Training,
|
Training,
|
||||||
Infrence,
|
Inference,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Functions
|
// Functions
|
||||||
|
@ -15,7 +15,7 @@ pub fn get_operation_mode() -> Option<OperationMode> {
|
||||||
// Getting operation mode
|
// Getting operation mode
|
||||||
match args[1].as_str() {
|
match args[1].as_str() {
|
||||||
"training" => Some(OperationMode::Training),
|
"training" => Some(OperationMode::Training),
|
||||||
"infrence" => Some(OperationMode::Infrence),
|
"inference" => Some(OperationMode::Inference),
|
||||||
_ => None,
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,10 +13,10 @@ fn main() {
|
||||||
|
|
||||||
// Creating a Neural Network with the Operation Mode
|
// Creating a Neural Network with the Operation Mode
|
||||||
match operation_mode {
|
match operation_mode {
|
||||||
None => panic!("Main: `OperationMode` not defined!"),
|
|
||||||
Some(mode) => {
|
Some(mode) => {
|
||||||
neural = NeuralNetwork::new(mode);
|
neural = NeuralNetwork::new(mode);
|
||||||
}
|
},
|
||||||
|
_ => panic!("Main: `OperationMode` not defined!"),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Starting the network
|
// Starting the network
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
// Libraries
|
// Libraries
|
||||||
mod data;
|
mod data;
|
||||||
mod infrence;
|
mod inference;
|
||||||
mod model;
|
mod model;
|
||||||
mod training;
|
mod training;
|
||||||
use super::config::OperationMode;
|
use super::config::OperationMode;
|
||||||
|
@ -11,6 +11,9 @@ use burn::{
|
||||||
optim::AdamConfig,
|
optim::AdamConfig,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Constants
|
||||||
|
const MODEL_DIRECTORY: &str = "./out";
|
||||||
|
|
||||||
// Structures
|
// Structures
|
||||||
pub struct NeuralNetwork {
|
pub struct NeuralNetwork {
|
||||||
mode: OperationMode,
|
mode: OperationMode,
|
||||||
|
@ -26,25 +29,23 @@ impl NeuralNetwork {
|
||||||
|
|
||||||
// Functions
|
// Functions
|
||||||
fn train(&self) {
|
fn train(&self) {
|
||||||
|
// Creating Backend
|
||||||
type MyBackend = Wgpu<f32, i32>;
|
type MyBackend = Wgpu<f32, i32>;
|
||||||
type MyAutodiffBackend = Autodiff<MyBackend>;
|
type MyAutodiffBackend = Autodiff<MyBackend>;
|
||||||
|
|
||||||
// Create a default Wgpu device
|
// Create a default Wgpu device
|
||||||
let device = burn::backend::wgpu::WgpuDevice::default();
|
let device = burn::backend::wgpu::WgpuDevice::default();
|
||||||
|
|
||||||
// All the training artifacts will be saved in this directory
|
|
||||||
let artifact_dir = "/tmp/guide";
|
|
||||||
|
|
||||||
// Train the model
|
// Train the model
|
||||||
training::train::<MyAutodiffBackend>(
|
training::train::<MyAutodiffBackend>(
|
||||||
artifact_dir,
|
MODEL_DIRECTORY,
|
||||||
training::TrainingConfig::new(model::ModelConfig::new(10, 512), AdamConfig::new()),
|
training::TrainingConfig::new(model::ModelConfig::new(10, 512), AdamConfig::new()),
|
||||||
device.clone(),
|
device.clone(),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Infer the model
|
// Infer the model
|
||||||
infrence::infer::<MyBackend>(
|
inference::infer::<MyBackend>(
|
||||||
artifact_dir,
|
MODEL_DIRECTORY,
|
||||||
device,
|
device,
|
||||||
burn::data::dataset::vision::MnistDataset::test()
|
burn::data::dataset::vision::MnistDataset::test()
|
||||||
.get(42)
|
.get(42)
|
||||||
|
@ -52,16 +53,13 @@ impl NeuralNetwork {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
fn infer(&self) {
|
fn infer(&self) {
|
||||||
|
// Creating Backend
|
||||||
type MyBackend = Wgpu<f32, i32>;
|
type MyBackend = Wgpu<f32, i32>;
|
||||||
|
|
||||||
let device = burn::backend::wgpu::WgpuDevice::default();
|
let device = burn::backend::wgpu::WgpuDevice::default();
|
||||||
|
|
||||||
// All the training artifacts are saved in this directory
|
|
||||||
let artifact_dir = "/tmp/guide";
|
|
||||||
|
|
||||||
// Infer the model
|
// Infer the model
|
||||||
infrence::infer::<MyBackend>(
|
inference::infer::<MyBackend>(
|
||||||
artifact_dir,
|
MODEL_DIRECTORY,
|
||||||
device,
|
device,
|
||||||
burn::data::dataset::vision::MnistDataset::test()
|
burn::data::dataset::vision::MnistDataset::test()
|
||||||
.get(42)
|
.get(42)
|
||||||
|
@ -73,7 +71,7 @@ impl NeuralNetwork {
|
||||||
// Switching based on mode
|
// Switching based on mode
|
||||||
match self.mode {
|
match self.mode {
|
||||||
OperationMode::Training => self.train(),
|
OperationMode::Training => self.train(),
|
||||||
OperationMode::Infrence => self.infer(),
|
OperationMode::Inference => self.infer(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue