generated from OBJNULL/Dockerized-Rust
Created network manager
This commit is contained in:
parent
d68f59aedd
commit
891a608f6f
2 changed files with 30 additions and 112 deletions
30
project/src/neural.rs
Normal file
30
project/src/neural.rs
Normal file
|
@ -0,0 +1,30 @@
|
|||
// Libraries
|
||||
mod data;
|
||||
mod infrence;
|
||||
mod model;
|
||||
mod training;
|
||||
|
||||
use super::config::OperationMode;
|
||||
|
||||
// Structures
|
||||
pub struct NeuralNetwork {
|
||||
mode: OperationMode,
|
||||
}
|
||||
|
||||
// Implementaions
|
||||
impl NeuralNetwork {
|
||||
// Constructors
|
||||
pub fn new(mode: OperationMode) -> Self {
|
||||
// Return Result
|
||||
Self { mode }
|
||||
}
|
||||
|
||||
// Functions
|
||||
pub fn start(&self) {
|
||||
// Switching based on mode
|
||||
match self.mode {
|
||||
OperationMode::Training => {}
|
||||
OperationMode::Infrence => {}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,112 +0,0 @@
|
|||
use crate::{
|
||||
data::{MnistBatch, MnistBatcher},
|
||||
model::{Model, ModelConfig},
|
||||
};
|
||||
use burn::{
|
||||
data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
|
||||
nn::loss::CrossEntropyLossConfig,
|
||||
optim::AdamConfig,
|
||||
prelude::*,
|
||||
record::CompactRecorder,
|
||||
tensor::backend::AutodiffBackend,
|
||||
train::{
|
||||
metric::{AccuracyMetric, LossMetric},
|
||||
ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
|
||||
},
|
||||
};
|
||||
|
||||
impl<B: Backend> Model<B> {
|
||||
pub fn forward_classification(
|
||||
&self,
|
||||
images: Tensor<B, 3>,
|
||||
targets: Tensor<B, 1, Int>,
|
||||
) -> ClassificationOutput<B> {
|
||||
let output = self.forward(images);
|
||||
let loss = CrossEntropyLossConfig::new()
|
||||
.init(&output.device())
|
||||
.forward(output.clone(), targets.clone());
|
||||
|
||||
ClassificationOutput::new(loss, output, targets)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: AutodiffBackend> TrainStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||
fn step(&self, batch: MnistBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
|
||||
let item = self.forward_classification(batch.images, batch.targets);
|
||||
|
||||
TrainOutput::new(self, item.loss.backward(), item)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ValidStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||
fn step(&self, batch: MnistBatch<B>) -> ClassificationOutput<B> {
|
||||
self.forward_classification(batch.images, batch.targets)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct TrainingConfig {
|
||||
pub model: ModelConfig,
|
||||
pub optimizer: AdamConfig,
|
||||
#[config(default = 10)]
|
||||
pub num_epochs: usize,
|
||||
#[config(default = 64)]
|
||||
pub batch_size: usize,
|
||||
#[config(default = 4)]
|
||||
pub num_workers: usize,
|
||||
#[config(default = 42)]
|
||||
pub seed: u64,
|
||||
#[config(default = 1.0e-4)]
|
||||
pub learning_rate: f64,
|
||||
}
|
||||
|
||||
fn create_artifact_dir(artifact_dir: &str) {
|
||||
// Remove existing artifacts before to get an accurate learner summary
|
||||
std::fs::remove_dir_all(artifact_dir).ok();
|
||||
std::fs::create_dir_all(artifact_dir).ok();
|
||||
}
|
||||
|
||||
pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, device: B::Device) {
|
||||
create_artifact_dir(artifact_dir);
|
||||
config
|
||||
.save(format!("{artifact_dir}/config.json"))
|
||||
.expect("Config should be saved successfully");
|
||||
|
||||
B::seed(config.seed);
|
||||
|
||||
let batcher_train = MnistBatcher::<B>::new(device.clone());
|
||||
let batcher_valid = MnistBatcher::<B::InnerBackend>::new(device.clone());
|
||||
|
||||
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
||||
.batch_size(config.batch_size)
|
||||
.shuffle(config.seed)
|
||||
.num_workers(config.num_workers)
|
||||
.build(MnistDataset::train());
|
||||
|
||||
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
|
||||
.batch_size(config.batch_size)
|
||||
.shuffle(config.seed)
|
||||
.num_workers(config.num_workers)
|
||||
.build(MnistDataset::test());
|
||||
|
||||
let learner = LearnerBuilder::new(artifact_dir)
|
||||
.metric_train_numeric(AccuracyMetric::new())
|
||||
.metric_valid_numeric(AccuracyMetric::new())
|
||||
.metric_train_numeric(LossMetric::new())
|
||||
.metric_valid_numeric(LossMetric::new())
|
||||
.with_file_checkpointer(CompactRecorder::new())
|
||||
.devices(vec![device.clone()])
|
||||
.num_epochs(config.num_epochs)
|
||||
.summary()
|
||||
.build(
|
||||
config.model.init::<B>(&device),
|
||||
config.optimizer.init(),
|
||||
config.learning_rate,
|
||||
);
|
||||
|
||||
let model_trained = learner.fit(dataloader_train, dataloader_test);
|
||||
|
||||
model_trained
|
||||
.save_file(format!("{artifact_dir}/model"), &CompactRecorder::new())
|
||||
.expect("Trained model should be saved successfully");
|
||||
}
|
Loading…
Reference in a new issue