generated from OBJNULL/Dockerized-Rust
Borrowed code from Burn Lib Guide
This commit is contained in:
parent
86d432187b
commit
7f6a9c41c9
4 changed files with 261 additions and 0 deletions
57
project/src/data.rs
Normal file
57
project/src/data.rs
Normal file
|
@ -0,0 +1,57 @@
|
|||
// Libraries
|
||||
use burn::{
|
||||
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
use burn::{
|
||||
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MnistBatcher<B: Backend> {
|
||||
device: B::Device,
|
||||
}
|
||||
|
||||
impl<B: Backend> MnistBatcher<B> {
|
||||
pub fn new(device: B::Device) -> Self {
|
||||
Self { device }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MnistBatch<B: Backend> {
|
||||
pub images: Tensor<B, 3>,
|
||||
pub targets: Tensor<B, 1, Int>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {
|
||||
fn batch(&self, items: Vec<MnistItem>) -> MnistBatch<B> {
|
||||
let images = items
|
||||
.iter()
|
||||
.map(|item| TensorData::from(item.image).convert::<B::FloatElem>())
|
||||
.map(|data| Tensor::<B, 2>::from_data(data, &self.device))
|
||||
.map(|tensor| tensor.reshape([1, 28, 28]))
|
||||
// normalize: make between [0,1] and make the mean = 0 and std = 1
|
||||
// values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example
|
||||
// https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122
|
||||
.map(|tensor| ((tensor / 255) - 0.1307) / 0.3081)
|
||||
.collect();
|
||||
|
||||
let targets = items
|
||||
.iter()
|
||||
.map(|item| {
|
||||
Tensor::<B, 1, Int>::from_data(
|
||||
[(item.label as i64).elem::<B::IntElem>()],
|
||||
&self.device,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let images = Tensor::cat(images, 0);
|
||||
let targets = Tensor::cat(targets, 0);
|
||||
|
||||
MnistBatch { images, targets }
|
||||
}
|
||||
}
|
24
project/src/infrence.rs
Normal file
24
project/src/infrence.rs
Normal file
|
@ -0,0 +1,24 @@
|
|||
use crate::{data::MnistBatcher, model::Model, training::TrainingConfig};
|
||||
use burn::{
|
||||
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
|
||||
prelude::*,
|
||||
record::{CompactRecorder, Recorder},
|
||||
};
|
||||
|
||||
pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MnistItem) {
|
||||
let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))
|
||||
.expect("Config should exist for the model; run train first");
|
||||
let record = CompactRecorder::new()
|
||||
.load(format!("{artifact_dir}/model").into(), &device)
|
||||
.expect("Trained model should exist; run train first");
|
||||
|
||||
let model: Model<B> = config.model.init(&device).load_record(record);
|
||||
|
||||
let label = item.label;
|
||||
let batcher = MnistBatcher::new(device);
|
||||
let batch = batcher.batch(vec![item]);
|
||||
let output = model.forward(batch.images);
|
||||
let predicted = output.argmax(1).flatten::<1>(0, 1).into_scalar();
|
||||
|
||||
println!("Predicted {} Expected {}", predicted, label);
|
||||
}
|
68
project/src/model.rs
Normal file
68
project/src/model.rs
Normal file
|
@ -0,0 +1,68 @@
|
|||
use burn::{
|
||||
nn::{
|
||||
conv::{Conv2d, Conv2dConfig},
|
||||
pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},
|
||||
Dropout, DropoutConfig, Linear, LinearConfig, Relu,
|
||||
},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
conv1: Conv2d<B>,
|
||||
conv2: Conv2d<B>,
|
||||
pool: AdaptiveAvgPool2d,
|
||||
dropout: Dropout,
|
||||
linear1: Linear<B>,
|
||||
linear2: Linear<B>,
|
||||
activation: Relu,
|
||||
}
|
||||
|
||||
#[derive(Config, Debug)]
|
||||
pub struct ModelConfig {
|
||||
num_classes: usize,
|
||||
hidden_size: usize,
|
||||
#[config(default = "0.5")]
|
||||
dropout: f64,
|
||||
}
|
||||
|
||||
impl ModelConfig {
|
||||
/// Returns the initialized model.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> Model<B> {
|
||||
Model {
|
||||
conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device),
|
||||
conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device),
|
||||
pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),
|
||||
activation: Relu::new(),
|
||||
linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device),
|
||||
linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device),
|
||||
dropout: DropoutConfig::new(self.dropout).init(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Model<B> {
|
||||
/// # Shapes
|
||||
/// - Images [batch_size, height, width]
|
||||
/// - Output [batch_size, class_prob]
|
||||
pub fn forward(&self, images: Tensor<B, 3>) -> Tensor<B, 2> {
|
||||
let [batch_size, height, width] = images.dims();
|
||||
|
||||
// Create a channel.
|
||||
let x = images.reshape([batch_size, 1, height, width]);
|
||||
|
||||
let x = self.conv1.forward(x); // [batch_size, 8, _, _]
|
||||
let x = self.dropout.forward(x);
|
||||
let x = self.conv2.forward(x); // [batch_size, 16, _, _]
|
||||
let x = self.dropout.forward(x);
|
||||
let x = self.activation.forward(x);
|
||||
|
||||
let x = self.pool.forward(x); // [batch_size, 16, 8, 8]
|
||||
let x = x.reshape([batch_size, 16 * 8 * 8]);
|
||||
let x = self.linear1.forward(x);
|
||||
let x = self.dropout.forward(x);
|
||||
let x = self.activation.forward(x);
|
||||
|
||||
self.linear2.forward(x) // [batch_size, num_classes]
|
||||
}
|
||||
}
|
112
project/src/training.rs
Normal file
112
project/src/training.rs
Normal file
|
@ -0,0 +1,112 @@
|
|||
use crate::{
|
||||
data::{MnistBatch, MnistBatcher},
|
||||
model::{Model, ModelConfig},
|
||||
};
|
||||
use burn::{
|
||||
data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
|
||||
nn::loss::CrossEntropyLossConfig,
|
||||
optim::AdamConfig,
|
||||
prelude::*,
|
||||
record::CompactRecorder,
|
||||
tensor::backend::AutodiffBackend,
|
||||
train::{
|
||||
metric::{AccuracyMetric, LossMetric},
|
||||
ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
|
||||
},
|
||||
};
|
||||
|
||||
impl<B: Backend> Model<B> {
|
||||
pub fn forward_classification(
|
||||
&self,
|
||||
images: Tensor<B, 3>,
|
||||
targets: Tensor<B, 1, Int>,
|
||||
) -> ClassificationOutput<B> {
|
||||
let output = self.forward(images);
|
||||
let loss = CrossEntropyLossConfig::new()
|
||||
.init(&output.device())
|
||||
.forward(output.clone(), targets.clone());
|
||||
|
||||
ClassificationOutput::new(loss, output, targets)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: AutodiffBackend> TrainStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||
fn step(&self, batch: MnistBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
|
||||
let item = self.forward_classification(batch.images, batch.targets);
|
||||
|
||||
TrainOutput::new(self, item.loss.backward(), item)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ValidStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||
fn step(&self, batch: MnistBatch<B>) -> ClassificationOutput<B> {
|
||||
self.forward_classification(batch.images, batch.targets)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct TrainingConfig {
|
||||
pub model: ModelConfig,
|
||||
pub optimizer: AdamConfig,
|
||||
#[config(default = 10)]
|
||||
pub num_epochs: usize,
|
||||
#[config(default = 64)]
|
||||
pub batch_size: usize,
|
||||
#[config(default = 4)]
|
||||
pub num_workers: usize,
|
||||
#[config(default = 42)]
|
||||
pub seed: u64,
|
||||
#[config(default = 1.0e-4)]
|
||||
pub learning_rate: f64,
|
||||
}
|
||||
|
||||
fn create_artifact_dir(artifact_dir: &str) {
|
||||
// Remove existing artifacts before to get an accurate learner summary
|
||||
std::fs::remove_dir_all(artifact_dir).ok();
|
||||
std::fs::create_dir_all(artifact_dir).ok();
|
||||
}
|
||||
|
||||
pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, device: B::Device) {
|
||||
create_artifact_dir(artifact_dir);
|
||||
config
|
||||
.save(format!("{artifact_dir}/config.json"))
|
||||
.expect("Config should be saved successfully");
|
||||
|
||||
B::seed(config.seed);
|
||||
|
||||
let batcher_train = MnistBatcher::<B>::new(device.clone());
|
||||
let batcher_valid = MnistBatcher::<B::InnerBackend>::new(device.clone());
|
||||
|
||||
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
||||
.batch_size(config.batch_size)
|
||||
.shuffle(config.seed)
|
||||
.num_workers(config.num_workers)
|
||||
.build(MnistDataset::train());
|
||||
|
||||
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
|
||||
.batch_size(config.batch_size)
|
||||
.shuffle(config.seed)
|
||||
.num_workers(config.num_workers)
|
||||
.build(MnistDataset::test());
|
||||
|
||||
let learner = LearnerBuilder::new(artifact_dir)
|
||||
.metric_train_numeric(AccuracyMetric::new())
|
||||
.metric_valid_numeric(AccuracyMetric::new())
|
||||
.metric_train_numeric(LossMetric::new())
|
||||
.metric_valid_numeric(LossMetric::new())
|
||||
.with_file_checkpointer(CompactRecorder::new())
|
||||
.devices(vec![device.clone()])
|
||||
.num_epochs(config.num_epochs)
|
||||
.summary()
|
||||
.build(
|
||||
config.model.init::<B>(&device),
|
||||
config.optimizer.init(),
|
||||
config.learning_rate,
|
||||
);
|
||||
|
||||
let model_trained = learner.fit(dataloader_train, dataloader_test);
|
||||
|
||||
model_trained
|
||||
.save_file(format!("{artifact_dir}/model"), &CompactRecorder::new())
|
||||
.expect("Trained model should be saved successfully");
|
||||
}
|
Loading…
Reference in a new issue