generated from OBJNULL/Dockerized-Rust
Compare commits
2 commits
46149b885c
...
ef7046b2c8
Author | SHA1 | Date | |
---|---|---|---|
ef7046b2c8 | |||
b9b07c4b75 |
4 changed files with 16 additions and 18 deletions
|
@ -4,7 +4,7 @@ use std::env;
|
|||
// Enums
|
||||
pub enum OperationMode {
|
||||
Training,
|
||||
Infrence,
|
||||
Inference,
|
||||
}
|
||||
|
||||
// Functions
|
||||
|
@ -15,7 +15,7 @@ pub fn get_operation_mode() -> Option<OperationMode> {
|
|||
// Getting operation mode
|
||||
match args[1].as_str() {
|
||||
"training" => Some(OperationMode::Training),
|
||||
"infrence" => Some(OperationMode::Infrence),
|
||||
"inference" => Some(OperationMode::Inference),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,10 +13,10 @@ fn main() {
|
|||
|
||||
// Creating a Neural Network with the Operation Mode
|
||||
match operation_mode {
|
||||
None => panic!("Main: `OperationMode` not defined!"),
|
||||
Some(mode) => {
|
||||
neural = NeuralNetwork::new(mode);
|
||||
}
|
||||
},
|
||||
_ => panic!("Main: `OperationMode` not defined!"),
|
||||
}
|
||||
|
||||
// Starting the network
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Libraries
|
||||
mod data;
|
||||
mod infrence;
|
||||
mod inference;
|
||||
mod model;
|
||||
mod training;
|
||||
use super::config::OperationMode;
|
||||
|
@ -11,6 +11,9 @@ use burn::{
|
|||
optim::AdamConfig,
|
||||
};
|
||||
|
||||
// Constants
|
||||
const MODEL_DIRECTORY: &str = "./out";
|
||||
|
||||
// Structures
|
||||
pub struct NeuralNetwork {
|
||||
mode: OperationMode,
|
||||
|
@ -26,25 +29,23 @@ impl NeuralNetwork {
|
|||
|
||||
// Functions
|
||||
fn train(&self) {
|
||||
// Creating Backend
|
||||
type MyBackend = Wgpu<f32, i32>;
|
||||
type MyAutodiffBackend = Autodiff<MyBackend>;
|
||||
|
||||
// Create a default Wgpu device
|
||||
let device = burn::backend::wgpu::WgpuDevice::default();
|
||||
|
||||
// All the training artifacts will be saved in this directory
|
||||
let artifact_dir = "/tmp/guide";
|
||||
|
||||
// Train the model
|
||||
training::train::<MyAutodiffBackend>(
|
||||
artifact_dir,
|
||||
MODEL_DIRECTORY,
|
||||
training::TrainingConfig::new(model::ModelConfig::new(10, 512), AdamConfig::new()),
|
||||
device.clone(),
|
||||
);
|
||||
|
||||
// Infer the model
|
||||
infrence::infer::<MyBackend>(
|
||||
artifact_dir,
|
||||
inference::infer::<MyBackend>(
|
||||
MODEL_DIRECTORY,
|
||||
device,
|
||||
burn::data::dataset::vision::MnistDataset::test()
|
||||
.get(42)
|
||||
|
@ -52,16 +53,13 @@ impl NeuralNetwork {
|
|||
);
|
||||
}
|
||||
fn infer(&self) {
|
||||
// Creating Backend
|
||||
type MyBackend = Wgpu<f32, i32>;
|
||||
|
||||
let device = burn::backend::wgpu::WgpuDevice::default();
|
||||
|
||||
// All the training artifacts are saved in this directory
|
||||
let artifact_dir = "/tmp/guide";
|
||||
|
||||
// Infer the model
|
||||
infrence::infer::<MyBackend>(
|
||||
artifact_dir,
|
||||
inference::infer::<MyBackend>(
|
||||
MODEL_DIRECTORY,
|
||||
device,
|
||||
burn::data::dataset::vision::MnistDataset::test()
|
||||
.get(42)
|
||||
|
@ -73,7 +71,7 @@ impl NeuralNetwork {
|
|||
// Switching based on mode
|
||||
match self.mode {
|
||||
OperationMode::Training => self.train(),
|
||||
OperationMode::Infrence => self.infer(),
|
||||
OperationMode::Inference => self.infer(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue