PyTorch Lightning: calculating different metrics#

Torchmetrics is a library that provides a collection of metrics for PyTorch. These metrics do all the heavylifting for you. Here we will focus on:

  • per-class accuracy

  • precision/recall/F1

  • top-5 accuracy

  • confusion matrix

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import random_split, DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
import torchmetrics # for metrics
class MNISTDataModule(L.LightningDataModule):
    def __init__(self, batch_size=64):
        super().__init__()
        self.batch_size = batch_size
        self.transform = transforms.ToTensor()

    def prepare_data(self):
        MNIST(root="./MNIST", train=True, download=True)
        MNIST(root="./MNIST", train=True, download=True)
    
    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            full = MNIST(root="./MNIST", train=True, transform= self.transform)
            self.train, self.val = random_split(full, [55000, 5000])
        
        if stage == "test" or stage is None:
            self.test = MNIST(root="./MNIST", train=False, transform= self.transform)
    
    def train_dataloader(self):
        return DataLoader(self.train, batch_size= self.batch_size)
    
    def val_dataloader(self):
        return DataLoader(self.val, batch_size= self.batch_size)
    
    def test_dataloader(self):
        return DataLoader(self.test, batch_size= self.batch_size)
class LitModel(L.LightningModule): # a replacesment of nn.Module
    def __init__(self):
      super().__init__() # call __init__ of the super class to init important LightningModule functions
      self.model = nn.Sequential(
         nn.Flatten(),
         nn.Linear(28*28, 128),
         nn.ReLU(),
         nn.Linear(128, 10)
      )

      self.train_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=10)
      self.val_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=10)
      self.test_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=10)

      self.precision = torchmetrics.classification.MulticlassPrecision(num_classes=10, average='macro')
      self.recall = torchmetrics.classification.MulticlassRecall(num_classes=10, average='macro')
      self.f1 = torchmetrics.classification.MulticlassF1Score(num_classes=10, average='macro')

      self.top5_acc = torchmetrics.classification.MulticlassAccuracy(num_classes=10, top_k=5)
      self.confmat = torchmetrics.classification.MulticlassConfusionMatrix(num_classes=10) # not implemented

    def forward(self, x):
      return self.model(x)
   
    def training_step(self, batch, batch_idx):
      x, y = batch
      logits = self(x)
      loss = F.cross_entropy(logits, y)
      acc = self.train_acc(logits, y)

      self.log("train_loss", loss)
      self.log("train_acc", acc, prog_bar=True)
      
      #return loss
      return {"loss": loss} #both are the same

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        acc = self.val_acc(logits, y)
        top5 = self.top5_acc(logits, y)
        prec = self.precision(logits, y)
        rec = self.recall(logits, y)
        f1 = self.f1(logits, y)

        self.log("val_loss", loss, prog_bar=True) # prog_bar=False will not show the val loss in the training progress bar.
        self.log("val_acc", acc, prog_bar=True)
        self.log("val_top5_acc", top5)
        self.log("val_precision", prec)
        self.log("val_recall", rec)
        self.log("val_f1", f1)
        
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        acc = self.test_acc(logits, y)
        top5 = self.top5_acc(logits, y)
        prec = self.precision(logits, y)
        rec = self.recall(logits, y)
        f1 = self.f1(logits, y)


        self.log("test_loss", loss)
        self.log("test_acc", acc, prog_bar=True)
        self.log("test_top5_acc", top5)
        self.log("test_precision", prec)
        self.log("test_recall", rec)
        self.log("test_f1", f1)

    def configure_optimizers(self):
      return torch.optim.Adam(self.parameters(), lr=1e-3) # the NN get the parameters not self.model.parameters()
# Datamodule
dm = MNISTDataModule()
# checkpoint based on val loss
checkpoint_cb = ModelCheckpoint(monitor="val_loss", mode="min")

# trainer
model = LitModel()
trainer = L.Trainer(max_epochs = 3,
                    accelerator="auto", # auto will select gpu if available
                    callbacks=[checkpoint_cb]) 
trainer.fit(model, datamodule=dm)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                      | Params | Mode 
----------------------------------------------------------------
0 | model     | Sequential                | 101 K  | train
1 | train_acc | MulticlassAccuracy        | 0      | train
2 | val_acc   | MulticlassAccuracy        | 0      | train
3 | test_acc  | MulticlassAccuracy        | 0      | train
4 | precision | MulticlassPrecision       | 0      | train
5 | recall    | MulticlassRecall          | 0      | train
6 | f1        | MulticlassF1Score         | 0      | train
7 | top5_acc  | MulticlassAccuracy        | 0      | train
8 | confmat   | MulticlassConfusionMatrix | 0      | train
----------------------------------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)
13        Modules in train mode
0         Modules in eval mode
                                                                           
/home/hell/Desktop/lightning/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
/home/hell/Desktop/lightning/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
Epoch 2: 100%|██████████| 860/860 [00:16<00:00, 50.92it/s, v_num=6, train_acc=1.000, val_loss=0.130, val_acc=0.962]
`Trainer.fit` stopped: `max_epochs=3` reached.
Epoch 2: 100%|██████████| 860/860 [00:16<00:00, 50.89it/s, v_num=6, train_acc=1.000, val_loss=0.130, val_acc=0.962]
trainer.test(model, datamodule=dm)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/hell/Desktop/lightning/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
Testing DataLoader 0: 100%|██████████| 157/157 [00:03<00:00, 48.08it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│         test_acc              0.9656000137329102     │
│          test_f1               0.961219847202301     │
│         test_loss             0.11702169477939606    │
│      test_precision           0.9657908082008362     │
│        test_recall            0.9643090963363647     │
│       test_top5_acc           0.9992062449455261     │
└───────────────────────────┴───────────────────────────┘
[{'test_loss': 0.11702169477939606,
  'test_acc': 0.9656000137329102,
  'test_top5_acc': 0.9992062449455261,
  'test_precision': 0.9657908082008362,
  'test_recall': 0.9643090963363647,
  'test_f1': 0.961219847202301}]