Updated WebGPU to WGPU

This commit is contained in:
Maddox Werts 2025-03-17 12:14:18 -04:00
parent 48c709414c
commit 73da28805b

View file

@ -6,7 +6,7 @@ mod training;
use super::config::OperationMode; use super::config::OperationMode;
use burn::{ use burn::{
backend::{Autodiff, WebGpu}, backend::{Autodiff, Wgpu},
data::dataset::Dataset, data::dataset::Dataset,
optim::AdamConfig, optim::AdamConfig,
}; };
@ -26,7 +26,7 @@ impl NeuralNetwork {
// Functions // Functions
fn train(&self) { fn train(&self) {
type MyBackend = WebGpu<f32, i32>; type MyBackend = Wgpu<f32, i32>;
type MyAutodiffBackend = Autodiff<MyBackend>; type MyAutodiffBackend = Autodiff<MyBackend>;
// Create a default Wgpu device // Create a default Wgpu device
@ -52,7 +52,7 @@ impl NeuralNetwork {
); );
} }
fn infer(&self) { fn infer(&self) {
type MyBackend = WebGpu<f32, i32>; type MyBackend = Wgpu<f32, i32>;
let device = burn::backend::wgpu::WgpuDevice::default(); let device = burn::backend::wgpu::WgpuDevice::default();