PyTorch Lightning#

PyTorch Lightning is a high-level framework built on top of PyTorch to make training deep learning models more structured and less repetitive. There are several modules that are central to the PyTorch Lightning. It elimination the need of boilerplate code. They have a nice documentation.

LightningModule#

An extension of the PyTorch nn.Module class that provides a standard interface for training, validation, testing and optimisation. Unlike nn.Module, this module provides methods like training_step(), validation_step(), configure_optimizers(), etc. This makes the code standardised and easy to read.

LightningDataModule#

The module combines all the data loading logic into a single class. It provides methods for downloading, preparing, and loading data. This is useful for separating the data logic from the model logic. This is optional, and LightningModule can also be used without it.

Trainer#

Handles training, validation, testing logic, such as logging, checkpointing, and GPU usage. If you want to stop a training run early, you can press “Ctrl + C” on your keyboard and the data till that point will be saved.

Torchmetrics#

It is a separate library that provides a collection of metrics for PyTorch Lightning. All the metrics we ever need without detaching the graph, converting to numpy all those chores are handled by the library.

Reproducibility#

PL provides lightning.seed_everything() to set the random seed for all libraries (NumPy, PyTorch, Python random, etc.).

Saving results#

PL automatically saves a checkpoint of the model based on the validation loss. It also saves a csv file with the training and validation metrics. The default path for saving the checkpoints is lightning_logs/version_0/checkpoints/ in pwd. each subsequent run will create a new version folder.

Logging#

PL provides a built-in logger that can log to TensorBoard, Comet, MLFlow, and other platforms.

Device selection#

No need to manually create a device variable and move the model to GPU or TPU. PL automatically moves the model to the appropriate device based on the Trainer settings. L.Trainer(accelerator="auto") will do the job.