PyTorch Lightning: Data module and test dataset#

The PyTorch Lightning data module is a standard way to organize your data loading code. It helps you separate the data preparation from the model training code, making your code cleaner and more maintainable.

A typical data module looks like this:

class MyDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size):
        super().__init__()
        self.batch_size = batch_size
        self.data_dir = data_dir

    def prepare_data(self):
        # Download or prepare your data here
        pass

    def setup(self, stage=None):
        # Load your data here
        self.train_dataset = MyDataset(train=True)
        self.val_dataset = MyDataset(train=False)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

The only change in the Trainer will be:

trainer.fit(model, datamodule=dm)
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import random_split, DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
import torchmetrics # for metrics

Data module#

class MNISTDataModule(L.LightningDataModule):
    def __init__(self, batch_size=64):
        super().__init__()
        self.batch_size = batch_size
        self.transform = transforms.ToTensor()

    def prepare_data(self):
        MNIST(root="./MNIST", train=True, download=True)
        MNIST(root="./MNIST", train=True, download=True)
    
    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            full = MNIST(root="./MNIST", train=True, transform= self.transform)
            self.train, self.val = random_split(full, [55000, 5000])
        
        if stage == "test" or stage is None:
            self.test = MNIST(root="./MNIST", train=False, transform= self.transform)
    
    def train_dataloader(self):
        return DataLoader(self.train, batch_size= self.batch_size)
    
    def val_dataloader(self):
        return DataLoader(self.val, batch_size= self.batch_size)
    
    def test_dataloader(self):
        return DataLoader(self.test, batch_size= self.batch_size)

NN class#

class LitModel(L.LightningModule): # a replacesment of nn.Module
    def __init__(self):
      super().__init__() # call __init__ of the super class to init important LightningModule functions
      self.model = nn.Sequential(
         nn.Flatten(),
         nn.Linear(28*28, 128),
         nn.ReLU(),
         nn.Linear(128, 10)
      )

      self.train_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=10)
      self.val_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=10)
      self.test_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=10)
    def forward(self, x):
      return self.model(x)
   
    def training_step(self, batch, batch_idx):
      x, y = batch
      logits = self(x)
      loss = F.cross_entropy(logits, y)
      acc = self.train_acc(logits, y)
      self.log("train_loss", loss)
      self.log("train_acc", acc, prog_bar=True)
      #return loss
      return {"loss": loss} #both are the same

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        acc = self.val_acc(logits, y)
        self.log("val_loss", loss, prog_bar=True) # prog_bar=False will not show the val loss in the training progress bar.
        self.log("val_acc", acc, prog_bar=True)
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        acc = self.test_acc(logits, y)
        self.log("test_loss", loss)
        self.log("test_acc", acc, prog_bar=True)

    def configure_optimizers(self):
      return torch.optim.Adam(self.parameters(), lr=1e-3) # the NN get the parameters not self.model.parameters()
# Datamodule
dm = MNISTDataModule()
# checkpoint based on val loss
checkpoint_cb = ModelCheckpoint(monitor="val_loss", mode="min")

# trainer
model = LitModel()
trainer = L.Trainer(max_epochs = 3,
                    accelerator="auto", # auto will select gpu if available
                    callbacks=[checkpoint_cb]) 
trainer.fit(model, datamodule=dm)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | model     | Sequential         | 101 K  | train
1 | train_acc | MulticlassAccuracy | 0      | train
2 | val_acc   | MulticlassAccuracy | 0      | train
3 | test_acc  | MulticlassAccuracy | 0      | train
---------------------------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)
8         Modules in train mode
0         Modules in eval mode
                                                                            
/home/hell/Desktop/lightning/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
/home/hell/Desktop/lightning/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
Epoch 2: 100%|██████████| 860/860 [00:15<00:00, 55.43it/s, v_num=4, train_acc=0.958, val_loss=0.124, val_acc=0.965]
`Trainer.fit` stopped: `max_epochs=3` reached.
Epoch 2: 100%|██████████| 860/860 [00:15<00:00, 55.40it/s, v_num=4, train_acc=0.958, val_loss=0.124, val_acc=0.965]
trainer.test(model, datamodule=dm)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/hell/Desktop/lightning/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
Testing DataLoader 0: 100%|██████████| 157/157 [00:01<00:00, 80.78it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│         test_acc              0.9667999744415283     │
│         test_loss             0.11054354161024094    │
└───────────────────────────┴───────────────────────────┘
[{'test_loss': 0.11054354161024094, 'test_acc': 0.9667999744415283}]