Home » Tutorials » PyTorch Lightning » Finetune Transformers Models with PyTorch Lightning

Finetune Transformers Models with PyTorch Lightning

An adaptation of Finetune transformers models with pytorch lightning tutorial using Habana Gaudi AI processors.

This notebook will use HuggingFace’s datasets library to get data, which will be wrapped in a LightningDataModule. Then, we write a class to perform text classification on any dataset from the GLUE Benchmark. (We just show CoLA and MRPC due to constraint on compute/disk)

Setup

This notebook requires some packages besides pytorch-lightning.

Download Habana Model-References from github

! git clone https://github.com/HabanaAI/Model-References.git
Cloning into 'Model-References'... error: unable to create file PyTorch/nlp/finetuning/huggingface/bert/transformers/src/transformers/models/speech_encoder_decoder/convert_mbart_wav2vec2_seq2seq_original_to_pytorch.py: Filename too long error: unable to create file PyTorch/nlp/finetuning/huggingface/bert/transformers/src/transformers/models/speech_encoder_decoder/convert_speech_to_text_wav2vec2_seq2seq_original_to_pytorch.py: Filename too long error: unable to create file PyTorch/nlp/finetuning/huggingface/bert/transformers/src/transformers/models/unispeech_sat/convert_unispeech_sat_original_pytorch_checkpoint_to_pytorch.py: Filename too long error: unable to create file PyTorch/nlp/finetuning/huggingface/bert/transformers/src/transformers/models/visual_bert/convert_visual_bert_original_pytorch_checkpoint_to_pytorch.py: Filename too long Updating files: 38% (1029/2666) Updating files: 39% (1040/2666) Updating files: 40% (1067/2666) Updating files: 41% (1094/2666) Updating files: 42% (1120/2666) Updating files: 43% (1147/2666) Updating files: 44% (1174/2666) Updating files: 45% (1200/2666) Updating files: 46% (1227/2666) Updating files: 47% (1254/2666) Updating files: 48% (1280/2666) Updating files: 49% (1307/2666) Updating files: 50% (1333/2666) Updating files: 51% (1360/2666) Updating files: 52% (1387/2666) Updating files: 53% (1413/2666) Updating files: 54% (1440/2666) Updating files: 55% (1467/2666) Updating files: 56% (1493/2666) Updating files: 57% (1520/2666) Updating files: 58% (1547/2666) Updating files: 59% (1573/2666) Updating files: 60% (1600/2666) Updating files: 61% (1627/2666) Updating files: 62% (1653/2666) Updating files: 63% (1680/2666) Updating files: 64% (1707/2666) Updating files: 65% (1733/2666) Updating files: 66% (1760/2666) Updating files: 67% (1787/2666) Updating files: 68% (1813/2666) Updating files: 69% (1840/2666) Updating files: 70% (1867/2666) Updating files: 71% (1893/2666) Updating files: 72% (1920/2666) Updating files: 73% (1947/2666) Updating files: 74% (1973/2666) Updating files: 75% (2000/2666) Updating files: 76% (2027/2666) Updating files: 77% (2053/2666) Updating files: 78% (2080/2666) Updating files: 78% (2099/2666) Updating files: 79% (2107/2666) Updating files: 80% (2133/2666) Updating files: 81% (2160/2666) Updating files: 82% (2187/2666) Updating files: 83% (2213/2666) Updating files: 84% (2240/2666) Updating files: 85% (2267/2666) Updating files: 86% (2293/2666) Updating files: 87% (2320/2666) Updating files: 88% (2347/2666) Updating files: 89% (2373/2666) Updating files: 90% (2400/2666) Updating files: 91% (2427/2666) Updating files: 92% (2453/2666) Updating files: 93% (2480/2666) Updating files: 94% (2507/2666) Updating files: 95% (2533/2666) Updating files: 96% (2560/2666) Updating files: 97% (2587/2666) Updating files: 98% (2613/2666) Updating files: 99% (2640/2666) Updating files: 100% (2666/2666) Updating files: 100% (2666/2666), done. fatal: unable to checkout working tree warning: Clone succeeded, but checkout failed. You can inspect what was checked out with 'git status' and retry with 'git restore --source=HEAD :/'
Code language: JavaScript (javascript)
! pip install --quiet "datasets" "scipy" "sklearn"
WARNING: You are using pip version 19.3.1; however, version 22.0.3 is available. You should consider upgrading via the 'pip install --upgrade pip' command.
Code language: JavaScript (javascript)
! pip install Model-References/PyTorch/nlp/finetuning/huggingface/bert/transformers/.
Processing ./Model-References/PyTorch/nlp/finetuning/huggingface/bert/transformers Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from transformers==4.15.0) (3.6.0) Requirement already satisfied: huggingface-hub<1.0,>=0.1.0 in /usr/local/lib/python3.8/dist-packages (from transformers==4.15.0) (0.4.0) Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from transformers==4.15.0) (1.22.2) Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from transformers==4.15.0) (21.3) Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.8/dist-packages (from transformers==4.15.0) (5.4.1) Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.8/dist-packages (from transformers==4.15.0) (2020.10.28) Requirement already satisfied: requests in /usr/local/lib/python3.8/dist-packages (from transformers==4.15.0) (2.27.1) Requirement already satisfied: sacremoses in /usr/local/lib/python3.8/dist-packages (from transformers==4.15.0) (0.0.47) Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /usr/local/lib/python3.8/dist-packages (from transformers==4.15.0) (0.10.3) Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.8/dist-packages (from transformers==4.15.0) (4.62.3) Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<1.0,>=0.1.0->transformers==4.15.0) (3.10.0.2) Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.8/dist-packages (from packaging>=20.0->transformers==4.15.0) (3.0.7) Requirement already satisfied: idna<4,>=2.5; python_version >= "3" in /usr/local/lib/python3.8/dist-packages (from requests->transformers==4.15.0) (3.3) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests->transformers==4.15.0) (1.26.8) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests->transformers==4.15.0) (2021.10.8) Requirement already satisfied: charset-normalizer~=2.0.0; python_version >= "3" in /usr/local/lib/python3.8/dist-packages (from requests->transformers==4.15.0) (2.0.12) Requirement already satisfied: click in /usr/local/lib/python3.8/dist-packages (from sacremoses->transformers==4.15.0) (8.0.4) Requirement already satisfied: joblib in /usr/local/lib/python3.8/dist-packages (from sacremoses->transformers==4.15.0) (1.1.0) Requirement already satisfied: six in /usr/local/lib/python3.8/dist-packages (from sacremoses->transformers==4.15.0) (1.16.0) Building wheels for collected packages: transformers Building wheel for transformers (setup.py) ... done Created wheel for transformers: filename=transformers-4.15.0-cp38-none-any.whl size=3337816 sha256=5d1495f939699635b956d11e3cf1b916afd6589771566b6e177f379a1b854f3b Stored in directory: /tmp/pip-ephem-wheel-cache-cvgd2e2u/wheels/da/a9/2f/465b33c0a36e032a57095c7e9f6e8500c8d68f2335910d9931 Successfully built transformers Installing collected packages: transformers Found existing installation: transformers 4.15.0 Uninstalling transformers-4.15.0: Successfully uninstalled transformers-4.15.0 Successfully installed transformers-4.15.0 WARNING: You are using pip version 19.3.1; however, version 22.0.3 is available. You should consider upgrading via the 'pip install --upgrade pip' command.
Code language: JavaScript (javascript)
from datetime import datetime
from typing import Optional
import os
import datasets
import torch
from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything
from torch.utils.data import DataLoader
from transformers import (
    AdamW,
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
)
from habana_frameworks.torch.utils.library_loader import load_habana_module

HPU Initialization

including load HPU library and set env variables

def HPUInitialization():
    os.environ['MAX_WAIT_ATTEMPTS'] = "50"
    os.environ['PT_HPU_ENABLE_SYNC_OUTPUT_HOST'] = "false"
    load_habana_module()
    
HPUInitialization()
Loading Habana modules from /usr/local/lib/python3.8/dist-packages/habana_frameworks/torch/lib
Code language: JavaScript (javascript)

Training BERT with Lightning

Lightning DataModule for GLUE

class GLUEDataModule(LightningDataModule):

    task_text_field_map = {
        "cola": ["sentence"],
        "sst2": ["sentence"],
        "mrpc": ["sentence1", "sentence2"],
        "qqp": ["question1", "question2"],
        "stsb": ["sentence1", "sentence2"],
        "mnli": ["premise", "hypothesis"],
        "qnli": ["question", "sentence"],
        "rte": ["sentence1", "sentence2"],
        "wnli": ["sentence1", "sentence2"],
        "ax": ["premise", "hypothesis"],
    }

    glue_task_num_labels = {
        "cola": 2,
        "sst2": 2,
        "mrpc": 2,
        "qqp": 2,
        "stsb": 1,
        "mnli": 3,
        "qnli": 2,
        "rte": 2,
        "wnli": 2,
        "ax": 3,
    }

    loader_columns = [
        "datasets_idx",
        "input_ids",
        "token_type_ids",
        "attention_mask",
        "start_positions",
        "end_positions",
        "labels",
    ]

    def __init__(
        self,
        model_name_or_path: str,
        task_name: str = "mrpc",
        max_seq_length: int = 128,
        train_batch_size: int = 32,
        eval_batch_size: int = 32,
        **kwargs,
    ):
        super().__init__()
        self.model_name_or_path = model_name_or_path
        self.task_name = task_name
        self.max_seq_length = max_seq_length
        self.train_batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size

        self.text_fields = self.task_text_field_map[task_name]
        self.num_labels = self.glue_task_num_labels[task_name]
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)

    def setup(self, stage: str):
        self.dataset = datasets.load_dataset("glue", self.task_name)

        for split in self.dataset.keys():
            self.dataset[split] = self.dataset[split].map(
                self.convert_to_features,
                batched=True,
                remove_columns=["label"],
            )
            self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]
            self.dataset[split].set_format(type="torch", columns=self.columns)

        self.eval_splits = [x for x in self.dataset.keys() if "validation" in x]

    def prepare_data(self):
        datasets.load_dataset("glue", self.task_name)
        AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)

    def train_dataloader(self):
        return DataLoader(self.dataset["train"], batch_size=self.train_batch_size)

    def val_dataloader(self):
        if len(self.eval_splits) == 1:
            return DataLoader(self.dataset["validation"], batch_size=self.eval_batch_size)
        elif len(self.eval_splits) > 1:
            return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]

    def test_dataloader(self):
        if len(self.eval_splits) == 1:
            return DataLoader(self.dataset["test"], batch_size=self.eval_batch_size)
        elif len(self.eval_splits) > 1:
            return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]

    def convert_to_features(self, example_batch, indices=None):

        # Either encode single sentence or sentence pairs
        if len(self.text_fields) > 1:
            texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))
        else:
            texts_or_text_pairs = example_batch[self.text_fields[0]]

        # Tokenize the text/text pairs
        features = self.tokenizer.batch_encode_plus(
            texts_or_text_pairs, max_length=self.max_seq_length, pad_to_max_length=True, truncation=True
        )

        # Rename label to labels to make it easier to pass to model forward
        features["labels"] = example_batch["label"]

        return features
--------------------------------------------------------------------------- NameError Traceback (most recent call last) ~\AppData\Local\Temp/ipykernel_2392/2472170714.py in <module> ----> 1 class GLUEDataModule(LightningDataModule): 2 3 task_text_field_map = { 4 "cola": ["sentence"], 5 "sst2": ["sentence"], NameError: name 'LightningDataModule' is not defined
Code language: HTML, XML (xml)

Prepare the data using datamodule

dm = GLUEDataModule("distilbert-base-uncased")
dm.prepare_data()
dm.setup("fit")
next(iter(dm.train_dataloader()))
Reusing dataset glue (/root/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad) Reusing dataset glue (/root/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad) Loading cached processed dataset at /root/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-419422735d445439.arrow Loading cached processed dataset at /root/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-66e685e20329ff99.arrow Loading cached processed dataset at /root/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-45b22f8d8d64c8ae.arrow
Code language: JavaScript (javascript)
{'input_ids': tensor([[ 101, 2572, 3217, ..., 0, 0, 0], [ 101, 9805, 3540, ..., 0, 0, 0], [ 101, 2027, 2018, ..., 0, 0, 0], ..., [ 101, 1996, 2922, ..., 0, 0, 0], [ 101, 6202, 1999, ..., 0, 0, 0], [ 101, 16565, 2566, ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, ..., 0, 0, 0], [1, 1, 1, ..., 0, 0, 0], [1, 1, 1, ..., 0, 0, 0], ..., [1, 1, 1, ..., 0, 0, 0], [1, 1, 1, ..., 0, 0, 0], [1, 1, 1, ..., 0, 0, 0]]), 'labels': tensor([1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0])}
Code language: JavaScript (javascript)

Transformer LightningModule

class GLUETransformer(LightningModule):
    def __init__(
        self,
        model_name_or_path: str,
        num_labels: int,
        task_name: str,
        learning_rate: float = 2e-5,
        adam_epsilon: float = 1e-8,
        warmup_steps: int = 0,
        weight_decay: float = 0.0,
        train_batch_size: int = 32,
        eval_batch_size: int = 32,
        eval_splits: Optional[list] = None,
        **kwargs,
    ):
        super().__init__()

        self.save_hyperparameters()

        self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_labels)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=self.config)
        self.metric = datasets.load_metric(
            "glue", self.hparams.task_name, experiment_id=datetime.now().strftime("%d-%m-%Y_%H-%M-%S")
        )

    def forward(self, **inputs):
        return self.model(**inputs)

    def training_step(self, batch, batch_idx):
        outputs = self(**batch)
        loss = outputs[0]
        return loss

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        outputs = self(**batch)
        val_loss, logits = outputs[:2]

        if self.hparams.num_labels >= 1:
            preds = torch.argmax(logits, axis=1)
        elif self.hparams.num_labels == 1:
            preds = logits.squeeze()

        labels = batch["labels"]

        return {"loss": val_loss, "preds": preds, "labels": labels}

    def validation_epoch_end(self, outputs):
        if self.hparams.task_name == "mnli":
            for i, output in enumerate(outputs):
                # matched or mismatched
                split = self.hparams.eval_splits[i].split("_")[-1]
                preds = torch.cat([x["preds"] for x in output]).detach().cpu().numpy()
                labels = torch.cat([x["labels"] for x in output]).detach().cpu().numpy()
                loss = torch.stack([x["loss"] for x in output]).mean()
                self.log(f"val_loss_{split}", loss, prog_bar=True)
                split_metrics = {
                    f"{k}_{split}": v for k, v in self.metric.compute(predictions=preds, references=labels).items()
                }
                self.log_dict(split_metrics, prog_bar=True)
            return loss

        preds = torch.cat([x["preds"] for x in outputs]).detach().cpu().numpy()
        labels = torch.cat([x["labels"] for x in outputs]).detach().cpu().numpy()
        loss = torch.stack([x["loss"] for x in outputs]).mean()
        self.log("val_loss", loss, prog_bar=True)
        self.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True)
        return loss

    def setup(self, stage=None) -> None:
        if stage != "fit":
            return
        # Get dataloader by calling it - train_dataloader() is called after setup() by default
        train_loader = self.trainer.datamodule.train_dataloader()

        # Calculate total steps
        tb_size = self.hparams.train_batch_size * max(1, self.trainer.hpus)
        ab_size = self.trainer.accumulate_grad_batches * float(self.trainer.max_epochs)
        self.total_steps = (len(train_loader.dataset) // tb_size) // ab_size

    def configure_optimizers(self):
        """Prepare optimizer and schedule (linear warmup and decay)"""
        model = self.model
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": self.hparams.weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.hparams.warmup_steps,
            num_training_steps=self.total_steps,
        )
        scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
        return [optimizer], [scheduler]

Training

CoLA

See an interactive view of the CoLA dataset in NLP Viewer

seed_everything(42)

dm = GLUEDataModule(model_name_or_path="albert-base-v2", task_name="cola")
dm.setup("fit")
model = GLUETransformer(
    model_name_or_path="albert-base-v2",
    num_labels=dm.num_labels,
    eval_splits=dm.eval_splits,
    task_name=dm.task_name,
)

trainer = Trainer(max_epochs=1, hpus=1)
trainer.fit(model, datamodule=dm)

MNLI

The MNLI dataset is huge, so we aren’t going to bother trying to train on it here.

We will skip over training and go straight to validation.

See an interactive view of the MRPC dataset in NLP Viewer

dm = GLUEDataModule(
    model_name_or_path="distilbert-base-cased",
    task_name="mnli",
)
dm.setup("fit")
model = GLUETransformer(
    model_name_or_path="distilbert-base-cased",
    num_labels=dm.num_labels,
    eval_splits=dm.eval_splits,
    task_name=dm.task_name,
)

trainer = Trainer(hpus=1, progress_bar_refresh_rate=20)
trainer.validate(model, dm.val_dataloader())
trainer.logged_metrics

Copyright (c) 2022 Habana Labs, Ltd. an Intel Company.
All rights reserved.

License

Licensed under a CC BY SA 4.0 license.

A derivative of Introduction To PyTorch Lightning by PL Team

Sign up for the latest Habana developer news, events, training, and updates.