generated from OBJNULL/Dockerized-Rust
57 lines
1.7 KiB
Rust
57 lines
1.7 KiB
Rust
// Libraries
|
|
use burn::{
|
|
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
|
|
prelude::*,
|
|
};
|
|
|
|
use burn::{
|
|
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
|
|
prelude::*,
|
|
};
|
|
|
|
#[derive(Clone)]
|
|
pub struct MnistBatcher<B: Backend> {
|
|
device: B::Device,
|
|
}
|
|
|
|
impl<B: Backend> MnistBatcher<B> {
|
|
pub fn new(device: B::Device) -> Self {
|
|
Self { device }
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub struct MnistBatch<B: Backend> {
|
|
pub images: Tensor<B, 3>,
|
|
pub targets: Tensor<B, 1, Int>,
|
|
}
|
|
|
|
impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {
|
|
fn batch(&self, items: Vec<MnistItem>) -> MnistBatch<B> {
|
|
let images = items
|
|
.iter()
|
|
.map(|item| TensorData::from(item.image).convert::<B::FloatElem>())
|
|
.map(|data| Tensor::<B, 2>::from_data(data, &self.device))
|
|
.map(|tensor| tensor.reshape([1, 28, 28]))
|
|
// normalize: make between [0,1] and make the mean = 0 and std = 1
|
|
// values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example
|
|
// https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122
|
|
.map(|tensor| ((tensor / 255) - 0.1307) / 0.3081)
|
|
.collect();
|
|
|
|
let targets = items
|
|
.iter()
|
|
.map(|item| {
|
|
Tensor::<B, 1, Int>::from_data(
|
|
[(item.label as i64).elem::<B::IntElem>()],
|
|
&self.device,
|
|
)
|
|
})
|
|
.collect();
|
|
|
|
let images = Tensor::cat(images, 0);
|
|
let targets = Tensor::cat(targets, 0);
|
|
|
|
MnistBatch { images, targets }
|
|
}
|
|
}
|