Compare commits

..

7 commits

7 changed files with 10 additions and 14 deletions

View file

@ -7,3 +7,6 @@ edition = "2021"
burn = { version = "0.16.0", features = ["wgpu", "train", "vision"] } burn = { version = "0.16.0", features = ["wgpu", "train", "vision"] }
log = { version = "0.4.26" } log = { version = "0.4.26" }
serde = { version = "1.0.219", features = ["std", "derive"] } serde = { version = "1.0.219", features = ["std", "derive"] }
bincode = "=2.0.0-rc.3"
bincode_derive = "=2.0.0-rc.3"

View file

@ -13,7 +13,7 @@ pub fn get_operation_mode() -> Option<OperationMode> {
let args: Vec<String> = env::args().collect(); let args: Vec<String> = env::args().collect();
// Getting operation mode // Getting operation mode
match &args[1] { match args[1].as_str() {
"training" => Some(OperationMode::Training), "training" => Some(OperationMode::Training),
"infrence" => Some(OperationMode::Infrence), "infrence" => Some(OperationMode::Infrence),
_ => None, _ => None,

View file

@ -3,8 +3,6 @@ mod config;
mod neural; mod neural;
use neural::NeuralNetwork; use neural::NeuralNetwork;
use std::error::Error;
// Entry-Point // Entry-Point
fn main() { fn main() {
// Getting Running Mode First // Getting Running Mode First

View file

@ -6,7 +6,7 @@ mod training;
use super::config::OperationMode; use super::config::OperationMode;
use burn::{ use burn::{
backend::{Autodiff, WebGpu}, backend::{Autodiff, Wgpu},
data::dataset::Dataset, data::dataset::Dataset,
optim::AdamConfig, optim::AdamConfig,
}; };
@ -26,7 +26,7 @@ impl NeuralNetwork {
// Functions // Functions
fn train(&self) { fn train(&self) {
type MyBackend = WebGpu<f32, i32>; type MyBackend = Wgpu<f32, i32>;
type MyAutodiffBackend = Autodiff<MyBackend>; type MyAutodiffBackend = Autodiff<MyBackend>;
// Create a default Wgpu device // Create a default Wgpu device
@ -43,7 +43,7 @@ impl NeuralNetwork {
); );
// Infer the model // Infer the model
inference::infer::<MyBackend>( infrence::infer::<MyBackend>(
artifact_dir, artifact_dir,
device, device,
burn::data::dataset::vision::MnistDataset::test() burn::data::dataset::vision::MnistDataset::test()
@ -52,7 +52,7 @@ impl NeuralNetwork {
); );
} }
fn infer(&self) { fn infer(&self) {
type MyBackend = WebGpu<f32, i32>; type MyBackend = Wgpu<f32, i32>;
let device = burn::backend::wgpu::WgpuDevice::default(); let device = burn::backend::wgpu::WgpuDevice::default();

View file

@ -4,11 +4,6 @@ use burn::{
prelude::*, prelude::*,
}; };
use burn::{
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
prelude::*,
};
#[derive(Clone)] #[derive(Clone)]
pub struct MnistBatcher<B: Backend> { pub struct MnistBatcher<B: Backend> {
device: B::Device, device: B::Device,

View file

@ -1,4 +1,4 @@
use crate::{data::MnistBatcher, model::Model, training::TrainingConfig}; use super::{data::MnistBatcher, model::Model, training::TrainingConfig};
use burn::{ use burn::{
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem}, data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
prelude::*, prelude::*,

View file

@ -1,4 +1,4 @@
use crate::{ use super::{
data::{MnistBatch, MnistBatcher}, data::{MnistBatch, MnistBatcher},
model::{Model, ModelConfig}, model::{Model, ModelConfig},
}; };