Home » Tutorials » PyTorch » DistilBert Sequence Classification with IMDb Reviews

DistilBert Sequence Classification with IMDb Reviews

An adaptation of Huggingface Sequence Classification with IMDB Reviews using Habana Gaudi AI processors.

Overview

This tutorial will take you through one example of using Huggingface Transformers models with IMDB datasets. The guide shows the workflow for training the model using Gaudi and is meant to be illustrative rather than definitive.

Note: The dataset can be explored in the Huggingface model hub (IMDb), and can be alternatively downloaded with the Huggingface NLP library with load_dataset(“imdb”).

Setup

Let’s start by downloading the dataset from the Large Movie Review Dataset webpage.

!wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
--2022-03-02 06:42:46-- http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz Resolving proxy-us.intel.com (proxy-us.intel.com)... 10.1.192.48 Connecting to proxy-us.intel.com (proxy-us.intel.com)|10.1.192.48|:911... connected. Proxy request sent, awaiting response... 200 OK Length: 84125825 (80M) [application/x-gzip] Saving to: ‘aclImdb_v1.tar.gz’ aclImdb_v1.tar.gz 100%[===================>] 80.23M 38.5MB/s in 2.1s 2022-03-02 06:42:49 (38.5 MB/s) - ‘aclImdb_v1.tar.gz’ saved [84125825/84125825]
Code language: PHP (php)
!tar -xf aclImdb_v1.tar.gz

Install required libraries

We will install the Habana version of transformers inside the docker.

!git clone --depth=1 https://github.com/HabanaAI/Model-References.git
Cloning into 'Model-References'... remote: Enumerating objects: 2956, done. remote: Counting objects: 100% (2956/2956), done. remote: Compressing objects: 100% (2132/2132), done. remote: Total 2956 (delta 832), reused 2323 (delta 745), pack-reused 0 Receiving objects: 100% (2956/2956), 19.04 MiB | 2.38 MiB/s, done. Resolving deltas: 100% (832/832), done. Checking out files: 100% (2668/2668), done.
Code language: JavaScript (javascript)
pip install Model-References/PyTorch/nlp/finetuning/huggingface/bert/transformers/.
Processing ./Model-References/PyTorch/nlp/finetuning/huggingface/bert/transformers Collecting filelock Downloading https://files.pythonhosted.org/packages/cd/f1/ba7dee3de0e9d3b8634d6fbaa5d0d407a7da64620305d147298b683e5c36/filelock-3.6.0-py3-none-any.whl Collecting huggingface-hub<1.0,>=0.1.0 Downloading https://files.pythonhosted.org/packages/c8/df/1b454741459f6ce75f86534bdad42ca17291b14a83066695f7d2c676e16c/huggingface_hub-0.4.0-py3-none-any.whl (67kB) |████████████████████████████████| 71kB 1.8MB/s eta 0:00:01 Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from transformers==4.15.0) (1.22.2) Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from transformers==4.15.0) (21.3) Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.8/dist-packages (from transformers==4.15.0) (5.4.1) Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.8/dist-packages (from transformers==4.15.0) (2020.10.28) Requirement already satisfied: requests in /usr/local/lib/python3.8/dist-packages (from transformers==4.15.0) (2.27.1) Collecting sacremoses Downloading https://files.pythonhosted.org/packages/ec/e5/407e634cbd3b96a9ce6960874c5b66829592ead9ac762bd50662244ce20b/sacremoses-0.0.47-py2.py3-none-any.whl (895kB) |████████████████████████████████| 901kB 3.2MB/s eta 0:00:01 Collecting tokenizers<0.11,>=0.10.1 Downloading https://files.pythonhosted.org/packages/e4/bd/10c052faa46f4effb18651b66f01010872f8eddb5f4034d72c08818129bd/tokenizers-0.10.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3MB) |████████████████████████████████| 3.3MB 1.2MB/s eta 0:00:01 Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.8/dist-packages (from transformers==4.15.0) (4.62.3) Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<1.0,>=0.1.0->transformers==4.15.0) (3.10.0.2) Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.8/dist-packages (from packaging>=20.0->transformers==4.15.0) (3.0.7) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests->transformers==4.15.0) (1.26.8) Requirement already satisfied: idna<4,>=2.5; python_version >= "3" in /usr/local/lib/python3.8/dist-packages (from requests->transformers==4.15.0) (3.3) Requirement already satisfied: charset-normalizer~=2.0.0; python_version >= "3" in /usr/local/lib/python3.8/dist-packages (from requests->transformers==4.15.0) (2.0.11) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests->transformers==4.15.0) (2021.10.8) Collecting click Downloading https://files.pythonhosted.org/packages/4a/a8/0b2ced25639fb20cc1c9784de90a8c25f9504a7f18cd8b5397bd61696d7d/click-8.0.4-py3-none-any.whl (97kB) |████████████████████████████████| 102kB 1.3MB/s ta 0:00:011 Requirement already satisfied: six in /usr/local/lib/python3.8/dist-packages (from sacremoses->transformers==4.15.0) (1.16.0) Collecting joblib Downloading https://files.pythonhosted.org/packages/3e/d5/0163eb0cfa0b673aa4fe1cd3ea9d8a81ea0f32e50807b0c295871e4aab2e/joblib-1.1.0-py2.py3-none-any.whl (306kB) |████████████████████████████████| 307kB 1.8MB/s eta 0:00:01 Building wheels for collected packages: transformers Building wheel for transformers (setup.py) ... done Created wheel for transformers: filename=transformers-4.15.0-cp38-none-any.whl size=3337816 sha256=3703d4cd928ecc7dc1276ca0d34e7a68631ed8714624e4817ec333aabd78f9fd Stored in directory: /tmp/pip-ephem-wheel-cache-8o1nzs3k/wheels/81/13/08/5690c044e7e3afad106a846035b89b57bafa2f585c07f7ed7a Successfully built transformers Installing collected packages: filelock, huggingface-hub, click, joblib, sacremoses, tokenizers, transformers Successfully installed click-8.0.4 filelock-3.6.0 huggingface-hub-0.4.0 joblib-1.1.0 sacremoses-0.0.47 tokenizers-0.10.3 transformers-4.15.0 WARNING: You are using pip version 19.3.1; however, version 22.0.3 is available. You should consider upgrading via the 'pip install --upgrade pip' command. Note: you may need to restart the kernel to use updated packages.
Code language: JavaScript (javascript)
pip install sklearn
Collecting sklearn Downloading https://files.pythonhosted.org/packages/1e/7a/dbb3be0ce9bd5c8b7e3d87328e79063f8b263b2b1bfa4774cb1147bfcd3f/sklearn-0.0.tar.gz Collecting scikit-learn Downloading https://files.pythonhosted.org/packages/40/d3/206905d836cd496c1f78a15ef92a0f0477d74113b4f349342bf31dfd62ca/scikit_learn-1.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (26.7MB) |████████████████████████████████| 26.7MB 1.0MB/s eta 0:00:01 |████████████████████████▊ | 20.6MB 1.6MB/s eta 0:00:04 Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.8/dist-packages (from scikit-learn->sklearn) (1.1.0) Collecting scipy>=1.1.0 Downloading https://files.pythonhosted.org/packages/d2/27/b2648569175ba233cb6ad13029f8df4049a581c268156c5dd1db5ca44a8c/scipy-1.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (41.6MB) |████████████████████████████████| 41.6MB 2.0MB/s eta 0:00:01 Collecting threadpoolctl>=2.0.0 Downloading https://files.pythonhosted.org/packages/61/cf/6e354304bcb9c6413c4e02a747b600061c21d38ba51e7e544ac7bc66aecc/threadpoolctl-3.1.0-py3-none-any.whl Requirement already satisfied: numpy>=1.14.6 in /usr/local/lib/python3.8/dist-packages (from scikit-learn->sklearn) (1.22.2) Building wheels for collected packages: sklearn Building wheel for sklearn (setup.py) ... done Created wheel for sklearn: filename=sklearn-0.0-py2.py3-none-any.whl size=1310 sha256=5c5befec7253ff2f4f1bf402b24ac3530a1ededf93d056d13c3e0fa0d86d9c83 Stored in directory: /tmp/pip-ephem-wheel-cache-babdam61/wheels/76/03/bb/589d421d27431bcd2c6da284d5f2286c8e3b2ea3cf1594c074 Successfully built sklearn Installing collected packages: scipy, threadpoolctl, scikit-learn, sklearn Successfully installed scikit-learn-1.0.2 scipy-1.8.0 sklearn-0.0 threadpoolctl-3.1.0 WARNING: You are using pip version 19.3.1; however, version 22.0.3 is available. You should consider upgrading via the 'pip install --upgrade pip' command. Note: you may need to restart the kernel to use updated packages.
Code language: JavaScript (javascript)
pip install datasets
Collecting datasets Downloading https://files.pythonhosted.org/packages/a6/45/ecbd6d5d6385b9702f8bb53801c66379edf044b373bbb77f184289cd3811/datasets-1.18.3-py3-none-any.whl (311kB) |████████████████████████████████| 317kB 3.1MB/s eta 0:00:01 Requirement already satisfied: packaging in /usr/local/lib/python3.8/dist-packages (from datasets) (21.3) Requirement already satisfied: pandas in /usr/local/lib/python3.8/dist-packages (from datasets) (1.3.3) Collecting xxhash Downloading https://files.pythonhosted.org/packages/6a/cf/50f4cfde85d90c2b3e9c98b46e17d190bbdd97b54d3e0876e1d9360e487f/xxhash-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212kB) |████████████████████████████████| 215kB 3.5MB/s eta 0:00:01 Requirement already satisfied: fsspec[http]>=2021.05.0 in /usr/local/lib/python3.8/dist-packages (from datasets) (2022.1.0) Collecting pyarrow!=4.0.0,>=3.0.0 Downloading https://files.pythonhosted.org/packages/98/7d/fb38132dd606533b36a3fde8b17db95a36351dc58afbc6dc6b3d668ef3f0/pyarrow-7.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (26.7MB) |████████████████████████████████| 26.7MB 9.5MB/s eta 0:00:01 Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.8/dist-packages (from datasets) (4.62.3) Requirement already satisfied: aiohttp in /usr/local/lib/python3.8/dist-packages (from datasets) (3.8.1) Requirement already satisfied: huggingface-hub<1.0.0,>=0.1.0 in /usr/local/lib/python3.8/dist-packages (from datasets) (0.4.0) Collecting multiprocess Downloading https://files.pythonhosted.org/packages/e6/22/b09b8394f8c86ff0cfebd725ea96bba0accd4a4b2be437bcba6a0cf7d1c3/multiprocess-0.70.12.2-py38-none-any.whl (128kB) |████████████████████████████████| 133kB 3.1MB/s eta 0:00:01 Collecting dill Downloading https://files.pythonhosted.org/packages/b6/c3/973676ceb86b60835bb3978c6db67a5dc06be6cfdbd14ef0f5a13e3fc9fd/dill-0.3.4-py2.py3-none-any.whl (86kB) |████████████████████████████████| 92kB 6.5MB/s eta 0:00:011 Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.8/dist-packages (from datasets) (2.27.1) Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from datasets) (1.22.2) Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.8/dist-packages (from packaging->datasets) (3.0.7) Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.8/dist-packages (from pandas->datasets) (2021.3) Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.8/dist-packages (from pandas->datasets) (2.8.2) Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (2.0.11) Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (4.0.2) Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.7.2) Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (21.4.0) Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (6.0.2) Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.3.0) Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.2.0) Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (3.10.0.2) Requirement already satisfied: pyyaml in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (5.4.1) Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (3.6.0) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (1.26.8) Requirement already satisfied: idna<4,>=2.5; python_version >= "3" in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (3.3) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (2021.10.8) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.16.0) Installing collected packages: xxhash, pyarrow, dill, multiprocess, datasets Successfully installed datasets-1.18.3 dill-0.3.4 multiprocess-0.70.12.2 pyarrow-7.0.0 xxhash-3.0.0 WARNING: You are using pip version 19.3.1; however, version 22.0.3 is available. You should consider upgrading via the 'pip install --upgrade pip' command. Note: you may need to restart the kernel to use updated packages.
Code language: JavaScript (javascript)

This data is organized into pos and neg folders with one text file per example.

import os
from pathlib import Path

def read_imdb_split(split_dir):
    split_dir = Path(split_dir)
    texts = []
    labels = []
    for label_dir in ["pos", "neg"]:
        for text_file in (split_dir/label_dir).iterdir():
            texts.append(text_file.read_text())
            labels.append(0 if label_dir == "neg" else 1)

    return texts, labels

train_texts, train_labels = read_imdb_split('./aclImdb/train')
test_texts, test_labels = read_imdb_split('./aclImdb/test')

We now have a train and test dataset, but let’s also also create a validation set which we can use for for evaluation and tuning without training our test set results. Sklearn has a convenient utility for creating such splits:

from sklearn.model_selection import train_test_split
train_texts, val_texts, train_labels, val_labels = train_test_split(train_texts, train_labels, test_size=.2)

Alright, we’ve read in our dataset. Now let’s tackle tokenization. We’ll eventually train a classifier using pre-trained DistilBert, so let’s use the DistilBert tokenizer.

from transformers import DistilBertTokenizerFast
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

Now we can simply pass our texts to the tokenizer. We’ll pass truncation=True and padding=True, which will ensure that all of our sequences are padded to the same length and are truncated to be no longer than the model’s maximum input length. This will allow us to feed batches of sequences into the model at the same time.

train_encodings = tokenizer(train_texts, truncation=True, padding=True)
val_encodings = tokenizer(val_texts, truncation=True, padding=True)
test_encodings = tokenizer(test_texts, truncation=True, padding=True)

Building the model and Fine-tuning with Trainer on Gaudi

Now, let’s turn our labels and encodings into a Dataset object. In PyTorch, this is done by subclassing a torch.utils.data.Dataset object and implementing len and getitem. We put the data in this format so that the data can be easily batched such that each key in the batch encoding corresponds to a named parameter of the forward() method of the model we will train.

import torch

class IMDbDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item
    def __len__(self):
        return len(self.labels)
train_dataset = IMDbDataset(train_encodings, train_labels)
val_dataset = IMDbDataset(val_encodings, val_labels)
test_dataset = IMDbDataset(test_encodings, test_labels)

The steps above prepared the datasets in the way that the trainer is expected. Now all we need to do is create a model to fine-tune, define the TrainingArguments and instantiate a Trainer. Next, let’s enable the training on Gaudi by setting the variables in TrainingArguments.

  • The argument use_habana is to set default device being Gaudi;
  • The argument hmp is to enable mixed precision;
  • The hmp_opt_level defines the level of optimization and it has two optional values: ‘O1’ and ‘O2’, its defaulte value is ‘O1’. For hmp_opt_level=’O1′, hmp_bf16 and hmp_fp32 are required to specify the BF16 op list and BF16 op list;
  • The hmp_verbose controls the printout of datatype conversion between BF16 and FP32. In this example, we set mixed precision optimization level hmp_opt_level=’O1′ and hmp_verbose=False for a clean output.
<strong>from</strong> transformers <strong>import</strong> DistilBertForSequenceClassification, Trainer, TrainingArguments

training_args <strong>=</strong> TrainingArguments(
    use_habana<strong>=</strong><strong>True</strong>,
    use_lazy_mode<strong>=</strong><strong>True</strong>,
    use_fused_adam<strong>=</strong><strong>True</strong>,
    use_fused_clip_norm<strong>=</strong><strong>True</strong>,
    hmp<strong>=</strong><strong>True</strong>,
    hmp_bf16<strong>=</strong>'./ops_bf16_distilbert_pt.txt',
    hmp_fp32<strong>=</strong>'./ops_fp32_distilbert_pt.txt',
    hmp_verbose<strong>=</strong><strong>False</strong>,
    output_dir<strong>=</strong>'./results',          <em># output directory</em>
    num_train_epochs<strong>=</strong>3,              <em># total number of training epochs</em>
    per_device_train_batch_size<strong>=</strong>16,  <em># batch size per device during training</em>
    per_device_eval_batch_size<strong>=</strong>64,   <em># batch size for evaluation</em>
    warmup_steps<strong>=</strong>500,                <em># number of warmup steps for learning rate scheduler</em>
    weight_decay<strong>=</strong>0.01,               <em># strength of weight decay</em>
    logging_dir<strong>=</strong>'./logs',            <em># directory for storing logs</em>
    logging_steps<strong>=</strong>10,
)
Loading Habana modules from /usr/local/lib/python3.8/dist-packages/habana_frameworks/torch/lib synapse_logger INFO. pid=14 at /home/jenkins/workspace/cdsoftwarebuilder/create-pytorch---bpt-d/repos/pytorch-integration/pytorch_helpers/synapse_logger/synapse_logger.cpp:340 Done command: restart huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid using `tokenizers` before the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Code language: JavaScript (javascript)

Now we can train the model from the previously saved checkpoint or the pretrained model. The default set of the full training is 3 epochs.

if not os.path.isdir("./results/checkpoint-3500"):
    model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")
    trainer = Trainer(
    model=model,                         # the instantiated Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=val_dataset             # evaluation dataset
    )
    trainer.train()
else:
    model = DistilBertForSequenceClassification.from_pretrained("./results/checkpoint-3500")
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias', 'vocab_transform.bias'] - This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.weight', 'pre_classifier.bias', 'classifier.bias'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. Enabled lazy mode hmp:verbose_mode False hmp:opt_level O1 ***** Running training ***** Num examples = 20000 Num Epochs = 3 Instantaneous batch size per device = 16 Total train batch size (w. parallel, distributed & accumulation) = 16 Gradient Accumulation steps = 1 Total optimization steps = 3750
Code language: JavaScript (javascript)
[3750/3750 08:06, Epoch 3/3]
Code language: JSON / JSON with Comments (json)
StepTraining Loss
100.687500
200.681300
300.687500
400.687500
500.684400
600.681300
700.662500
800.631200
900.600000
1000.465600
1100.387500
1200.368800
1300.385900
1400.353100
1500.365600
1600.325000
1700.354700
1800.279700
1900.284400
2000.415600
2100.332800
2200.306300
2300.214100
2400.248400
2500.362500
2600.248400
2700.221900
2800.314100
2900.309400
3000.304700
3100.268700
3200.214100
3300.242200
3400.278100
3500.287500
3600.353100
3700.209400
3800.223400
3900.262500
4000.237500
4100.279700
4200.289100
4300.298400
4400.264100
4500.221900
4600.346900
4700.309400
4800.267200
4900.384400
5000.331200
5100.273400
5200.343800
5300.395300
5400.323400
5500.309400
5600.292200
5700.271900
5800.206200
5900.229700
6000.267200
6100.403100
6200.367200
6300.203100
6400.203100
6500.220300
6600.323400
6700.350000
6800.318700
6900.228100
7000.251600
7100.152300
7200.376600
7300.334400
7400.217200
7500.174200
7600.191400
7700.365600
7800.298400
7900.298400
8000.193000
8100.157000
8200.220300
8300.226600
8400.271900
8500.215600
8600.221900
8700.209400
8800.197700
8900.295300
9000.243700
9100.200000
9200.262500
9300.246900
9400.323400
9500.303100
9600.279700
9700.232800
9800.250000
9900.239100
10000.171900
10100.225000
10200.174200
10300.212500
10400.259400
10500.221900
10600.199200
10700.214100
10800.150000
10900.260900
11000.186700
11100.229700
11200.139100
11300.209400
11400.306300
11500.343800
11600.251600
11700.275000
11800.240600
11900.218800
12000.193800
12100.183600
12200.228100
12300.239100
12400.212500
12500.176600
12600.073800
12700.192200
12800.094100
12900.130500
13000.163300
13100.182000
13200.110200
13300.187500
13400.115600
13500.172700
13600.141400
13700.161700
13800.186700
13900.107000
14000.100800
14100.164100
14200.146100
14300.103100
14400.102300
14500.144500
14600.118700
14700.225000
14800.121100
14900.120300
15000.121100
15100.114100
15200.172700
15300.181200
15400.158600
15500.113300
15600.124200
15700.076600
15800.128900
15900.207800
16000.157800
16100.145300
16200.178900
16300.127300
16400.164800
16500.114100
16600.115600
16700.125800
16800.145300
16900.073400
17000.135900
17100.120300
17200.085500
17300.130500
17400.179700
17500.163300
17600.164100
17700.169500
17800.118000
17900.132800
18000.097700
18100.094500
18200.147700
18300.074600
18400.126600
18500.183600
18600.160900
18700.175800
18800.116400
18900.145300
19000.118700
19100.153100
19200.098400
19300.153900
19400.136700
19500.121900
19600.176600
19700.124200
19800.128900
19900.100000
20000.097700
20100.107800
20200.143000
20300.173400
20400.079700
20500.073800
20600.104700
20700.052000
20800.088300
20900.156200
21000.173400
21100.169500
21200.066400
21300.128900
21400.139800
21500.124200
21600.120300
21700.155500
21800.125000
21900.077700
22000.082400
22100.178100
22200.123400
22300.107800
22400.118000
22500.145300
22600.209400
22700.125000
22800.145300
22900.126600
23000.117200
23100.099600
23200.079700
23300.069100
23400.145300
23500.226600
23600.146100
23700.162500
23800.080500
23900.146100
24000.163300
24100.114800
24200.096900
24300.155500
24400.147700
24500.188300
24600.142200
24700.117200
24800.132800
24900.179700
25000.144500
25100.091000
25200.049000
25300.071500
25400.046900
25500.023500
25600.067600
25700.050000
25800.051600
25900.048400
26000.126600
26100.028700
26200.017700
26300.039500
26400.079700
26500.013900
26600.040200
26700.022500
26800.055900
26900.107800
27000.068400
27100.022200
27200.057400
27300.015700
27400.018200
27500.018400
27600.014500
27700.010000
27800.032200
27900.066000
28000.039500
28100.049800
28200.038300
28300.045700
28400.025200
28500.014800
28600.042000
28700.040000
28800.050800
28900.016000
29000.015200
29100.029500
29200.022200
29300.006800
29400.004100
29500.057800
29600.030300
29700.028300
29800.029700
29900.062500
30000.013900
30100.078500
30200.094500
30300.030300
30400.082000
30500.071900
30600.014500
30700.022100
30800.102300
30900.023000
31000.061300
31100.022400
31200.039500
31300.104700
31400.012400
31500.060900
31600.070700
31700.022000
31800.045500
31900.023000
32000.014300
32100.036500
32200.073800
32300.055100
32400.008700
32500.011200
32600.032000
32700.044700
32800.064100
32900.041800
33000.022400
33100.002500
33200.046500
33300.024200
33400.041000
33500.046300
33600.035400
33700.057800
33800.004100
33900.076600
34000.025400
34100.031100
34200.014900
34300.005200
34400.046100
34500.041200
34600.018500
34700.036700
34800.018100
34900.058600
35000.043400
35100.054700
35200.038700
35300.021600
35400.072700
35500.041200
35600.071900
35700.007200
35800.009200
35900.032600
36000.052700
36100.019900
36200.010400
36300.005400
36400.048400
36500.034600
36600.052300
36700.025800
36800.015500
36900.010900
37000.012800
37100.010400
37200.035500
37300.006200
37400.019200
37500.018700
Saving model checkpoint to ./results/checkpoint-500 Configuration saved in ./results/checkpoint-500/config.json Model weights saved in ./results/checkpoint-500/pytorch_model.bin Saving model checkpoint to ./results/checkpoint-1000 Configuration saved in ./results/checkpoint-1000/config.json Model weights saved in ./results/checkpoint-1000/pytorch_model.bin Saving model checkpoint to ./results/checkpoint-1500 Configuration saved in ./results/checkpoint-1500/config.json Model weights saved in ./results/checkpoint-1500/pytorch_model.bin Saving model checkpoint to ./results/checkpoint-2000 Configuration saved in ./results/checkpoint-2000/config.json Model weights saved in ./results/checkpoint-2000/pytorch_model.bin Saving model checkpoint to ./results/checkpoint-2500 Configuration saved in ./results/checkpoint-2500/config.json Model weights saved in ./results/checkpoint-2500/pytorch_model.bin Saving model checkpoint to ./results/checkpoint-3000 Configuration saved in ./results/checkpoint-3000/config.json Model weights saved in ./results/checkpoint-3000/pytorch_model.bin Saving model checkpoint to ./results/checkpoint-3500 Configuration saved in ./results/checkpoint-3500/config.json Model weights saved in ./results/checkpoint-3500/pytorch_model.bin Training completed. Do not forget to share your model on huggingface.co/models =)
Code language: PHP (php)

After the training finishes, we can evaluate the training results using the validation dataset. The function compute_metrics is used to calculate the accuracy number.

import numpy as np
from datasets import load_metric
metric = load_metric("accuracy")
def compute_metrics(eval_pred):
        logits, labels = eval_pred
        predictions = np.argmax(logits, axis=-1)
        return metric.compute(predictions=predictions, references=labels)
trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics,
    )
Enabled lazy mode hmp:verbose_mode False hmp:opt_level O1
Code language: CSS (css)

Print out the final results

At the end of the training, we can print out the final training/evaluation result.

print("**************** Evaluation below************")
metrics = trainer.evaluate()
metrics["eval_samples"] = len(val_dataset)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
***** Running Evaluation ***** Num examples = 5000 Batch size = 64 **************** Evaluation below************ [79/79 00:19] ***** eval metrics ***** eval_accuracy = 0.9248 eval_loss = 0.2556 eval_runtime = 0:00:26.47 eval_samples = 5000 eval_samples_per_second = 188.837 eval_steps_per_second = 2.984
Code language: JavaScript (javascript)

Gaudi training tips based trainer in huggingface transformers

In TrainingArguments setup:

  • Set use_habana=True to enable Gaudi device.
  • Set use_lazy_mode=True to enable lazy mode for better performance.
  • Set use_fused_adam=True to use Gaudi customized adam optimizer for better performance.
  • Set use_fused_clip_norm=True to use Gaudi customized clip_norm kernel for better performance.
  • Set mixed precision hmp=True.
    • The default hmp_verbose value is True. The setting hmp_verbose=False helps a clean printout.
    • The default hmp_opt_level value is ‘O1’
    • For hmp_opt_level=’O1′, the following flags are needed.
      • hmp_bf16=’./ops_bf16_distilbert_pt.txt’,
      • hmp_fp32=’./ops_fp32_distilbert_pt.txt’,

Summary

One can easily enable their model script on Gaudi by specifying a few Gaudi arguments in TrainingArguments.

Copyright© 2021 Habana Labs, Ltd. an Intel Company.

Licensed under the Apache License, Version 2.0 (the “License”);

You may not use this file except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

Sign up for the latest Habana developer news, events, training, and updates.