generated from OBJNULL/Dockerized-Rust
68 lines
2.1 KiB
Rust
68 lines
2.1 KiB
Rust
use burn::{
|
|
nn::{
|
|
conv::{Conv2d, Conv2dConfig},
|
|
pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},
|
|
Dropout, DropoutConfig, Linear, LinearConfig, Relu,
|
|
},
|
|
prelude::*,
|
|
};
|
|
|
|
#[derive(Module, Debug)]
|
|
pub struct Model<B: Backend> {
|
|
conv1: Conv2d<B>,
|
|
conv2: Conv2d<B>,
|
|
pool: AdaptiveAvgPool2d,
|
|
dropout: Dropout,
|
|
linear1: Linear<B>,
|
|
linear2: Linear<B>,
|
|
activation: Relu,
|
|
}
|
|
|
|
#[derive(Config, Debug)]
|
|
pub struct ModelConfig {
|
|
num_classes: usize,
|
|
hidden_size: usize,
|
|
#[config(default = "0.5")]
|
|
dropout: f64,
|
|
}
|
|
|
|
impl ModelConfig {
|
|
/// Returns the initialized model.
|
|
pub fn init<B: Backend>(&self, device: &B::Device) -> Model<B> {
|
|
Model {
|
|
conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device),
|
|
conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device),
|
|
pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),
|
|
activation: Relu::new(),
|
|
linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device),
|
|
linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device),
|
|
dropout: DropoutConfig::new(self.dropout).init(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<B: Backend> Model<B> {
|
|
/// # Shapes
|
|
/// - Images [batch_size, height, width]
|
|
/// - Output [batch_size, class_prob]
|
|
pub fn forward(&self, images: Tensor<B, 3>) -> Tensor<B, 2> {
|
|
let [batch_size, height, width] = images.dims();
|
|
|
|
// Create a channel.
|
|
let x = images.reshape([batch_size, 1, height, width]);
|
|
|
|
let x = self.conv1.forward(x); // [batch_size, 8, _, _]
|
|
let x = self.dropout.forward(x);
|
|
let x = self.conv2.forward(x); // [batch_size, 16, _, _]
|
|
let x = self.dropout.forward(x);
|
|
let x = self.activation.forward(x);
|
|
|
|
let x = self.pool.forward(x); // [batch_size, 16, 8, 8]
|
|
let x = x.reshape([batch_size, 16 * 8 * 8]);
|
|
let x = self.linear1.forward(x);
|
|
let x = self.dropout.forward(x);
|
|
let x = self.activation.forward(x);
|
|
|
|
self.linear2.forward(x) // [batch_size, num_classes]
|
|
}
|
|
}
|