diff --git a/project/src/neural.rs b/project/src/neural.rs index 20a8513..a834664 100644 --- a/project/src/neural.rs +++ b/project/src/neural.rs @@ -6,7 +6,7 @@ mod training; use super::config::OperationMode; use burn::{ - backend::{Autodiff, WebGpu}, + backend::{Autodiff, Wgpu}, data::dataset::Dataset, optim::AdamConfig, }; @@ -26,7 +26,7 @@ impl NeuralNetwork { // Functions fn train(&self) { - type MyBackend = WebGpu; + type MyBackend = Wgpu; type MyAutodiffBackend = Autodiff; // Create a default Wgpu device @@ -52,7 +52,7 @@ impl NeuralNetwork { ); } fn infer(&self) { - type MyBackend = WebGpu; + type MyBackend = Wgpu; let device = burn::backend::wgpu::WgpuDevice::default();