How to set up a different model
Contents
How to set up a different model#
In this Tutorial, we show how one can quickly edit the template to try a new architecture. In this example, we create a simple CNN architecture to classify MNIST images.
Define the new architecture#
In src/quicksetup_ai/models/components, we create a new file called simple_cnn.py. For simplicity, we adapt the code from pytorch documentation to create a CNN.
import torch.nn as nn
import torch.nn.functional as F
import torch
class SimpleCNN(nn.Module):
def __init__(
self,
input_chans: int = 1,
conv1_out_chans: int = 6,
conv1_kernel_size: int = 5,
conv2_out_chans: int = 16,
conv2_kernel_size: int = 5,
max_pool_size: int = 2,
lin_input_size: int = 16 * 4 * 4,
output_size: int = 10,
):
super().__init__()
self.conv1 = nn.Conv2d(input_chans, conv1_out_chans, conv1_kernel_size)
self.pool = nn.MaxPool2d(max_pool_size, max_pool_size)
self.conv2 = nn.Conv2d(conv1_out_chans, conv2_out_chans, conv2_kernel_size)
self.fc = nn.Linear(lin_input_size, output_size)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = self.fc(x)
return x
Create the configuration file#
In configs/model/, we create a new file called mnist_cnn.yaml with the content of mnist.yaml. Then, we edit the net part by setting _target_ to the new model, and providing the parameters we want for training. Here’s the content of the new file.
_target_: quicksetup_ai.models.mnist_module.MNISTLitModule
lr: 0.001
weight_decay: 0.0005
net:
_target_: quicksetup_ai.models.components.simple_cnn.SimpleCNN
input_chans: 1
conv1_out_chans: 6
conv1_kernel_size: 5
conv2_out_chans: 16
conv2_kernel_size: 5
max_pool_size: 2
lin_input_size: 256
output_size: 10
For sake of simplicity, we keep on using MNISTLitModule since we do not need to alter the training process.
Edit the main train/test configurations#
The final step is to configure configs/train.yaml and configs/test.yaml.
First, we edit the training configuration. Under defaults, we set model to mnist_cnn.yaml to make it use the new model. After training is completed, we edit the test configuration: We also set model to mnist_cnn.yaml under defaults. Then, we provide the ckpt_path to the model we want to test.
Congratulations! You can now use custom model architectures in the template.