generated from OBJNULL/Dockerized-Rust
Code Cleanup and Store in Out folder
This commit is contained in:
parent
46149b885c
commit
b9b07c4b75
1 changed files with 8 additions and 10 deletions
|
@ -11,6 +11,9 @@ use burn::{
|
||||||
optim::AdamConfig,
|
optim::AdamConfig,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Constants
|
||||||
|
const MODEL_DIRECTORY: &str = "./out";
|
||||||
|
|
||||||
// Structures
|
// Structures
|
||||||
pub struct NeuralNetwork {
|
pub struct NeuralNetwork {
|
||||||
mode: OperationMode,
|
mode: OperationMode,
|
||||||
|
@ -26,25 +29,23 @@ impl NeuralNetwork {
|
||||||
|
|
||||||
// Functions
|
// Functions
|
||||||
fn train(&self) {
|
fn train(&self) {
|
||||||
|
// Creating Backend
|
||||||
type MyBackend = Wgpu<f32, i32>;
|
type MyBackend = Wgpu<f32, i32>;
|
||||||
type MyAutodiffBackend = Autodiff<MyBackend>;
|
type MyAutodiffBackend = Autodiff<MyBackend>;
|
||||||
|
|
||||||
// Create a default Wgpu device
|
// Create a default Wgpu device
|
||||||
let device = burn::backend::wgpu::WgpuDevice::default();
|
let device = burn::backend::wgpu::WgpuDevice::default();
|
||||||
|
|
||||||
// All the training artifacts will be saved in this directory
|
|
||||||
let artifact_dir = "/tmp/guide";
|
|
||||||
|
|
||||||
// Train the model
|
// Train the model
|
||||||
training::train::<MyAutodiffBackend>(
|
training::train::<MyAutodiffBackend>(
|
||||||
artifact_dir,
|
MODEL_DIRECTORY,
|
||||||
training::TrainingConfig::new(model::ModelConfig::new(10, 512), AdamConfig::new()),
|
training::TrainingConfig::new(model::ModelConfig::new(10, 512), AdamConfig::new()),
|
||||||
device.clone(),
|
device.clone(),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Infer the model
|
// Infer the model
|
||||||
infrence::infer::<MyBackend>(
|
infrence::infer::<MyBackend>(
|
||||||
artifact_dir,
|
MODEL_DIRECTORY,
|
||||||
device,
|
device,
|
||||||
burn::data::dataset::vision::MnistDataset::test()
|
burn::data::dataset::vision::MnistDataset::test()
|
||||||
.get(42)
|
.get(42)
|
||||||
|
@ -52,16 +53,13 @@ impl NeuralNetwork {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
fn infer(&self) {
|
fn infer(&self) {
|
||||||
|
// Creating Backend
|
||||||
type MyBackend = Wgpu<f32, i32>;
|
type MyBackend = Wgpu<f32, i32>;
|
||||||
|
|
||||||
let device = burn::backend::wgpu::WgpuDevice::default();
|
let device = burn::backend::wgpu::WgpuDevice::default();
|
||||||
|
|
||||||
// All the training artifacts are saved in this directory
|
|
||||||
let artifact_dir = "/tmp/guide";
|
|
||||||
|
|
||||||
// Infer the model
|
// Infer the model
|
||||||
infrence::infer::<MyBackend>(
|
infrence::infer::<MyBackend>(
|
||||||
artifact_dir,
|
MODEL_DIRECTORY,
|
||||||
device,
|
device,
|
||||||
burn::data::dataset::vision::MnistDataset::test()
|
burn::data::dataset::vision::MnistDataset::test()
|
||||||
.get(42)
|
.get(42)
|
||||||
|
|
Loading…
Reference in a new issue