generated from OBJNULL/Dockerized-Rust
Code Cleanup and Store in Out folder
This commit is contained in:
parent
46149b885c
commit
b9b07c4b75
1 changed files with 8 additions and 10 deletions
|
@ -11,6 +11,9 @@ use burn::{
|
|||
optim::AdamConfig,
|
||||
};
|
||||
|
||||
// Constants
|
||||
const MODEL_DIRECTORY: &str = "./out";
|
||||
|
||||
// Structures
|
||||
pub struct NeuralNetwork {
|
||||
mode: OperationMode,
|
||||
|
@ -26,25 +29,23 @@ impl NeuralNetwork {
|
|||
|
||||
// Functions
|
||||
fn train(&self) {
|
||||
// Creating Backend
|
||||
type MyBackend = Wgpu<f32, i32>;
|
||||
type MyAutodiffBackend = Autodiff<MyBackend>;
|
||||
|
||||
// Create a default Wgpu device
|
||||
let device = burn::backend::wgpu::WgpuDevice::default();
|
||||
|
||||
// All the training artifacts will be saved in this directory
|
||||
let artifact_dir = "/tmp/guide";
|
||||
|
||||
// Train the model
|
||||
training::train::<MyAutodiffBackend>(
|
||||
artifact_dir,
|
||||
MODEL_DIRECTORY,
|
||||
training::TrainingConfig::new(model::ModelConfig::new(10, 512), AdamConfig::new()),
|
||||
device.clone(),
|
||||
);
|
||||
|
||||
// Infer the model
|
||||
infrence::infer::<MyBackend>(
|
||||
artifact_dir,
|
||||
MODEL_DIRECTORY,
|
||||
device,
|
||||
burn::data::dataset::vision::MnistDataset::test()
|
||||
.get(42)
|
||||
|
@ -52,16 +53,13 @@ impl NeuralNetwork {
|
|||
);
|
||||
}
|
||||
fn infer(&self) {
|
||||
// Creating Backend
|
||||
type MyBackend = Wgpu<f32, i32>;
|
||||
|
||||
let device = burn::backend::wgpu::WgpuDevice::default();
|
||||
|
||||
// All the training artifacts are saved in this directory
|
||||
let artifact_dir = "/tmp/guide";
|
||||
|
||||
// Infer the model
|
||||
infrence::infer::<MyBackend>(
|
||||
artifact_dir,
|
||||
MODEL_DIRECTORY,
|
||||
device,
|
||||
burn::data::dataset::vision::MnistDataset::test()
|
||||
.get(42)
|
||||
|
|
Loading…
Reference in a new issue