An adaptation of Finetune transformers models with pytorch lightning tutorial using Habana Gaudi AI processors.
This notebook will use HuggingFace’s datasets library to get data, which will be wrapped in a LightningDataModule. Then, we write a class to perform text classification on any dataset from the GLUE Benchmark. (We just show CoLA and MRPC due to constraint on compute/disk)
Setup
This notebook requires some packages besides pytorch-lightning.
Download Habana Model-References from github
! git clone https://github.com/HabanaAI/Model-References.git
Cloning into 'Model-References'...
error: unable to create file PyTorch/nlp/finetuning/huggingface/bert/transformers/src/transformers/models/speech_encoder_decoder/convert_mbart_wav2vec2_seq2seq_original_to_pytorch.py: Filename too long
error: unable to create file PyTorch/nlp/finetuning/huggingface/bert/transformers/src/transformers/models/speech_encoder_decoder/convert_speech_to_text_wav2vec2_seq2seq_original_to_pytorch.py: Filename too long
error: unable to create file PyTorch/nlp/finetuning/huggingface/bert/transformers/src/transformers/models/unispeech_sat/convert_unispeech_sat_original_pytorch_checkpoint_to_pytorch.py: Filename too long
error: unable to create file PyTorch/nlp/finetuning/huggingface/bert/transformers/src/transformers/models/visual_bert/convert_visual_bert_original_pytorch_checkpoint_to_pytorch.py: Filename too long
Updating files: 38% (1029/2666)
Updating files: 39% (1040/2666)
Updating files: 40% (1067/2666)
Updating files: 41% (1094/2666)
Updating files: 42% (1120/2666)
Updating files: 43% (1147/2666)
Updating files: 44% (1174/2666)
Updating files: 45% (1200/2666)
Updating files: 46% (1227/2666)
Updating files: 47% (1254/2666)
Updating files: 48% (1280/2666)
Updating files: 49% (1307/2666)
Updating files: 50% (1333/2666)
Updating files: 51% (1360/2666)
Updating files: 52% (1387/2666)
Updating files: 53% (1413/2666)
Updating files: 54% (1440/2666)
Updating files: 55% (1467/2666)
Updating files: 56% (1493/2666)
Updating files: 57% (1520/2666)
Updating files: 58% (1547/2666)
Updating files: 59% (1573/2666)
Updating files: 60% (1600/2666)
Updating files: 61% (1627/2666)
Updating files: 62% (1653/2666)
Updating files: 63% (1680/2666)
Updating files: 64% (1707/2666)
Updating files: 65% (1733/2666)
Updating files: 66% (1760/2666)
Updating files: 67% (1787/2666)
Updating files: 68% (1813/2666)
Updating files: 69% (1840/2666)
Updating files: 70% (1867/2666)
Updating files: 71% (1893/2666)
Updating files: 72% (1920/2666)
Updating files: 73% (1947/2666)
Updating files: 74% (1973/2666)
Updating files: 75% (2000/2666)
Updating files: 76% (2027/2666)
Updating files: 77% (2053/2666)
Updating files: 78% (2080/2666)
Updating files: 78% (2099/2666)
Updating files: 79% (2107/2666)
Updating files: 80% (2133/2666)
Updating files: 81% (2160/2666)
Updating files: 82% (2187/2666)
Updating files: 83% (2213/2666)
Updating files: 84% (2240/2666)
Updating files: 85% (2267/2666)
Updating files: 86% (2293/2666)
Updating files: 87% (2320/2666)
Updating files: 88% (2347/2666)
Updating files: 89% (2373/2666)
Updating files: 90% (2400/2666)
Updating files: 91% (2427/2666)
Updating files: 92% (2453/2666)
Updating files: 93% (2480/2666)
Updating files: 94% (2507/2666)
Updating files: 95% (2533/2666)
Updating files: 96% (2560/2666)
Updating files: 97% (2587/2666)
Updating files: 98% (2613/2666)
Updating files: 99% (2640/2666)
Updating files: 100% (2666/2666)
Updating files: 100% (2666/2666), done.
fatal: unable to checkout working tree
warning: Clone succeeded, but checkout failed.
You can inspect what was checked out with 'git status'
and retry with 'git restore --source=HEAD :/'
Code language: JavaScript (javascript)
! pip install --quiet "datasets" "scipy" "sklearn"
WARNING: You are using pip version 19.3.1; however, version 22.0.3 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.
Code language: JavaScript (javascript)
! pip install Model-References/PyTorch/nlp/finetuning/huggingface/bert/transformers/.
Processing ./Model-References/PyTorch/nlp/finetuning/huggingface/bert/transformers
Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from transformers==4.15.0) (3.6.0)
Requirement already satisfied: huggingface-hub<1.0,>=0.1.0 in /usr/local/lib/python3.8/dist-packages (from transformers==4.15.0) (0.4.0)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from transformers==4.15.0) (1.22.2)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from transformers==4.15.0) (21.3)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.8/dist-packages (from transformers==4.15.0) (5.4.1)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.8/dist-packages (from transformers==4.15.0) (2020.10.28)
Requirement already satisfied: requests in /usr/local/lib/python3.8/dist-packages (from transformers==4.15.0) (2.27.1)
Requirement already satisfied: sacremoses in /usr/local/lib/python3.8/dist-packages (from transformers==4.15.0) (0.0.47)
Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /usr/local/lib/python3.8/dist-packages (from transformers==4.15.0) (0.10.3)
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.8/dist-packages (from transformers==4.15.0) (4.62.3)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<1.0,>=0.1.0->transformers==4.15.0) (3.10.0.2)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.8/dist-packages (from packaging>=20.0->transformers==4.15.0) (3.0.7)
Requirement already satisfied: idna<4,>=2.5; python_version >= "3" in /usr/local/lib/python3.8/dist-packages (from requests->transformers==4.15.0) (3.3)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests->transformers==4.15.0) (1.26.8)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests->transformers==4.15.0) (2021.10.8)
Requirement already satisfied: charset-normalizer~=2.0.0; python_version >= "3" in /usr/local/lib/python3.8/dist-packages (from requests->transformers==4.15.0) (2.0.12)
Requirement already satisfied: click in /usr/local/lib/python3.8/dist-packages (from sacremoses->transformers==4.15.0) (8.0.4)
Requirement already satisfied: joblib in /usr/local/lib/python3.8/dist-packages (from sacremoses->transformers==4.15.0) (1.1.0)
Requirement already satisfied: six in /usr/local/lib/python3.8/dist-packages (from sacremoses->transformers==4.15.0) (1.16.0)
Building wheels for collected packages: transformers
Building wheel for transformers (setup.py) ... done
Created wheel for transformers: filename=transformers-4.15.0-cp38-none-any.whl size=3337816 sha256=5d1495f939699635b956d11e3cf1b916afd6589771566b6e177f379a1b854f3b
Stored in directory: /tmp/pip-ephem-wheel-cache-cvgd2e2u/wheels/da/a9/2f/465b33c0a36e032a57095c7e9f6e8500c8d68f2335910d9931
Successfully built transformers
Installing collected packages: transformers
Found existing installation: transformers 4.15.0
Uninstalling transformers-4.15.0:
Successfully uninstalled transformers-4.15.0
Successfully installed transformers-4.15.0
WARNING: You are using pip version 19.3.1; however, version 22.0.3 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.
Code language: JavaScript (javascript)
from datetime import datetime
from typing import Optional
import os
import datasets
import torch
from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything
from torch.utils.data import DataLoader
from transformers import (
AdamW,
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
get_linear_schedule_with_warmup,
)
from habana_frameworks.torch.utils.library_loader import load_habana_module
HPU Initialization
including load HPU library and set env variables
def HPUInitialization():
os.environ['MAX_WAIT_ATTEMPTS'] = "50"
os.environ['PT_HPU_ENABLE_SYNC_OUTPUT_HOST'] = "false"
load_habana_module()
HPUInitialization()
Loading Habana modules from /usr/local/lib/python3.8/dist-packages/habana_frameworks/torch/lib
Code language: JavaScript (javascript)
Training BERT with Lightning
Lightning DataModule for GLUE
class GLUEDataModule(LightningDataModule):
task_text_field_map = {
"cola": ["sentence"],
"sst2": ["sentence"],
"mrpc": ["sentence1", "sentence2"],
"qqp": ["question1", "question2"],
"stsb": ["sentence1", "sentence2"],
"mnli": ["premise", "hypothesis"],
"qnli": ["question", "sentence"],
"rte": ["sentence1", "sentence2"],
"wnli": ["sentence1", "sentence2"],
"ax": ["premise", "hypothesis"],
}
glue_task_num_labels = {
"cola": 2,
"sst2": 2,
"mrpc": 2,
"qqp": 2,
"stsb": 1,
"mnli": 3,
"qnli": 2,
"rte": 2,
"wnli": 2,
"ax": 3,
}
loader_columns = [
"datasets_idx",
"input_ids",
"token_type_ids",
"attention_mask",
"start_positions",
"end_positions",
"labels",
]
def __init__(
self,
model_name_or_path: str,
task_name: str = "mrpc",
max_seq_length: int = 128,
train_batch_size: int = 32,
eval_batch_size: int = 32,
**kwargs,
):
super().__init__()
self.model_name_or_path = model_name_or_path
self.task_name = task_name
self.max_seq_length = max_seq_length
self.train_batch_size = train_batch_size
self.eval_batch_size = eval_batch_size
self.text_fields = self.task_text_field_map[task_name]
self.num_labels = self.glue_task_num_labels[task_name]
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
def setup(self, stage: str):
self.dataset = datasets.load_dataset("glue", self.task_name)
for split in self.dataset.keys():
self.dataset[split] = self.dataset[split].map(
self.convert_to_features,
batched=True,
remove_columns=["label"],
)
self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]
self.dataset[split].set_format(type="torch", columns=self.columns)
self.eval_splits = [x for x in self.dataset.keys() if "validation" in x]
def prepare_data(self):
datasets.load_dataset("glue", self.task_name)
AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
def train_dataloader(self):
return DataLoader(self.dataset["train"], batch_size=self.train_batch_size)
def val_dataloader(self):
if len(self.eval_splits) == 1:
return DataLoader(self.dataset["validation"], batch_size=self.eval_batch_size)
elif len(self.eval_splits) > 1:
return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]
def test_dataloader(self):
if len(self.eval_splits) == 1:
return DataLoader(self.dataset["test"], batch_size=self.eval_batch_size)
elif len(self.eval_splits) > 1:
return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]
def convert_to_features(self, example_batch, indices=None):
# Either encode single sentence or sentence pairs
if len(self.text_fields) > 1:
texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))
else:
texts_or_text_pairs = example_batch[self.text_fields[0]]
# Tokenize the text/text pairs
features = self.tokenizer.batch_encode_plus(
texts_or_text_pairs, max_length=self.max_seq_length, pad_to_max_length=True, truncation=True
)
# Rename label to labels to make it easier to pass to model forward
features["labels"] = example_batch["label"]
return features
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_2392/2472170714.py in <module>
----> 1 class GLUEDataModule(LightningDataModule):
2
3 task_text_field_map = {
4 "cola": ["sentence"],
5 "sst2": ["sentence"],
NameError: name 'LightningDataModule' is not defined
Code language: HTML, XML (xml)
Prepare the data using datamodule
dm = GLUEDataModule("distilbert-base-uncased")
dm.prepare_data()
dm.setup("fit")
next(iter(dm.train_dataloader()))
Reusing dataset glue (/root/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
Reusing dataset glue (/root/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
Loading cached processed dataset at /root/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-419422735d445439.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-66e685e20329ff99.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-45b22f8d8d64c8ae.arrow
Code language: JavaScript (javascript)
{'input_ids': tensor([[ 101, 2572, 3217, ..., 0, 0, 0],
[ 101, 9805, 3540, ..., 0, 0, 0],
[ 101, 2027, 2018, ..., 0, 0, 0],
...,
[ 101, 1996, 2922, ..., 0, 0, 0],
[ 101, 6202, 1999, ..., 0, 0, 0],
[ 101, 16565, 2566, ..., 0, 0, 0]]),
'attention_mask': tensor([[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0],
...,
[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0]]),
'labels': tensor([1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1,
1, 1, 0, 0, 1, 1, 1, 0])}
Code language: JavaScript (javascript)
Transformer LightningModule
class GLUETransformer(LightningModule):
def __init__(
self,
model_name_or_path: str,
num_labels: int,
task_name: str,
learning_rate: float = 2e-5,
adam_epsilon: float = 1e-8,
warmup_steps: int = 0,
weight_decay: float = 0.0,
train_batch_size: int = 32,
eval_batch_size: int = 32,
eval_splits: Optional[list] = None,
**kwargs,
):
super().__init__()
self.save_hyperparameters()
self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_labels)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=self.config)
self.metric = datasets.load_metric(
"glue", self.hparams.task_name, experiment_id=datetime.now().strftime("%d-%m-%Y_%H-%M-%S")
)
def forward(self, **inputs):
return self.model(**inputs)
def training_step(self, batch, batch_idx):
outputs = self(**batch)
loss = outputs[0]
return loss
def validation_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(**batch)
val_loss, logits = outputs[:2]
if self.hparams.num_labels >= 1:
preds = torch.argmax(logits, axis=1)
elif self.hparams.num_labels == 1:
preds = logits.squeeze()
labels = batch["labels"]
return {"loss": val_loss, "preds": preds, "labels": labels}
def validation_epoch_end(self, outputs):
if self.hparams.task_name == "mnli":
for i, output in enumerate(outputs):
# matched or mismatched
split = self.hparams.eval_splits[i].split("_")[-1]
preds = torch.cat([x["preds"] for x in output]).detach().cpu().numpy()
labels = torch.cat([x["labels"] for x in output]).detach().cpu().numpy()
loss = torch.stack([x["loss"] for x in output]).mean()
self.log(f"val_loss_{split}", loss, prog_bar=True)
split_metrics = {
f"{k}_{split}": v for k, v in self.metric.compute(predictions=preds, references=labels).items()
}
self.log_dict(split_metrics, prog_bar=True)
return loss
preds = torch.cat([x["preds"] for x in outputs]).detach().cpu().numpy()
labels = torch.cat([x["labels"] for x in outputs]).detach().cpu().numpy()
loss = torch.stack([x["loss"] for x in outputs]).mean()
self.log("val_loss", loss, prog_bar=True)
self.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True)
return loss
def setup(self, stage=None) -> None:
if stage != "fit":
return
# Get dataloader by calling it - train_dataloader() is called after setup() by default
train_loader = self.trainer.datamodule.train_dataloader()
# Calculate total steps
tb_size = self.hparams.train_batch_size * max(1, self.trainer.hpus)
ab_size = self.trainer.accumulate_grad_batches * float(self.trainer.max_epochs)
self.total_steps = (len(train_loader.dataset) // tb_size) // ab_size
def configure_optimizers(self):
"""Prepare optimizer and schedule (linear warmup and decay)"""
model = self.model
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": self.hparams.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=self.hparams.warmup_steps,
num_training_steps=self.total_steps,
)
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
return [optimizer], [scheduler]
Training
CoLA
See an interactive view of the CoLA dataset in NLP Viewer
seed_everything(42)
dm = GLUEDataModule(model_name_or_path="albert-base-v2", task_name="cola")
dm.setup("fit")
model = GLUETransformer(
model_name_or_path="albert-base-v2",
num_labels=dm.num_labels,
eval_splits=dm.eval_splits,
task_name=dm.task_name,
)
trainer = Trainer(max_epochs=1, hpus=1)
trainer.fit(model, datamodule=dm)
MNLI
The MNLI dataset is huge, so we aren’t going to bother trying to train on it here.
We will skip over training and go straight to validation.
See an interactive view of the MRPC dataset in NLP Viewer
dm = GLUEDataModule(
model_name_or_path="distilbert-base-cased",
task_name="mnli",
)
dm.setup("fit")
model = GLUETransformer(
model_name_or_path="distilbert-base-cased",
num_labels=dm.num_labels,
eval_splits=dm.eval_splits,
task_name=dm.task_name,
)
trainer = Trainer(hpus=1, progress_bar_refresh_rate=20)
trainer.validate(model, dm.val_dataloader())
trainer.logged_metrics
Copyright (c) 2022 Habana Labs, Ltd. an Intel Company.
All rights reserved.
License
Licensed under a CC BY SA 4.0 license.
A derivative of Introduction To PyTorch Lightning by PL Team