Home » Tutorials » PyTorch Lightning » Introduction To PyTorch Lightning

Introduction To PyTorch Lightning

An adaptation of Introduction to PyTorch Lightning tutorial using Habana Gaudi AI processors.

In this tutorial, we’ll go over the basics of lightning by preparing models to train on the MNIST Handwritten Digits dataset

Setup

This tutorial requires some packages besides pytorch-lightning.

! pip install --quiet "torchvision" "torchmetrics" 
WARNING: You are using pip version 19.3.1; however, version 22.0.4 is available. You should consider upgrading via the 'pip install --upgrade pip' command.
Code language: JavaScript (javascript)
import os
import torch
from pytorch_lightning import LightningModule, Trainer
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import MNIST
from habana_frameworks.torch.utils.library_loader import load_habana_module
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
BATCH_SIZE = 256

Gaudi Initialization

Include Gaudi HPU library and set env variables

def HPUInitialization():
    os.environ['MAX_WAIT_ATTEMPTS'] = "50"
    os.environ['PT_HPU_ENABLE_SYNC_OUTPUT_HOST'] = "false"
    load_habana_module()

Simplest example

Here’s the simplest most minimal example with just a training loop (no validation, no testing).

class MNISTModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

By using the Trainer you automatically get:

  1. Tensorboard logging
  2. Model checkpointing
  3. Training and validation loop
  4. early-stopping
HPUInitialization()

# Init our model
mnist_model = MNISTModel()

# Init DataLoader from MNIST Dataset
train_ds = MNIST(PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)

# Initialize a trainer
trainer = Trainer(
    hpus=1,
    max_epochs=3,
    progress_bar_refresh_rate=20,
)

# Train the model ⚡
trainer.fit(mnist_model, train_loader)
Loading Habana modules from /usr/local/lib/python3.8/dist-packages/habana_frameworks/torch/lib Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz 0%| | 0/9912422 [00:00<?, ?it/s] Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz 0%| | 0/28881 [00:00<?, ?it/s] Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz 0%| | 0/1648877 [00:00<?, ?it/s] Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz 0%| | 0/4542 [00:00<?, ?it/s] /usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/callback_connector.py:90: LightningDeprecationWarning: Setting `Trainer(progress_bar_refresh_rate=20)` is deprecated in v1.5 and will be removed in v1.7. Please pass `pytorch_lightning.callbacks.progress.TQDMProgressBar` with `refresh_rate` directly to the Trainer's `callbacks` argument instead. Or, to disable the progress bar pass `enable_progress_bar = False` to the Trainer. rank_zero_deprecation( GPU available: False, used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: True, using: 1 HPUs Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw | Name | Type | Params -------------------------------- 0 | l1 | Linear | 7.9 K -------------------------------- 7.9 K Trainable params 0 Non-trainable params 7.9 K Total params 0.031 Total estimated model params size (MB) /usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/data_loading.py:133: UserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 112 which is the number of cpus on this machine) in the `DataLoader` init to improve performance. rank_zero_warn( Training: 0it [00:00, ?it/s]
Code language: PHP (php)

A more complete MNIST Lightning Module Example

That wasn’t so hard was it?

Now that we’ve got our feet wet, let’s dive in a bit deeper and write a more complete LightningModule for MNIST…

This time, we’ll bake in all the dataset specific pieces directly in the LightningModule. This way, we can avoid writing extra code at the beginning of our script every time we want to run it.

Note what the following built-in functions are doing:


prepare_data()
 💾

  • This is where we can download the dataset. We point to our desired dataset and ask torchvision’s MNIST dataset class to download if the dataset isn’t found there.
  • Note we do not make any state assignments in this function (i.e. self.something = ...)

setup(stage) ⚙️

  • Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test).
  • Setup expects a ‘stage’ arg which is used to separate logic for ‘fit’ and ‘test’.
  • If you don’t mind loading all your datasets at once, you can set up a condition to allow for both ‘fit’ related setup and ‘test’ related setup to run whenever None is passed to stage (or ignore it altogether and exclude any conditionals).
  • Note this runs across all GPUs and it is safe to make state assignments here

x_dataloader() ♻️

  • train_dataloader()val_dataloader(), and test_dataloader() all return PyTorch DataLoader instances that are created by wrapping their respective datasets that we prepared in setup()
class LitMNIST(LightningModule):
    def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):

        super().__init__()

        # Set our init args as class attributes
        self.data_dir = data_dir
        self.hidden_size = hidden_size
        self.learning_rate = learning_rate

        # Hardcode some dataset specific attributes
        self.num_classes = 10
        self.dims = (1, 28, 28)
        channels, width, height = self.dims
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

        # Define PyTorch model
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, self.num_classes),
        )

        self.accuracy = Accuracy()

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.accuracy(preds, y)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", self.accuracy, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        # Here we just reuse the validation_step for testing
        return self.validation_step(batch, batch_idx)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    ####################
    # DATA RELATED HOOKS
    ####################

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)

Training the Model on Gaudi

model = LitMNIST()
trainer = Trainer(
    hpus=1,
    max_epochs=3,
    progress_bar_refresh_rate=20,
)
trainer.fit(model)
GPU available: False, used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: True, using: 1 HPUs | Name | Type | Params ---------------------------------------- 0 | model | Sequential | 55.1 K 1 | accuracy | Accuracy | 0 ---------------------------------------- 55.1 K Trainable params 0 Non-trainable params 55.1 K Total params 0.220 Total estimated model params size (MB) Validation sanity check: 0it [00:00, ?it/s] /usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/data_loading.py:133: UserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 112 which is the number of cpus on this machine) in the `DataLoader` init to improve performance. rank_zero_warn( Training: 0it [00:00, ?it/s] Validating: 0it [00:00, ?it/s] Validating: 0it [00:00, ?it/s] Validating: 0it [00:00, ?it/s]
Code language: PHP (php)

Testing

To test a model, call trainer.test(model).

trainer.test()
/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py:1413: UserWarning: `.test(ckpt_path=None)` was called without a model. The best model of the previous `fit` call will be used. You can pass `test(ckpt_path='best')` to use and best model checkpoint and avoid this warning or `ckpt_path=trainer.checkpoint_callback.last_model_path` to use the last model. rank_zero_warn( Restoring states from the checkpoint path at /model_garden/internal/PyTorch/cpu_fallback/pytorch-lightning/text-classification/lightning_logs/version_16/checkpoints/epoch=2-step=644.ckpt Loaded model weights from checkpoint at /model_garden/internal/PyTorch/cpu_fallback/pytorch-lightning/text-classification/lightning_logs/version_16/checkpoints/epoch=2-step=644.ckpt /usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/data_loading.py:133: UserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 112 which is the number of cpus on this machine) in the `DataLoader` init to improve performance. rank_zero_warn( Testing: 0it [00:00, ?it/s] -------------------------------------------------------------------------------- DATALOADER:0 TEST RESULTS {'val_acc': 0.9248999953269958, 'val_loss': 0.25047656893730164} -------------------------------------------------------------------------------- [{'val_loss': 0.25047656893730164, 'val_acc': 0.9248999953269958}]
Code language: JavaScript (javascript)

Copyright (c) 2022 Habana Labs, Ltd. an Intel Company.
All rights reserved.

License

Licensed under a CC BY SA 4.0 license.

A derivative of Introduction To PyTorch Lightning by PL Team

Sign up for the latest Habana developer news, events, training, and updates.