generated from OBJNULL/Dockerized-Rust
Compare commits
7 commits
bd5c22d9c6
...
46149b885c
Author | SHA1 | Date | |
---|---|---|---|
46149b885c | |||
8608676533 | |||
73da28805b | |||
48c709414c | |||
224ef39603 | |||
31c30ed200 | |||
8b918113ac |
7 changed files with 10 additions and 14 deletions
|
@ -7,3 +7,6 @@ edition = "2021"
|
||||||
burn = { version = "0.16.0", features = ["wgpu", "train", "vision"] }
|
burn = { version = "0.16.0", features = ["wgpu", "train", "vision"] }
|
||||||
log = { version = "0.4.26" }
|
log = { version = "0.4.26" }
|
||||||
serde = { version = "1.0.219", features = ["std", "derive"] }
|
serde = { version = "1.0.219", features = ["std", "derive"] }
|
||||||
|
|
||||||
|
bincode = "=2.0.0-rc.3"
|
||||||
|
bincode_derive = "=2.0.0-rc.3"
|
||||||
|
|
|
@ -13,7 +13,7 @@ pub fn get_operation_mode() -> Option<OperationMode> {
|
||||||
let args: Vec<String> = env::args().collect();
|
let args: Vec<String> = env::args().collect();
|
||||||
|
|
||||||
// Getting operation mode
|
// Getting operation mode
|
||||||
match &args[1] {
|
match args[1].as_str() {
|
||||||
"training" => Some(OperationMode::Training),
|
"training" => Some(OperationMode::Training),
|
||||||
"infrence" => Some(OperationMode::Infrence),
|
"infrence" => Some(OperationMode::Infrence),
|
||||||
_ => None,
|
_ => None,
|
||||||
|
|
|
@ -3,8 +3,6 @@ mod config;
|
||||||
mod neural;
|
mod neural;
|
||||||
use neural::NeuralNetwork;
|
use neural::NeuralNetwork;
|
||||||
|
|
||||||
use std::error::Error;
|
|
||||||
|
|
||||||
// Entry-Point
|
// Entry-Point
|
||||||
fn main() {
|
fn main() {
|
||||||
// Getting Running Mode First
|
// Getting Running Mode First
|
||||||
|
|
|
@ -6,7 +6,7 @@ mod training;
|
||||||
use super::config::OperationMode;
|
use super::config::OperationMode;
|
||||||
|
|
||||||
use burn::{
|
use burn::{
|
||||||
backend::{Autodiff, WebGpu},
|
backend::{Autodiff, Wgpu},
|
||||||
data::dataset::Dataset,
|
data::dataset::Dataset,
|
||||||
optim::AdamConfig,
|
optim::AdamConfig,
|
||||||
};
|
};
|
||||||
|
@ -26,7 +26,7 @@ impl NeuralNetwork {
|
||||||
|
|
||||||
// Functions
|
// Functions
|
||||||
fn train(&self) {
|
fn train(&self) {
|
||||||
type MyBackend = WebGpu<f32, i32>;
|
type MyBackend = Wgpu<f32, i32>;
|
||||||
type MyAutodiffBackend = Autodiff<MyBackend>;
|
type MyAutodiffBackend = Autodiff<MyBackend>;
|
||||||
|
|
||||||
// Create a default Wgpu device
|
// Create a default Wgpu device
|
||||||
|
@ -43,7 +43,7 @@ impl NeuralNetwork {
|
||||||
);
|
);
|
||||||
|
|
||||||
// Infer the model
|
// Infer the model
|
||||||
inference::infer::<MyBackend>(
|
infrence::infer::<MyBackend>(
|
||||||
artifact_dir,
|
artifact_dir,
|
||||||
device,
|
device,
|
||||||
burn::data::dataset::vision::MnistDataset::test()
|
burn::data::dataset::vision::MnistDataset::test()
|
||||||
|
@ -52,7 +52,7 @@ impl NeuralNetwork {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
fn infer(&self) {
|
fn infer(&self) {
|
||||||
type MyBackend = WebGpu<f32, i32>;
|
type MyBackend = Wgpu<f32, i32>;
|
||||||
|
|
||||||
let device = burn::backend::wgpu::WgpuDevice::default();
|
let device = burn::backend::wgpu::WgpuDevice::default();
|
||||||
|
|
||||||
|
|
|
@ -4,11 +4,6 @@ use burn::{
|
||||||
prelude::*,
|
prelude::*,
|
||||||
};
|
};
|
||||||
|
|
||||||
use burn::{
|
|
||||||
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
|
|
||||||
prelude::*,
|
|
||||||
};
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct MnistBatcher<B: Backend> {
|
pub struct MnistBatcher<B: Backend> {
|
||||||
device: B::Device,
|
device: B::Device,
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use crate::{data::MnistBatcher, model::Model, training::TrainingConfig};
|
use super::{data::MnistBatcher, model::Model, training::TrainingConfig};
|
||||||
use burn::{
|
use burn::{
|
||||||
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
|
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
|
||||||
prelude::*,
|
prelude::*,
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use crate::{
|
use super::{
|
||||||
data::{MnistBatch, MnistBatcher},
|
data::{MnistBatch, MnistBatcher},
|
||||||
model::{Model, ModelConfig},
|
model::{Model, ModelConfig},
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in a new issue