Home » Tutorials » PyTorch » DistilBert Sequence Classification with IMDb Reviews

DistilBert Sequence Classification with IMDb Reviews

An adaptation of Huggingface Sequence Classification with IMDB Reviews using Habana Gaudi AI processors. Overview This tutorial will take you through one example of using Huggingface Transformers models with IMDB datasets. The guide shows the workflow for training the model using Gaudi and is meant to be illustrative rather than definitive. Note: The dataset can be explored ...

An adaptation of Huggingface Sequence Classification with IMDB Reviews using Habana Gaudi AI processors.

Overview

This tutorial will take you through one example of using Huggingface Transformers models with IMDB datasets. The guide shows the workflow for training the model using Gaudi and is meant to be illustrative rather than definitive.

Note: The dataset can be explored in the Huggingface model hub (IMDb), and can be alternatively downloaded with the Huggingface NLP library with load_dataset(“imdb”).

Setup

Let’s start by downloading the dataset from the Large Movie Review Dataset webpage.

!wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
--2022-11-17 06:50:07--  http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
Resolving ai.stanford.edu (ai.stanford.edu)... 171.64.68.10
Connecting to ai.stanford.edu (ai.stanford.edu)|171.64.68.10|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 84125825 (80M) [application/x-gzip]
Saving to: ‘aclImdb_v1.tar.gz’

aclImdb_v1.tar.gz   100%[===================>]  80.23M  47.8MB/s    in 1.7s    

2022-11-17 06:50:08 (47.8 MB/s) - ‘aclImdb_v1.tar.gz’ saved [84125825/84125825]Code language: PHP (php)
!tar -xf aclImdb_v1.tar.gz

Install required libraries

We will install the Habana version of transformers inside the docker.

!git clone --depth=1 https://github.com/HabanaAI/Model-References.git
Cloning into 'Model-References'...
remote: Enumerating objects: 6099, done.
remote: Counting objects: 100% (6099/6099), done.
remote: Compressing objects: 100% (4350/4350), done.
remote: Total 6099 (delta 1809), reused 5230 (delta 1582), pack-reused 0
Receiving objects: 100% (6099/6099), 39.51 MiB | 19.45 MiB/s, done.
Resolving deltas: 100% (1809/1809), done.Code language: JavaScript (javascript)
pip install Model-References/PyTorch/nlp/finetuning/huggingface/bert/transformers/.
Processing ./Model-References/PyTorch/nlp/finetuning/huggingface/bert/transformers
  Preparing metadata (setup.py) ... done
Collecting filelock
  Downloading filelock-3.8.0-py3-none-any.whl (10 kB)
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.11.0-py3-none-any.whl (182 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 182.1/182.1 kB 8.2 MB/s eta 0:00:00
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from transformers==4.20.1) (1.22.3)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from transformers==4.20.1) (21.3)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.8/dist-packages (from transformers==4.20.1) (5.4.1)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.8/dist-packages (from transformers==4.20.1) (2020.10.28)
Requirement already satisfied: requests in /usr/local/lib/python3.8/dist-packages (from transformers==4.20.1) (2.28.1)
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.6/6.6 MB 113.5 MB/s eta 0:00:00 0:00:01
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.8/dist-packages (from transformers==4.20.1) (4.64.1)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<1.0,>=0.1.0->transformers==4.20.1) (4.4.0)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.8/dist-packages (from packaging>=20.0->transformers==4.20.1) (3.0.9)
Requirement already satisfied: charset-normalizer<3,>=2 in /usr/local/lib/python3.8/dist-packages (from requests->transformers==4.20.1) (2.1.1)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests->transformers==4.20.1) (1.26.12)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests->transformers==4.20.1) (2022.9.24)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests->transformers==4.20.1) (3.4)
Building wheels for collected packages: transformers
  Building wheel for transformers (setup.py) ... done
  Created wheel for transformers: filename=transformers-4.20.1-py3-none-any.whl size=4248789 sha256=1a9f15e173edb71e1a91a02dad9735f6c2bb2079087b9f4424b9dc6a8e57adf1
  Stored in directory: /tmp/pip-ephem-wheel-cache-olf_mzup/wheels/99/23/5e/6fb06b86f9a4787183569a5ba0fe05827108d9cc054a944f9a
Successfully built transformers
Installing collected packages: tokenizers, filelock, huggingface-hub, transformers
Successfully installed filelock-3.8.0 huggingface-hub-0.11.0 tokenizers-0.12.1 transformers-4.20.1
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

[notice] A new release of pip available: 22.3 -> 22.3.1
[notice] To update, run: python3 -m pip install --upgrade pip
Note: you may need to restart the kernel to use updated packages.Code language: JavaScript (javascript)
pip install scikit-learn
Collecting scikit-learn
  Downloading scikit_learn-1.1.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (31.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 31.2/31.2 MB 279.4 MB/s eta 0:00:00a 0:00:01
Collecting scipy>=1.3.2
  Downloading scipy-1.9.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (33.8 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 33.8/33.8 MB 333.0 MB/s eta 0:00:0000:0100:01
Collecting threadpoolctl>=2.0.0
  Downloading threadpoolctl-3.1.0-py3-none-any.whl (14 kB)
Requirement already satisfied: numpy>=1.17.3 in /usr/local/lib/python3.8/dist-packages (from scikit-learn) (1.22.3)
Collecting joblib>=1.0.0
  Downloading joblib-1.2.0-py3-none-any.whl (297 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 298.0/298.0 kB 361.0 MB/s eta 0:00:00
Installing collected packages: threadpoolctl, scipy, joblib, scikit-learn
Successfully installed joblib-1.2.0 scikit-learn-1.1.3 scipy-1.9.3 threadpoolctl-3.1.0
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

<strong>[</strong>notice<strong>]</strong> A new release of pip available: 22.3 -> 22.3.1
<strong>[</strong>notice<strong>]</strong> To update, run: python3 -m pip install --upgrade pip
Note: you may need to restart the kernel to use updated packages.Code language: HTML, XML (xml)
pip install datasets
Collecting datasets
  Downloading datasets-2.7.0-py3-none-any.whl (451 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 451.6/451.6 kB 11.4 MB/s eta 0:00:00a 0:00:01
Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.8/dist-packages (from datasets) (4.64.1)
Collecting xxhash
  Downloading xxhash-3.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 212.9/212.9 kB 338.2 MB/s eta 0:00:00
Collecting dill<0.3.7
  Downloading dill-0.3.6-py3-none-any.whl (110 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 110.5/110.5 kB 314.3 MB/s eta 0:00:00
Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.8/dist-packages (from datasets) (2.28.1)
Requirement already satisfied: pandas in /usr/local/lib/python3.8/dist-packages (from datasets) (1.4.1)
Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.8/dist-packages (from datasets) (2022.10.0)
Collecting pyarrow>=6.0.0
  Downloading pyarrow-10.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (35.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 35.3/35.3 MB 221.8 MB/s eta 0:00:0000:0100:01
Requirement already satisfied: huggingface-hub<1.0.0,>=0.2.0 in /usr/local/lib/python3.8/dist-packages (from datasets) (0.11.0)
Collecting multiprocess
  Downloading multiprocess-0.70.14-py38-none-any.whl (132 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 132.0/132.0 kB 319.8 MB/s eta 0:00:00
Requirement already satisfied: aiohttp in /usr/local/lib/python3.8/dist-packages (from datasets) (3.8.3)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.8/dist-packages (from datasets) (5.4.1)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from datasets) (1.22.3)
Collecting responses<0.19
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Requirement already satisfied: packaging in /usr/local/lib/python3.8/dist-packages (from datasets) (21.3)
Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.8.1)
Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.2.0)
Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (2.1.1)
Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (4.0.2)
Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (6.0.2)
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (22.1.0)
Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.3.1)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<1.0.0,>=0.2.0->datasets) (4.4.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<1.0.0,>=0.2.0->datasets) (3.8.0)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.8/dist-packages (from packaging->datasets) (3.0.9)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (3.4)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (2022.9.24)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (1.26.12)
Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.8/dist-packages (from pandas->datasets) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.8/dist-packages (from pandas->datasets) (2022.6)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)
Installing collected packages: xxhash, pyarrow, dill, responses, multiprocess, datasets
Successfully installed datasets-2.7.0 dill-0.3.6 multiprocess-0.70.14 pyarrow-10.0.0 responses-0.18.0 xxhash-3.1.0
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

[notice] A new release of pip available: 22.3 -> 22.3.1
[notice] To update, run: python3 -m pip install --upgrade pip
Note: you may need to restart the kernel to use updated packages.Code language: JavaScript (javascript)

This data is organized into pos and neg folders with one text file per example.

import os
from pathlib import Path

def read_imdb_split(split_dir):
    split_dir = Path(split_dir)
    texts = []
    labels = []
    for label_dir in ["pos", "neg"]:
        for text_file in (split_dir/label_dir).iterdir():
            texts.append(text_file.read_text())
            labels.append(0 if label_dir == "neg" else 1)

    return texts, labels

train_texts, train_labels = read_imdb_split('./aclImdb/train')
test_texts, test_labels = read_imdb_split('./aclImdb/test')

We now have a train and test dataset, but let’s also also create a validation set which we can use for for evaluation and tuning without training our test set results. Sklearn has a convenient utility for creating such splits:

from sklearn.model_selection import train_test_split
train_texts, val_texts, train_labels, val_labels = train_test_split(train_texts, train_labels, test_size=.2)

Alright, we’ve read in our dataset. Now let’s tackle tokenization. We’ll eventually train a classifier using pre-trained DistilBert, so let’s use the DistilBert tokenizer.

from transformers import DistilBertTokenizerFast
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
Downloading: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28.0/28.0 [00:00<00:00, 47.6kB/s]
Downloading: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 226k/226k [00:00<00:00, 770kB/s]
Downloading: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 455k/455k [00:00<00:00, 1.90MB/s]
Downloading: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 483/483 [00:00<00:00, 398kB/s]
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BertTokenizer'. 
The class this function is called from is 'DistilBertTokenizerFast'.Code language: JavaScript (javascript)

Now we can simply pass our texts to the tokenizer. We’ll pass truncation=True and padding=True, which will ensure that all of our sequences are padded to the same length and are truncated to be no longer than the model’s maximum input length. This will allow us to feed batches of sequences into the model at the same time.

train_encodings = tokenizer(train_texts, truncation=True, padding=True)
val_encodings = tokenizer(val_texts, truncation=True, padding=True)
test_encodings = tokenizer(test_texts, truncation=True, padding=True)

Building the model and Fine-tuning with Trainer on Gaudi

Now, let’s turn our labels and encodings into a Dataset object. In PyTorch, this is done by subclassing a torch.utils.data.Dataset object and implementing len and getitem. We put the data in this format so that the data can be easily batched such that each key in the batch encoding corresponds to a named parameter of the forward() method of the model we will train.

import torch

class IMDbDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item
    def __len__(self):
        return len(self.labels)
train_dataset = IMDbDataset(train_encodings, train_labels)
val_dataset = IMDbDataset(val_encodings, val_labels)
test_dataset = IMDbDataset(test_encodings, test_labels)

The steps above prepared the datasets in the way that the trainer is expected. Now all we need to do is create a model to fine-tune, define the TrainingArguments and instantiate a Trainer. Next, let’s enable the training on Gaudi by setting the variables in TrainingArguments.

  • The argument use_hpu is to set default device being Gaudi;
  • The argument hmp is to enable mixed precision;
  • ops_hmp_bf16 and ops_hmp_fp32 files are required to specify the BF16 op list and BF16 op list;
  • The hmp_verbose controls the printout of datatype conversion between BF16 and FP32.

In this example, we set hmp_verbose=False for a clean output.

from transformers import DistilBertForSequenceClassification, Trainer, TrainingArguments

training_args = TrainingArguments(
    use_hpu=True,
    use_lazy_mode=True,
    use_fused_adam=True,
    use_fused_clip_norm=True,
    hmp=True,
    hmp_bf16='./ops_bf16_distilbert_pt.txt',
    hmp_fp32='./ops_fp32_distilbert_pt.txt',
    hmp_verbose=False,
    output_dir='./results',          # output directory
    num_train_epochs=3,              # total number of training epochs
    per_device_train_batch_size=16,  # batch size per device during training
    per_device_eval_batch_size=64,   # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=10,
)
/usr/local/lib/python3.8/dist-packages/habana_frameworks/torch/utils/distributed_utils.py:6: UserWarning: habana_frameworks.torch.utils.distributed_utils.initialize_distributed_hpu is deprecated. Please use habana_frameworks.torch.distributed.hccl.initialize_distributed_hpu
  warnings.warn("habana_frameworks.torch.utils.distributed_utils.initialize_distributed_hpu is deprecated. "
--------------------------------------------------------------------------
An invalid value was supplied for an enum variable.

  Variable     : btl_vader_single_copy_mechanism
  Value        : non
  Valid values : 1:"cma", 4:"emulated", 3:"none"
--------------------------------------------------------------------------
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)Code language: JavaScript (javascript)

Now we can train the model from the previously saved checkpoint or the pretrained model. The default set of the full training is 3 epochs.

if not os.path.isdir("./results/checkpoint-3500"):
    model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")
    trainer = Trainer(
    model=model,                         # the instantiated Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=val_dataset             # evaluation dataset
    )
    trainer.train()
else:
    model = DistilBertForSequenceClassification.from_pretrained("./results/checkpoint-3500")
loading configuration file ./results/checkpoint-3500/config.json
Model config DistilBertConfig {
  "_name_or_path": "distilbert-base-uncased",
  "activation": "gelu",
  "architectures": [
    "DistilBertForSequenceClassification"
  ],
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "hidden_dim": 3072,
  "initializer_range": 0.02,
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "pad_token_id": 0,
  "problem_type": "single_label_classification",
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": false,
  "tie_weights_": true,
  "torch_dtype": "float32",
  "transformers_version": "4.20.1",
  "vocab_size": 30522
}

loading weights file ./results/checkpoint-3500/pytorch_model.bin
All model checkpoint weights were used when initializing DistilBertForSequenceClassification.

All the weights of DistilBertForSequenceClassification were initialized from the model checkpoint at ./results/checkpoint-3500.
If your task is similar to the task the model of the checkpoint was trained on, you can already use DistilBertForSequenceClassification for predictions without further training.
Code language: JavaScript (javascript)

After the training finishes, we can evaluate the training results using the validation dataset. The function compute_metrics is used to calculate the accuracy number.

import numpy as np
from datasets import load_metric
metric = load_metric("accuracy")
def compute_metrics(eval_pred):
        logits, labels = eval_pred
        predictions = np.argmax(logits, axis=-1)
        return metric.compute(predictions=predictions, references=labels)
trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics,
    )
/tmp/ipykernel_120/3142568293.py:3: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate
  metric = load_metric("accuracy")
Downloading builder script: 4.21kB [00:00, 4.50MB/s]                                                                                                                            
Enabled lazy mode
hmp:verbose_mode  FCode language: JavaScript (javascript)

Print out the final results

At the end of the training, we can print out the final training/evaluation result.

print("**************** Evaluation below************")
metrics = trainer.evaluate()
metrics["eval_samples"] = len(val_dataset)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
***** Running Evaluation *****
  Num examples = 5000
  Batch size = 64
**************** Evaluation below************
[79/79 00:17]
***** eval metrics *****
  eval_accuracy           =     0.9266
  eval_loss               =     0.3263
  eval_runtime            = 0:00:21.82
  eval_samples            =       5000
  eval_samples_per_second =    229.112
  eval_steps_per_second   =       3.62Code language: JavaScript (javascript)

Gaudi training tips based trainer in huggingface transformers

In TrainingArguments setup:

  • Set use_hpu=True to enable Gaudi device.
  • Set use_lazy_mode=True to enable lazy mode for better performance.
  • Set use_fused_adam=True to use Gaudi customized adam optimizer for better performance.
  • Set use_fused_clip_norm=True to use Gaudi customized clip_norm kernel for better performance.
  • Set mixed precision hmp=True.
    • The default hmp_verbose value is True. The setting hmp_verbose=False helps a clean printout.
    • For mixed precision, the following flags are needed.
      • hmp_bf16=’./ops_bf16_distilbert_pt.txt’,
      • hmp_fp32=’./ops_fp32_distilbert_pt.txt’,

Summary

One can easily enable their model script on Gaudi by specifying a few Gaudi arguments in TrainingArguments.

Copyright© 2021 Habana Labs, Ltd. an Intel Company.

Licensed under the Apache License, Version 2.0 (the “License”);

You may not use this file except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

Share this article:

Sign up for the latest Habana developer news, events, training, and updates.