generated from OBJNULL/Dockerized-Rust
Compare commits
7 commits
bd5c22d9c6
...
46149b885c
Author | SHA1 | Date | |
---|---|---|---|
46149b885c | |||
8608676533 | |||
73da28805b | |||
48c709414c | |||
224ef39603 | |||
31c30ed200 | |||
8b918113ac |
7 changed files with 10 additions and 14 deletions
|
@ -7,3 +7,6 @@ edition = "2021"
|
|||
burn = { version = "0.16.0", features = ["wgpu", "train", "vision"] }
|
||||
log = { version = "0.4.26" }
|
||||
serde = { version = "1.0.219", features = ["std", "derive"] }
|
||||
|
||||
bincode = "=2.0.0-rc.3"
|
||||
bincode_derive = "=2.0.0-rc.3"
|
||||
|
|
|
@ -13,7 +13,7 @@ pub fn get_operation_mode() -> Option<OperationMode> {
|
|||
let args: Vec<String> = env::args().collect();
|
||||
|
||||
// Getting operation mode
|
||||
match &args[1] {
|
||||
match args[1].as_str() {
|
||||
"training" => Some(OperationMode::Training),
|
||||
"infrence" => Some(OperationMode::Infrence),
|
||||
_ => None,
|
||||
|
|
|
@ -3,8 +3,6 @@ mod config;
|
|||
mod neural;
|
||||
use neural::NeuralNetwork;
|
||||
|
||||
use std::error::Error;
|
||||
|
||||
// Entry-Point
|
||||
fn main() {
|
||||
// Getting Running Mode First
|
||||
|
|
|
@ -6,7 +6,7 @@ mod training;
|
|||
use super::config::OperationMode;
|
||||
|
||||
use burn::{
|
||||
backend::{Autodiff, WebGpu},
|
||||
backend::{Autodiff, Wgpu},
|
||||
data::dataset::Dataset,
|
||||
optim::AdamConfig,
|
||||
};
|
||||
|
@ -26,7 +26,7 @@ impl NeuralNetwork {
|
|||
|
||||
// Functions
|
||||
fn train(&self) {
|
||||
type MyBackend = WebGpu<f32, i32>;
|
||||
type MyBackend = Wgpu<f32, i32>;
|
||||
type MyAutodiffBackend = Autodiff<MyBackend>;
|
||||
|
||||
// Create a default Wgpu device
|
||||
|
@ -43,7 +43,7 @@ impl NeuralNetwork {
|
|||
);
|
||||
|
||||
// Infer the model
|
||||
inference::infer::<MyBackend>(
|
||||
infrence::infer::<MyBackend>(
|
||||
artifact_dir,
|
||||
device,
|
||||
burn::data::dataset::vision::MnistDataset::test()
|
||||
|
@ -52,7 +52,7 @@ impl NeuralNetwork {
|
|||
);
|
||||
}
|
||||
fn infer(&self) {
|
||||
type MyBackend = WebGpu<f32, i32>;
|
||||
type MyBackend = Wgpu<f32, i32>;
|
||||
|
||||
let device = burn::backend::wgpu::WgpuDevice::default();
|
||||
|
||||
|
|
|
@ -4,11 +4,6 @@ use burn::{
|
|||
prelude::*,
|
||||
};
|
||||
|
||||
use burn::{
|
||||
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MnistBatcher<B: Backend> {
|
||||
device: B::Device,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{data::MnistBatcher, model::Model, training::TrainingConfig};
|
||||
use super::{data::MnistBatcher, model::Model, training::TrainingConfig};
|
||||
use burn::{
|
||||
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
|
||||
prelude::*,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{
|
||||
use super::{
|
||||
data::{MnistBatch, MnistBatcher},
|
||||
model::{Model, ModelConfig},
|
||||
};
|
||||
|
|
Loading…
Reference in a new issue