PyTorch Lightning: calculating different metrics#
Torchmetrics is a library that provides a collection of metrics for PyTorch. These metrics do all the heavylifting for you. Here we will focus on:
per-class accuracy
precision/recall/F1
top-5 accuracy
confusion matrix
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import random_split, DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
import torchmetrics # for metrics
class MNISTDataModule(L.LightningDataModule):
def __init__(self, batch_size=64):
super().__init__()
self.batch_size = batch_size
self.transform = transforms.ToTensor()
def prepare_data(self):
MNIST(root="./MNIST", train=True, download=True)
MNIST(root="./MNIST", train=True, download=True)
def setup(self, stage=None):
if stage == "fit" or stage is None:
full = MNIST(root="./MNIST", train=True, transform= self.transform)
self.train, self.val = random_split(full, [55000, 5000])
if stage == "test" or stage is None:
self.test = MNIST(root="./MNIST", train=False, transform= self.transform)
def train_dataloader(self):
return DataLoader(self.train, batch_size= self.batch_size)
def val_dataloader(self):
return DataLoader(self.val, batch_size= self.batch_size)
def test_dataloader(self):
return DataLoader(self.test, batch_size= self.batch_size)
class LitModel(L.LightningModule): # a replacesment of nn.Module
def __init__(self):
super().__init__() # call __init__ of the super class to init important LightningModule functions
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(28*28, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
self.train_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=10)
self.val_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=10)
self.test_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=10)
self.precision = torchmetrics.classification.MulticlassPrecision(num_classes=10, average='macro')
self.recall = torchmetrics.classification.MulticlassRecall(num_classes=10, average='macro')
self.f1 = torchmetrics.classification.MulticlassF1Score(num_classes=10, average='macro')
self.top5_acc = torchmetrics.classification.MulticlassAccuracy(num_classes=10, top_k=5)
self.confmat = torchmetrics.classification.MulticlassConfusionMatrix(num_classes=10) # not implemented
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
acc = self.train_acc(logits, y)
self.log("train_loss", loss)
self.log("train_acc", acc, prog_bar=True)
#return loss
return {"loss": loss} #both are the same
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
acc = self.val_acc(logits, y)
top5 = self.top5_acc(logits, y)
prec = self.precision(logits, y)
rec = self.recall(logits, y)
f1 = self.f1(logits, y)
self.log("val_loss", loss, prog_bar=True) # prog_bar=False will not show the val loss in the training progress bar.
self.log("val_acc", acc, prog_bar=True)
self.log("val_top5_acc", top5)
self.log("val_precision", prec)
self.log("val_recall", rec)
self.log("val_f1", f1)
def test_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
acc = self.test_acc(logits, y)
top5 = self.top5_acc(logits, y)
prec = self.precision(logits, y)
rec = self.recall(logits, y)
f1 = self.f1(logits, y)
self.log("test_loss", loss)
self.log("test_acc", acc, prog_bar=True)
self.log("test_top5_acc", top5)
self.log("test_precision", prec)
self.log("test_recall", rec)
self.log("test_f1", f1)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3) # the NN get the parameters not self.model.parameters()
# Datamodule
dm = MNISTDataModule()
# checkpoint based on val loss
checkpoint_cb = ModelCheckpoint(monitor="val_loss", mode="min")
# trainer
model = LitModel()
trainer = L.Trainer(max_epochs = 3,
accelerator="auto", # auto will select gpu if available
callbacks=[checkpoint_cb])
trainer.fit(model, datamodule=dm)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params | Mode
----------------------------------------------------------------
0 | model | Sequential | 101 K | train
1 | train_acc | MulticlassAccuracy | 0 | train
2 | val_acc | MulticlassAccuracy | 0 | train
3 | test_acc | MulticlassAccuracy | 0 | train
4 | precision | MulticlassPrecision | 0 | train
5 | recall | MulticlassRecall | 0 | train
6 | f1 | MulticlassF1Score | 0 | train
7 | top5_acc | MulticlassAccuracy | 0 | train
8 | confmat | MulticlassConfusionMatrix | 0 | train
----------------------------------------------------------------
101 K Trainable params
0 Non-trainable params
101 K Total params
0.407 Total estimated model params size (MB)
13 Modules in train mode
0 Modules in eval mode
/home/hell/Desktop/lightning/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
/home/hell/Desktop/lightning/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
Epoch 2: 100%|██████████| 860/860 [00:16<00:00, 50.92it/s, v_num=6, train_acc=1.000, val_loss=0.130, val_acc=0.962]
`Trainer.fit` stopped: `max_epochs=3` reached.
Epoch 2: 100%|██████████| 860/860 [00:16<00:00, 50.89it/s, v_num=6, train_acc=1.000, val_loss=0.130, val_acc=0.962]
trainer.test(model, datamodule=dm)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/hell/Desktop/lightning/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
Testing DataLoader 0: 100%|██████████| 157/157 [00:03<00:00, 48.08it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ DataLoader 0 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ test_acc │ 0.9656000137329102 │ │ test_f1 │ 0.961219847202301 │ │ test_loss │ 0.11702169477939606 │ │ test_precision │ 0.9657908082008362 │ │ test_recall │ 0.9643090963363647 │ │ test_top5_acc │ 0.9992062449455261 │ └───────────────────────────┴───────────────────────────┘
[{'test_loss': 0.11702169477939606,
'test_acc': 0.9656000137329102,
'test_top5_acc': 0.9992062449455261,
'test_precision': 0.9657908082008362,
'test_recall': 0.9643090963363647,
'test_f1': 0.961219847202301}]