PyTorch Lightning: Data module and test dataset#
The PyTorch Lightning data module is a standard way to organize your data loading code. It helps you separate the data preparation from the model training code, making your code cleaner and more maintainable.
A typical data module looks like this:
class MyDataModule(pl.LightningDataModule):
def __init__(self, data_dir, batch_size):
super().__init__()
self.batch_size = batch_size
self.data_dir = data_dir
def prepare_data(self):
# Download or prepare your data here
pass
def setup(self, stage=None):
# Load your data here
self.train_dataset = MyDataset(train=True)
self.val_dataset = MyDataset(train=False)
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size)
The only change in the Trainer will be:
trainer.fit(model, datamodule=dm)
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import random_split, DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
import torchmetrics # for metrics
Data module#
class MNISTDataModule(L.LightningDataModule):
def __init__(self, batch_size=64):
super().__init__()
self.batch_size = batch_size
self.transform = transforms.ToTensor()
def prepare_data(self):
MNIST(root="./MNIST", train=True, download=True)
MNIST(root="./MNIST", train=True, download=True)
def setup(self, stage=None):
if stage == "fit" or stage is None:
full = MNIST(root="./MNIST", train=True, transform= self.transform)
self.train, self.val = random_split(full, [55000, 5000])
if stage == "test" or stage is None:
self.test = MNIST(root="./MNIST", train=False, transform= self.transform)
def train_dataloader(self):
return DataLoader(self.train, batch_size= self.batch_size)
def val_dataloader(self):
return DataLoader(self.val, batch_size= self.batch_size)
def test_dataloader(self):
return DataLoader(self.test, batch_size= self.batch_size)
NN class#
class LitModel(L.LightningModule): # a replacesment of nn.Module
def __init__(self):
super().__init__() # call __init__ of the super class to init important LightningModule functions
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(28*28, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
self.train_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=10)
self.val_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=10)
self.test_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=10)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
acc = self.train_acc(logits, y)
self.log("train_loss", loss)
self.log("train_acc", acc, prog_bar=True)
#return loss
return {"loss": loss} #both are the same
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
acc = self.val_acc(logits, y)
self.log("val_loss", loss, prog_bar=True) # prog_bar=False will not show the val loss in the training progress bar.
self.log("val_acc", acc, prog_bar=True)
def test_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
acc = self.test_acc(logits, y)
self.log("test_loss", loss)
self.log("test_acc", acc, prog_bar=True)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3) # the NN get the parameters not self.model.parameters()
# Datamodule
dm = MNISTDataModule()
# checkpoint based on val loss
checkpoint_cb = ModelCheckpoint(monitor="val_loss", mode="min")
# trainer
model = LitModel()
trainer = L.Trainer(max_epochs = 3,
accelerator="auto", # auto will select gpu if available
callbacks=[checkpoint_cb])
trainer.fit(model, datamodule=dm)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params | Mode
---------------------------------------------------------
0 | model | Sequential | 101 K | train
1 | train_acc | MulticlassAccuracy | 0 | train
2 | val_acc | MulticlassAccuracy | 0 | train
3 | test_acc | MulticlassAccuracy | 0 | train
---------------------------------------------------------
101 K Trainable params
0 Non-trainable params
101 K Total params
0.407 Total estimated model params size (MB)
8 Modules in train mode
0 Modules in eval mode
/home/hell/Desktop/lightning/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
/home/hell/Desktop/lightning/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
Epoch 2: 100%|██████████| 860/860 [00:15<00:00, 55.43it/s, v_num=4, train_acc=0.958, val_loss=0.124, val_acc=0.965]
`Trainer.fit` stopped: `max_epochs=3` reached.
Epoch 2: 100%|██████████| 860/860 [00:15<00:00, 55.40it/s, v_num=4, train_acc=0.958, val_loss=0.124, val_acc=0.965]
trainer.test(model, datamodule=dm)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/hell/Desktop/lightning/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
Testing DataLoader 0: 100%|██████████| 157/157 [00:01<00:00, 80.78it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ DataLoader 0 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ test_acc │ 0.9667999744415283 │ │ test_loss │ 0.11054354161024094 │ └───────────────────────────┴───────────────────────────┘
[{'test_loss': 0.11054354161024094, 'test_acc': 0.9667999744415283}]