Home » Tutorials » PyTorch » Inference on Gaudi solutions with HPU Graph

Inference on Gaudi solutions with HPU Graph

This tutorial demonstrates how to run Inference on Gaudi solutions using HPU Graph for better performance

Overview

This tutorial will show you how to run inference on first-gen Gaudi and Gaudi2 accelerators.  These are simple, fully runnable examples that show how to run inference using the MNIST dataset, a simple checkpoint and linear model. For more details on Inference on Gaudi and Gaudi2, refer to the Inference User Guide. There are three examples on GitHub;

  • Example 1 is a simple inference example showing how to run using the model.eval() path, which is the most direct path to running inference.
  • Example 2 adds the use of HPUGraph with the Graph and Stream APIs
  • Example 3 uses HPU Graph with wrap_in_hpu_graph API, a simpler version of the Graph and Stream APIs

The HPUGraph API provides a performance optimization technique to reduce PyTorch host overhead. This is done by capturing the PyTorch execution on a stream for the first iteration and replaying that in subsequent ones. The replay avoids the PyTorch overhead of accumulating the ops in the model and makes the execution device bound.

For further details on Stream APIs and HPU Graph APIs, refer to HPU Graph APIs and Stream APIs in reference documentation here

The HPU Graph API from Habana can be used for performance gains and should be applied to real world models where application is latency sensitive, or the host time ends up greater than the device time due to a low batch size. HPU Graphs feature can help minimize this host time.

The three examples are provided in Jupyter notebooks.  To run these examples, please access Gaudi in these two ways:

You can follow the Installation guide here to pull and run a Habana PyTorch Docker Image and then install the JupyterLab library.   The example below is from Example 3, using Inference with HPU Graph.

Inference on Gaudi – Example3

This tutorial will show inference mode with HPU GRAPH with the built-in wrapper `wrap_in_hpu_graph`, by using a simple model and the MNIST dataset.

Download pretrained model checkpoints from vault

!wget https://vault.habana.ai/artifactory/misc/inference/mnist/mnist-epoch_20.pth

Import all necessary dependencies

import os
import sys
import torch
import time
import habana_frameworks.torch as ht
import habana_frameworks.torch.core as htcore
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import torch.nn as nn
import torch.nn.functional as F

Define a simple Net model for MNIST.

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1   = nn.Linear(784, 256)
        self.fc2   = nn.Linear(256, 64)
        self.fc3   = nn.Linear(64, 10)
    def forward(self, x):
        out = x.view(-1,28*28)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        out = F.log_softmax(out, dim=1)
        return out

Create the model, and load the pre-trained checkpoint. Optimize the model for eval, and move the model to the Gaudi Accelerator (“hpu”)

model = Net()
checkpoint = torch.load('mnist-epoch_20.pth')
model.load_state_dict(checkpoint)
model = model.eval()

Wrap the model with HPU graph, and move it to HPU Here we are using “wrap_in_hpu_graph” to wrap module forward function with HPU Graphs. This wrapper captures, caches and replays the graph.

model = ht.hpu.wrap_in_hpu_graph(model)
model = model.to("hpu")
=============================SYSTEM CONFIGURATION ========================================= 
Num CPU Cores = 96
CPU RAM = 784300908 KB 
============================================================================================ 

Create an MNIST dataset for evaluation.  This is pulled from the Torchvision library

transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))])

data_path = './data'
test_kwargs = {'batch_size': 32}
dataset1 = datasets.MNIST(data_path, train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(dataset1,**test_kwargs)

Do a warm run : here HPU graph will be captured and cached.

warmup_input = torch.randn(32, 1, 28, 28, device='hpu')
warmup_output = model(warmup_input)

Run inference.

Here, we already wrap the model with the HPU graph with wrap_in_hpu_graph as shown above, so there is no need to copy and replay the stream. It will all be done in the background. We are also using asynchronous copies here as shown below (copy with “non_blocking=True” followed by mark_step), to further optimize the inference. Please refer to the guideline below for more information here. Adding mark_step after model() is not required with HPU Graphs as it is handled implicitly.

correct = 0 
for batch_idx, (data, label) in enumerate(test_loader):  
    data = data.to("hpu", non_blocking=True)
    htcore.mark_step()
    output = model(data)
    correct += output.max(1)[1].eq(label).sum()

print('Accuracy: {:.2f}%'.format(100. * correct / (len(test_loader) * 32)))
Accuracy: 94.36%Code language: CSS (css)

Summary

Running Inference on Gaudi is easy to do and adding HPU Graph can improve performance in your model.

Copyright© 2023 Habana Labs, Ltd. an Intel Company.

Licensed under the Apache License, Version 2.0 (the “License”);

You may not use this file except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

Share this article:

Stay Informed: Register for the latest Intel Gaudi AI Accelerator developer news, events, training, and updates.