generated from OBJNULL/Dockerized-Rust
Added Infer and Train to neural
This commit is contained in:
parent
e62131ba1f
commit
548112b89e
1 changed files with 52 additions and 3 deletions
|
@ -3,9 +3,14 @@ mod data;
|
||||||
mod infrence;
|
mod infrence;
|
||||||
mod model;
|
mod model;
|
||||||
mod training;
|
mod training;
|
||||||
|
|
||||||
use super::config::OperationMode;
|
use super::config::OperationMode;
|
||||||
|
|
||||||
|
use burn::{
|
||||||
|
backend::{Autodiff, WebGpu},
|
||||||
|
data::dataset::Dataset,
|
||||||
|
optim::AdamConfig,
|
||||||
|
};
|
||||||
|
|
||||||
// Structures
|
// Structures
|
||||||
pub struct NeuralNetwork {
|
pub struct NeuralNetwork {
|
||||||
mode: OperationMode,
|
mode: OperationMode,
|
||||||
|
@ -20,11 +25,55 @@ impl NeuralNetwork {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Functions
|
// Functions
|
||||||
|
fn train(&self) {
|
||||||
|
type MyBackend = WebGpu<f32, i32>;
|
||||||
|
type MyAutodiffBackend = Autodiff<MyBackend>;
|
||||||
|
|
||||||
|
// Create a default Wgpu device
|
||||||
|
let device = burn::backend::wgpu::WgpuDevice::default();
|
||||||
|
|
||||||
|
// All the training artifacts will be saved in this directory
|
||||||
|
let artifact_dir = "/tmp/guide";
|
||||||
|
|
||||||
|
// Train the model
|
||||||
|
training::train::<MyAutodiffBackend>(
|
||||||
|
artifact_dir,
|
||||||
|
training::TrainingConfig::new(model::ModelConfig::new(10, 512), AdamConfig::new()),
|
||||||
|
device.clone(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Infer the model
|
||||||
|
inference::infer::<MyBackend>(
|
||||||
|
artifact_dir,
|
||||||
|
device,
|
||||||
|
burn::data::dataset::vision::MnistDataset::test()
|
||||||
|
.get(42)
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
fn infer(&self) {
|
||||||
|
type MyBackend = WebGpu<f32, i32>;
|
||||||
|
|
||||||
|
let device = burn::backend::wgpu::WgpuDevice::default();
|
||||||
|
|
||||||
|
// All the training artifacts are saved in this directory
|
||||||
|
let artifact_dir = "/tmp/guide";
|
||||||
|
|
||||||
|
// Infer the model
|
||||||
|
infrence::infer::<MyBackend>(
|
||||||
|
artifact_dir,
|
||||||
|
device,
|
||||||
|
burn::data::dataset::vision::MnistDataset::test()
|
||||||
|
.get(42)
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
pub fn start(&self) {
|
pub fn start(&self) {
|
||||||
// Switching based on mode
|
// Switching based on mode
|
||||||
match self.mode {
|
match self.mode {
|
||||||
OperationMode::Training => {}
|
OperationMode::Training => self.train(),
|
||||||
OperationMode::Infrence => {}
|
OperationMode::Infrence => self.infer(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue