Added Infer and Train to neural

This commit is contained in:
Maddox Werts 2025-03-17 10:19:58 -04:00
parent e62131ba1f
commit 548112b89e

View file

@ -3,9 +3,14 @@ mod data;
mod infrence; mod infrence;
mod model; mod model;
mod training; mod training;
use super::config::OperationMode; use super::config::OperationMode;
use burn::{
backend::{Autodiff, WebGpu},
data::dataset::Dataset,
optim::AdamConfig,
};
// Structures // Structures
pub struct NeuralNetwork { pub struct NeuralNetwork {
mode: OperationMode, mode: OperationMode,
@ -20,11 +25,55 @@ impl NeuralNetwork {
} }
// Functions // Functions
fn train(&self) {
type MyBackend = WebGpu<f32, i32>;
type MyAutodiffBackend = Autodiff<MyBackend>;
// Create a default Wgpu device
let device = burn::backend::wgpu::WgpuDevice::default();
// All the training artifacts will be saved in this directory
let artifact_dir = "/tmp/guide";
// Train the model
training::train::<MyAutodiffBackend>(
artifact_dir,
training::TrainingConfig::new(model::ModelConfig::new(10, 512), AdamConfig::new()),
device.clone(),
);
// Infer the model
inference::infer::<MyBackend>(
artifact_dir,
device,
burn::data::dataset::vision::MnistDataset::test()
.get(42)
.unwrap(),
);
}
fn infer(&self) {
type MyBackend = WebGpu<f32, i32>;
let device = burn::backend::wgpu::WgpuDevice::default();
// All the training artifacts are saved in this directory
let artifact_dir = "/tmp/guide";
// Infer the model
infrence::infer::<MyBackend>(
artifact_dir,
device,
burn::data::dataset::vision::MnistDataset::test()
.get(42)
.unwrap(),
);
}
pub fn start(&self) { pub fn start(&self) {
// Switching based on mode // Switching based on mode
match self.mode { match self.mode {
OperationMode::Training => {} OperationMode::Training => self.train(),
OperationMode::Infrence => {} OperationMode::Infrence => self.infer(),
} }
} }
} }