How to set up a different dataset#

In this Tutorial, we show how one can quickly edit the template to use another dataset. In this example, the new dataset is CIFAR10.

Create a new DataModule#

Under src/quicksetup-ai/datamodules/, we create a new file called cifar_10_datamodule.py with the content of mnist_datamodule.py. Then, we edit the necessary parts, namely in prepare_data, we download the CIFAR10 dataset instead of MNIST. We also modify the default splits and transforms.

from typing import Optional, Tuple

import torch
from pytorch_lightning import LightningDataModule
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms

class CIFAR10DataModule(LightningDataModule):
    """
    Example of LightningDataModule for CIFAR10 dataset.

    A DataModule implements 5 key methods:
        - prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode)
        - setup (things to do on every accelerator in distributed mode)
        - train_dataloader (the training dataloader)
        - val_dataloader (the validation dataloader(s))
        - test_dataloader (the test dataloader(s))

    This allows you to share a full dataset without explaining how to download,
    split, transform and process the data.

    Read the docs:
        https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html
    """
    
    def __init__(
        self,
        data_dir: str = "data/",
        train_val_test_split: Tuple[int, int, int] =  (40_000, 5_000, 15_000),
        batch_size: int = 64,
        num_workers: int = 0,
        pin_memory: bool = False,
    ):
        super().__init__()

        # this line allows to access init params with 'self.hparams' attribute
        self.save_hyperparameters(logger=False)
        
        # data transformations
        self.transforms = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )

        self.data_train: Optional[Dataset] = None
        self.data_val: Optional[Dataset] = None
        self.data_test: Optional[Dataset] = None

        @property
        def num_classes(self) -> int:
            return 10

    def prepare_data(self):
        """Download data if needed. This method is called only from a single GPU.
        Do not use it to assign state (self.x = y)."""
        CIFAR10(self.hparams.data_dir, train=True, download=True)
        CIFAR10(self.hparams.data_dir, train=False, download=True)

    def setup(self, stage: Optional[str] = None):
        """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
        This method is called by lightning twice for `trainer.fit()` and `trainer.test()`, so be careful if you do a random split!
        The `stage` can be used to differentiate whether it's called before trainer.fit()` or `trainer.test()`."""

        # load datasets only if they're not loaded already
        if not self.data_train and not self.data_val and not self.data_test:
            trainset = CIFAR10(self.hparams.data_dir, train=True, transform=self.transforms)
            testset = CIFAR10(self.hparams.data_dir, train=False, transform=self.transforms)
            dataset = ConcatDataset(datasets=[trainset, testset])
            self.data_train, self.data_val, self.data_test = random_split(
                dataset=dataset,
                lengths=self.hparams.train_val_test_split,
                generator=torch.Generator().manual_seed(42),
            )

    def train_dataloader(self):
        return DataLoader(
            dataset=self.data_train,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(
            dataset=self.data_val,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            shuffle=False,
        )

    def test_dataloader(self):
        return DataLoader(
            dataset=self.data_test,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            shuffle=False,
        )

Create the configuration file#

In configs/datamodule/, we create a new file called cifar_10.yaml. We provide the correct DataModule class and the parameters we want for the experiment. Here’s the content of cifar_10.yaml.

_target_: quicksetup-ai.datamodules.cifar_10_datamodule.CIFAR10DataModule

data_dir: ${data_dir} # data_dir is specified in config.yaml
batch_size: 64
train_val_test_split: [40_000, 5_000, 15_000]
num_workers: 0
pin_memory: False

Edit the main train/test configurations#

The final step is to configure configs/train.yaml and configs/test.yaml.

First, we edit the training configuration. Under defaults, we set datamodule to cifar_10.yaml to make it use the new datamodule.

Note that we also need to create a new model configuration file: configs/model/cifar_10.yaml where we adjust the parameters to suit the new dataset. We set model in train.yaml under defaults to cifar_10.yaml.

Finally, we edit the test configuration: We also set datamodule and model to cifar_10.yaml under defaults. Then, we provide the ckpt_path to the model we want to test.

Congratulations! You can now use other datasets in the template.