Compare commits

..

7 commits

7 changed files with 10 additions and 14 deletions

View file

@ -7,3 +7,6 @@ edition = "2021"
burn = { version = "0.16.0", features = ["wgpu", "train", "vision"] }
log = { version = "0.4.26" }
serde = { version = "1.0.219", features = ["std", "derive"] }
bincode = "=2.0.0-rc.3"
bincode_derive = "=2.0.0-rc.3"

View file

@ -13,7 +13,7 @@ pub fn get_operation_mode() -> Option<OperationMode> {
let args: Vec<String> = env::args().collect();
// Getting operation mode
match &args[1] {
match args[1].as_str() {
"training" => Some(OperationMode::Training),
"infrence" => Some(OperationMode::Infrence),
_ => None,

View file

@ -3,8 +3,6 @@ mod config;
mod neural;
use neural::NeuralNetwork;
use std::error::Error;
// Entry-Point
fn main() {
// Getting Running Mode First

View file

@ -6,7 +6,7 @@ mod training;
use super::config::OperationMode;
use burn::{
backend::{Autodiff, WebGpu},
backend::{Autodiff, Wgpu},
data::dataset::Dataset,
optim::AdamConfig,
};
@ -26,7 +26,7 @@ impl NeuralNetwork {
// Functions
fn train(&self) {
type MyBackend = WebGpu<f32, i32>;
type MyBackend = Wgpu<f32, i32>;
type MyAutodiffBackend = Autodiff<MyBackend>;
// Create a default Wgpu device
@ -43,7 +43,7 @@ impl NeuralNetwork {
);
// Infer the model
inference::infer::<MyBackend>(
infrence::infer::<MyBackend>(
artifact_dir,
device,
burn::data::dataset::vision::MnistDataset::test()
@ -52,7 +52,7 @@ impl NeuralNetwork {
);
}
fn infer(&self) {
type MyBackend = WebGpu<f32, i32>;
type MyBackend = Wgpu<f32, i32>;
let device = burn::backend::wgpu::WgpuDevice::default();

View file

@ -4,11 +4,6 @@ use burn::{
prelude::*,
};
use burn::{
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
prelude::*,
};
#[derive(Clone)]
pub struct MnistBatcher<B: Backend> {
device: B::Device,

View file

@ -1,4 +1,4 @@
use crate::{data::MnistBatcher, model::Model, training::TrainingConfig};
use super::{data::MnistBatcher, model::Model, training::TrainingConfig};
use burn::{
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
prelude::*,

View file

@ -1,4 +1,4 @@
use crate::{
use super::{
data::{MnistBatch, MnistBatcher},
model::{Model, ModelConfig},
};