use burn::{ nn::{ conv::{Conv2d, Conv2dConfig}, pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig}, Dropout, DropoutConfig, Linear, LinearConfig, Relu, }, prelude::*, }; #[derive(Module, Debug)] pub struct Model { conv1: Conv2d, conv2: Conv2d, pool: AdaptiveAvgPool2d, dropout: Dropout, linear1: Linear, linear2: Linear, activation: Relu, } #[derive(Config, Debug)] pub struct ModelConfig { num_classes: usize, hidden_size: usize, #[config(default = "0.5")] dropout: f64, } impl ModelConfig { /// Returns the initialized model. pub fn init(&self, device: &B::Device) -> Model { Model { conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device), conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device), pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(), activation: Relu::new(), linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device), linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device), dropout: DropoutConfig::new(self.dropout).init(), } } } impl Model { /// # Shapes /// - Images [batch_size, height, width] /// - Output [batch_size, class_prob] pub fn forward(&self, images: Tensor) -> Tensor { let [batch_size, height, width] = images.dims(); // Create a channel. let x = images.reshape([batch_size, 1, height, width]); let x = self.conv1.forward(x); // [batch_size, 8, _, _] let x = self.dropout.forward(x); let x = self.conv2.forward(x); // [batch_size, 16, _, _] let x = self.dropout.forward(x); let x = self.activation.forward(x); let x = self.pool.forward(x); // [batch_size, 16, 8, 8] let x = x.reshape([batch_size, 16 * 8 * 8]); let x = self.linear1.forward(x); let x = self.dropout.forward(x); let x = self.activation.forward(x); self.linear2.forward(x) // [batch_size, num_classes] } }