generated from OBJNULL/Dockerized-Rust
Added Infer and Train to neural
This commit is contained in:
parent
e62131ba1f
commit
548112b89e
1 changed files with 52 additions and 3 deletions
|
@ -3,9 +3,14 @@ mod data;
|
|||
mod infrence;
|
||||
mod model;
|
||||
mod training;
|
||||
|
||||
use super::config::OperationMode;
|
||||
|
||||
use burn::{
|
||||
backend::{Autodiff, WebGpu},
|
||||
data::dataset::Dataset,
|
||||
optim::AdamConfig,
|
||||
};
|
||||
|
||||
// Structures
|
||||
pub struct NeuralNetwork {
|
||||
mode: OperationMode,
|
||||
|
@ -20,11 +25,55 @@ impl NeuralNetwork {
|
|||
}
|
||||
|
||||
// Functions
|
||||
fn train(&self) {
|
||||
type MyBackend = WebGpu<f32, i32>;
|
||||
type MyAutodiffBackend = Autodiff<MyBackend>;
|
||||
|
||||
// Create a default Wgpu device
|
||||
let device = burn::backend::wgpu::WgpuDevice::default();
|
||||
|
||||
// All the training artifacts will be saved in this directory
|
||||
let artifact_dir = "/tmp/guide";
|
||||
|
||||
// Train the model
|
||||
training::train::<MyAutodiffBackend>(
|
||||
artifact_dir,
|
||||
training::TrainingConfig::new(model::ModelConfig::new(10, 512), AdamConfig::new()),
|
||||
device.clone(),
|
||||
);
|
||||
|
||||
// Infer the model
|
||||
inference::infer::<MyBackend>(
|
||||
artifact_dir,
|
||||
device,
|
||||
burn::data::dataset::vision::MnistDataset::test()
|
||||
.get(42)
|
||||
.unwrap(),
|
||||
);
|
||||
}
|
||||
fn infer(&self) {
|
||||
type MyBackend = WebGpu<f32, i32>;
|
||||
|
||||
let device = burn::backend::wgpu::WgpuDevice::default();
|
||||
|
||||
// All the training artifacts are saved in this directory
|
||||
let artifact_dir = "/tmp/guide";
|
||||
|
||||
// Infer the model
|
||||
infrence::infer::<MyBackend>(
|
||||
artifact_dir,
|
||||
device,
|
||||
burn::data::dataset::vision::MnistDataset::test()
|
||||
.get(42)
|
||||
.unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn start(&self) {
|
||||
// Switching based on mode
|
||||
match self.mode {
|
||||
OperationMode::Training => {}
|
||||
OperationMode::Infrence => {}
|
||||
OperationMode::Training => self.train(),
|
||||
OperationMode::Infrence => self.infer(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue