Dynamic Shapes and how to detect them
Usually dynamicity introduces recompilations, which slows down execution. For optimizing a model’s speed, it is desirable to identify if it has dynamic inputs or ops and then mitigate it if possible by following steps shown in this document. In this notebook we shall discuss some tools to detect dynamic inputs and ops.
Types of Dynamicity
Before we start looking at optimizations, we should discuss the main places that generate Dynamic Shapes. Dynamic Shapes can be broadly classified into two categories:
- Inputs – Dynamic shapes due to varying input shapes during training, such as varying sentence lengths in language models or differing image resolutions in image model
- Ops – Dynamic shapes due to Ops occur for certain Ops whose output shape depends on the actual input data, rather than only the input shapes, that is Ops with non-inferable output shapes given input shapes.
Follow these two steps to look for Dynamicity in your model
Step 1: Check for general recompilations and use Habana’s Dynamic Shape automated support feature.
- Set the environment flag as follows
PT_HPU_METRICS_FILE=/root/metricslog.json PT_HPU_METRICS_DUMP_TRIGGERS=process_exit,metric_change
. This will give a broad sense of the recompilations in the model. The metricslog.json file created will show how often a graph_compilation is called. For static graphs, a reduction in recompilations is expected after a few steps. - If recompilations continue to exist, set the
PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES=1
, to enable Habana’s automated Dynamic Shape control. This variable can be set to enable the Habana PyTorch bridge and Graph Compiler to automatically manage dynamic shapes in model scripts. The graphs will be automatically bucketed and padded into ranges to achieve a common size, reducing recompilations and improving performance when working with dynamic workloads. - If recompilations continue to exist, or you encounter instability and want to achieve better performance, go to step 2
Step 2: Deeper Analysis of the models Data and OPs
The rest of this tutorial will cover the details of how to use these tools for specific analysis of your model. These tools will allow you to pinpoint areas of dynamicity and make improvements.
Detecting Dynamic inputs with the Data Dynamicity Tool
In this section we will use the data_dynamicity
tool, which accepts a torch dataloader and produces a report of how many distinct input shapes it sees, to look at low vs high dynamicity in input datasets. We will also discuss some strategies to mitigate high input dynamicity by padding.
Image datasets
Low input dynamicity
In the example below we see MNIST
dataset with batchsize of 7 has 2 input shapes, one for batch size = 7 and the other for batch size = 3 (because MNIST
has 60000
training images, and 60000%7=3
, so we have batch size=3 for the last batch). This is considered very low and acceptable amount of dynamicity in the input
from habana_frameworks.torch.utils.experimental import data_dynamicity
import torchvision
from torch.utils.data import DataLoader
# Creating a sample MNIST dataloader
mnist_ds = torchvision.datasets.MNIST('mnist', download=True, transform=torchvision.transforms.ToTensor())
mnist_dl = DataLoader(mnist_ds, batch_size=7, num_workers=2)
# Call the dataloader dynamicity tool on the dataloader
res = data_dynamicity(mnist_dl)
from this code, the tool provides this output, no dynamicity here.
==============================================================================
|Shape |Count |
==============================================================================
|((7, 1, 28, 28), (7,)) |8571 |
------------------------------------------------------------------------------
|((3, 1, 28, 28), (3,)) |1 |
------------------------------------------------------------------------------
Number of unique shapes: 2
There is a little dynamicity in input data shapes
Code language: JavaScript (javascript)
High input dynamicity
On the other hand, for the Flowers 102 dataset, we have images of different shapes. We see 29 different input shapes in the next example.
pip install scipy
from habana_frameworks.torch.utils.experimental import data_dynamicity
import torchvision
from torch.utils.data import DataLoader
import torch
# Join a list of images/labels into a single batched tensor
# In this case we find the image with the largets dimensions in the batch,
# and then pad everything else to that size
def collate(batch):
dim1 = min([k[0].shape[1] for k in batch])
dim2 = min([k[0].shape[2] for k in batch])
images = torch.stack([k[0][:,:dim1,:dim2] for k in batch])
labels = torch.tensor([k[1] for k in batch])
return (images,labels)
flowers_ds = torchvision.datasets.Flowers102('flowers', download=True, transform=torchvision.transforms.ToTensor())
flowers_dl = DataLoader(flowers_ds, batch_size=7, num_workers=2, collate_fn=collate)
res = data_dynamicity(flowers_dl)
===================================================================================
|Shape |Count |
===================================================================================
|((7, 3, 500, 500), (7,)) |111 |
-----------------------------------------------------------------------------------
|((7, 3, 500, 667), (7,)) |5 |
----------------------------------------------------------------------------------
|((7, 3, 500, 528), (7,)) |2 |
-----------------------------------------------------------------------------------
|((7, 3, 500, 501), (7,)) |2 |
----------------------------------------------------------------------------------
|((7, 3, 500, 542), (7,)) |2 |
--------------------------------------------------------------------------------
|((7, 3, 500, 592), (7,)) |1 |
----------------------------------------------------------------------------------
***
***
|((7, 3, 500, 549), (7,)) |1 |
-----------------------------------------------------------------------------------
|((5, 3, 500, 500), (5,)) |1 |
-----------------------------------------------------------------------------------
Number of unique shapes: 29
There is a lot of dynamicity in input data shapes
Code language: JavaScript (javascript)
Depending on the usecase, we can bucket images to certain fixed sizes, or resize/crop them to a single shape. A centre-crop solution is shown in the example below, which makes the Flowers 102
dataset more static.
from habana_frameworks.torch.utils.experimental import data_dynamicity
import torchvision
from torch.utils.data import DataLoader
import torch
def collate(batch):
images = torch.stack([k[0] for k in batch])
labels = torch.tensor([k[1] for k in batch])
return (images,labels)
# Center crop to a fixed size, applied as a transform
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.CenterCrop((300,300))])
flowers_ds = torchvision.datasets.Flowers102('flowers', download=True, transform=transform)
flowers_dl = DataLoader(flowers_ds, batch_size=7, num_workers=2, collate_fn=collate)
res = data_dynamicity(flowers_dl)
==============================================================================
|Shape |Count |
==============================================================================
|((7, 3, 300, 300), (7,)) |145 |
------------------------------------------------------------------------------
|((5, 3, 300, 300), (5,)) |1 |
------------------------------------------------------------------------------
Number of unique shapes: 2
There is a little dynamicity in input data shapes
Code language: JavaScript (javascript)
Text datasets
We often have high input dynamicity for text datasets, because sentence sizes vary a lot. In the example below, we have 443 different shapes for SQUAD dataset when batching with batchsize=7. Within each batch we pad to the largest sentence size
pip install datasets
pip install transformers
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import torch
from habana_frameworks.torch.utils.experimental import data_dynamicity
# Pad to max length sentence in each batch
def collate(batch):
def pad(item, val, maxlen):
return torch.tensor([i + [val]*(maxlen-len(i)) for i in item])
token = [k['token_type_ids'] for k in batch]
attention = [k['attention_mask'] for k in batch]
inp = [k['input_ids'] for k in batch]
token_lens = [len(i) for i in token]
# Find the max length sentence in this batch
max_len = max(token_lens)
assert token_lens == [len(i) for i in attention] == [len(i) for i in inp]
return {'token_type_ids': pad(token, 0, max_len), 'attention_mask': pad(attention, 0, max_len), 'input_ids': pad(inp, 0, max_len)}
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
squad_dataset = load_dataset('squad')
tokenized_dataset = squad_dataset.map(lambda x: tokenizer(x['context']), batched=True)
dt = DataLoader(tokenized_dataset['train'], batch_size=7, num_workers=2, collate_fn=collate)
res = data_dynamicity(dt)
=================================================================================================================
|Shape |Count |
=================================================================================================================
|((-1023680607561683160, (7, 160)), (-4748259973688274144, (7, 160)), (-5213422677791015773, (7, 160))) |114 |
-----------------------------------------------------------------------------------------------------------------
|((-1023680607561683160, (7, 145)), (-4748259973688274144, (7, 145)), (-5213422677791015773, (7, 145))) |109 |
-----------------------------------------------------------------------------------------------------------------
|((-1023680607561683160, (7, 143)), (-4748259973688274144, (7, 143)), (-5213422677791015773, (7, 143))) |108 |
------------------------------------------------------------------------------------------------------------
|((-1023680607561683160, (7, 180)), (-4748259973688274144, (7, 180)), (-5213422677791015773, (7, 180))) |107 |
----------------------------------------------------------------------------------------------------------------
***
***
|((-1023680607561683160, (7, 149)), (-4748259973688274144, (7, 149)), (-5213422677791015773, (7, 149))) |99 |
----------------------------------------------------------------------------------------------------------------
|((-1023680607561683160, (7, 513)), (-4748259973688274144, (7, 513)), (-5213422677791015773, (7, 513))) |1 |
----------------------------------------------------------------------------------------------------------------
|((-1023680607561683160, (7, 431)), (-4748259973688274144, (7, 431)), (-5213422677791015773, (7, 431))) |1 |
----------------------------------------------------------------------------------------------------------------
|((-1023680607561683160, (1, 159)), (-4748259973688274144, (1, 159)), (-5213422677791015773, (1, 159))) |1 |
----------------------------------------------------------------------------------------------------------------
Number of unique shapes: 443
There is a lot of dynamicity in input data shapes
Code language: JavaScript (javascript)
A very simple way to get static shapes is to pad the data to the longest sentence length. However this is inefficient computationally, because we are wasting compute effort on the padded sections which are thrown away later.
In the next example we show the same SQUAD dataset padded to maximum sentence length, and thus exhibiting low input dynamicity.
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import torch
from habana_frameworks.torch.utils.experimental import data_dynamicity
# Pad to max sentence length in the whole dataset
def get_collate(max_sentence):
def collate(batch):
def pad(item, val):
return torch.tensor([i + [val]*(max_sentence-len(i)) for i in item])
token = [k['token_type_ids'] for k in batch]
attention = [k['attention_mask'] for k in batch]
inp = [k['input_ids'] for k in batch]
return {'token_type_ids': pad(token, 0), 'attention_mask': pad(attention, 0), 'input_ids': pad(inp, 0)}
return collate
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
squad_dataset = load_dataset('squad')
tokenized_dataset = squad_dataset.map(lambda x: tokenizer(x['context']), batched=True)
# Find max sentence length in the whole dataset
max_sentence = max([len(dt['input_ids']) for dt in tokenized_dataset['train']])
dt = DataLoader(tokenized_dataset['train'], batch_size=7, num_workers=2, collate_fn=get_collate(max_sentence))
res = data_dynamicity(dt)
==================================================================================================================
|Shape |Count |
==================================================================================================================
|((-1023680607561683160, (7, 867)), (-4748259973688274144, (7, 867)), (-5213422677791015773, (7, 867))) |12514 |
------------------------------------------------------------------------------------------------------------------
|((-1023680607561683160, (1, 867)), (-4748259973688274144, (1, 867)), (-5213422677791015773, (1, 867))) |1 |
------------------------------------------------------------------------------------------------------------------
Number of unique shapes: 2
There is a little dynamicity in input data shapes
Code language: JavaScript (javascript)
We can reduce compilations, yet not waste computation by padding to longest sentence by using bucketing. Here we select a hyperparameter, which is the number of buckets, and use some algorithm to divide the range between the lengths of the shortest and the longest sentences in the dataset into buckets. Then for each batch we find the longest sentence in the batch and pad it to a bucket just larger than it.
For a case study using wav2vec please refer to this example
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import torch
from habana_frameworks.torch.utils.experimental import data_dynamicity
import numpy as np
def get_buckets(sizes, num_buckets):
buckets = np.unique(
np.percentile(
sizes,
np.linspace(0, 100, num_buckets + 1),
interpolation="lower",
)[1:]
)
return buckets
# Find the largest sentence in the batch
# Then find the bucket just larger than it, and pad everything to that
def get_collate(buckets):
def collate(batch):
def pad(item, val):
max_in_batch = max([len(i) for i in item])
nearest_bucket = np.where(buckets>=max_in_batch)[0][0]
return torch.tensor([i + [val]*(buckets[nearest_bucket]-len(i)) for i in item])
token = [k['token_type_ids'] for k in batch]
attention = [k['attention_mask'] for k in batch]
inp = [k['input_ids'] for k in batch]
return {'token_type_ids': pad(token, 0), 'attention_mask': pad(attention, 0), 'input_ids': pad(inp, 0)}
return collate
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
squad_dataset = load_dataset('squad')
tokenized_dataset = squad_dataset.map(lambda x: tokenizer(x['context']), batched=True)
buckets = get_buckets([len(dt['input_ids']) for dt in tokenized_dataset['train']], 5)
dt = DataLoader(tokenized_dataset['train'], batch_size=7, num_workers=2, collate_fn=get_collate(buckets))
res = data_dynamicity(dt)
===============================================================================================================
|Shape |Count |
================================================================================================================
|((-1023680607561683160, (7, 867)), (-4748259973688274144, (7, 867)), (-5213422677791015773, (7, 867))) |4543 |
----------------------------------------------------------------------------------------------------------------
|((-1023680607561683160, (7, 207)), (-4748259973688274144, (7, 207)), (-5213422677791015773, (7, 207))) |3350 |
----------------------------------------------------------------------------------------------------------------
|((-1023680607561683160, (7, 164)), (-4748259973688274144, (7, 164)), (-5213422677791015773, (7, 164))) |2277 |
----------------------------------------------------------------------------------------------------------------
|((-1023680607561683160, (7, 138)), (-4748259973688274144, (7, 138)), (-5213422677791015773, (7, 138))) |1388 |
----------------------------------------------------------------------------------------------------------------
|((-1023680607561683160, (7, 114)), (-4748259973688274144, (7, 114)), (-5213422677791015773, (7, 114))) |956 |
----------------------------------------------------------------------------------------------------------------
|((-1023680607561683160, (1, 164)), (-4748259973688274144, (1, 164)), (-5213422677791015773, (1, 164))) |1 |
----------------------------------------------------------------------------------------------------------------
Number of unique shapes: 6
There is some dynamicity in input data shapes
Code language: JavaScript (javascript)
You can now observe that the model now has only six input shapes, where the sentence lengths have been separated into buckets and then using the smallest amount of padding possible to fill each bucket.
Detecting Dynamic Ops
Now that we know how to detect dynamic inputs, in the next section we shall try to detect dynamic ops in models. Dynamic ops are those operations whose output shapes cannot be predicted just from knowing the input shapes.
Simple example
In the next example, we have a simple toy model, which we run for 5 steps. The input shape changes at the 4th step, so we expect recompilation there. However the model itself has dynamic ops, so we will see the tool identify the module which might be dynamic.
The code examples below can be run in the Terminal window. Simply copy this code into a python file (dyn_ops.py) and run on the terminal window
from habana_frameworks.torch.utils.experimental import detect_recompilation_auto_model
import torch
class InnerNet(torch.nn.Module):
def __init__(self):
super(InnerNet, self).__init__()
self.conv = torch.nn.Conv2d(1, 8, 3, 3)
def forward(self, x):
x = torch.flatten(self.conv(x), 1)
x = x[x>0] # This is dynamic
return x.sum()
net = torch.nn.Sequential(torch.nn.ReLU(), InnerNet()).to('hpu')
net = detect_recompilation_auto_model(net) # wrap model in dynamic op detection tool
for bs in [20,20,30,30]: #Input shape changes at 3rd step
inp = torch.rand(bs, 1, 50, 50).to('hpu')
print(net(inp))
net.analyse_dynamicity() # Call this after a few steps to generate the dynamicity report
The tool outputs 2 tables (and corresponding csv files)
The first one shows what happens at each step, while the second one shows which module/submodule recompiled the most times. Lets analyse the first table for each step
Step | Recompiling modules | New in | New out | Class | Location | Comment |
0 | Net/0 | True | True | torch.nn.modules.activation.ReLU | /usr/local/lib/python3.8/dist-packages/torch/nn/modules/activation.py | Recompiled due to new input shape |
0 | Net/1/conv | True | True | torch.nn.modules.conv.Conv2d | /usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py | Recompiled due to new input shape |
0 | Net/1 | True | True | __main__.InnerNet | dyn_ops.py | Recompiled due to new input shape |
0 | Net | True | True | torch.nn.modules.container.Sequential | /usr/local/lib/python3.8/dist-packages/torch/nn/modules/container.py | Recompiled due to new input shape |
1 | Net/1 | False | False | __main__.InnerNet | dyn_ops.py | Already processed input shape still recompiled. Maybe dyn ops |
1 | Net | False | False | torch.nn.modules.container.Sequential | /usr/local/lib/python3.8/dist-packages/torch/nn/modules/container.py | Already processed input shape still recompiled. Maybe dyn ops. Could be due to dynamic child |
2 | Net/0 | True | True | torch.nn.modules.activation.ReLU | /usr/local/lib/python3.8/dist-packages/torch/nn/modules/activation.py | Recompiled due to new input shape |
2 | Net/1/conv | True | True | torch.nn.modules.conv.Conv2d | /usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py | Recompiled due to new input shape |
2 | Net/1 | True | False | __main__.InnerNet | dyn_ops.py | Recompiled due to new input shape |
2 | Net | True | False | torch.nn.modules.container.Sequential | /usr/local/lib/python3.8/dist-packages/torch/nn/modules/container.py | Recompiled due to new input shape |
3 | Net/1 | False | False | __main__.InnerNet | dyn_ops.py | Already processed input shape still recompiled. Maybe dyn ops |
3 | Net | False | False | torch.nn.modules.container.Sequential | /usr/local/lib/python3.8/dist-packages/torch/nn/modules/container.py | Already processed input shape still recompiled. Maybe dyn ops. Could be due to dynamic child |
Step 0: The first two lines of the first table show all modules recompile since it is the first step.
Step 1: The next two lines show InnerNet and Net recompile. The “Comment” column, however, shows that InnerNet might be dynamic because it recompiled even without dynamic children modules, while Net might not be dynamic as it might have recompiled because its child (InnerNet) has recompiled.
Step 2: The next four lines show Step 2, where a new input shape is seen, so every module recompiles as expected shown in the “Comment” column.
Step 3: The last two lines for Step 3 again point to InnerNet as having dynamic ops.
Thus possible outputs from the tool’s “Comment” column are:
- Recompiled due to new input shape
- Already processed input shape still recompiled and has new output shape. Maybe dyn ops. Could be due to dynamic child
- Already processed input shape still recompiled. Maybe dyn ops
- Already processed input shape still recompiled and has new output shape. Maybe dyn ops
The first comment is due to some new shape, so recompilation is expected. The second comment is possible recompilation of a module because of recompilation of some child module. The last 2 comments are of interest because they identify modules which have the dynamic op
Note that this tool takes time to run, so its recommended to run for a short number of steps. Also it should be run on a 1 card (without distributed). Finally, while the tool can detect recompilation due to inputs (and ignore those), it is recommended to pass in same shape inputs where possible to save time running the tool. With static inputs, the tool can focus only on finding dynamic ops, which is the more interesting case than just dynamic inputs.
In the next example, we replace the dynamic portion with a static equivalent. On running the detect_recompilation_auto_model tool, we now see dynamicity only from inputs
from habana_frameworks.torch.utils.experimental import detect_recompilation_auto_model
import torch
class InnerNet(torch.nn.Module):
def __init__(self):
super(InnerNet, self).__init__()
self.conv = torch.nn.Conv2d(1, 8, 3, 3)
def forward(self, x):
x = torch.flatten(self.conv(x), 1)
#x = x[x>0] # This is dynamic, replacing in next line with static implementation
x = torch.where(x>0, x, torch.zeros_like(x))
return x.sum()
net = torch.nn.Sequential(torch.nn.ReLU(), InnerNet()).to('hpu')
net = detect_recompilation_auto_model(net)
for bs in [20,20,30,30]: #Input shape changes at 4th step
inp = torch.rand(bs, 1, 50, 50).to('hpu')
print(net(inp))
net.analyse_dynamicity() # Call this after a few steps to generate the dynamicity report
Step | Recompiling modules | New in | New out | Class | Location | Comment |
0 | Net/0 | True | True | torch.nn.modules.activation.ReLU | /usr/local/lib/python3.8/dist-packages/torch/nn/modules/activation.py | Recompiled due to new input shape |
0 | Net/1/conv | True | True | torch.nn.modules.conv.Conv2d | /usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py | Recompiled due to new input shape |
0 | Net/1 | True | True | __main__.InnerNet | dyn_ops_static.py | Recompiled due to new input shape |
0 | Net | True | True | torch.nn.modules.container.Sequential | /usr/local/lib/python3.8/dist-packages/torch/nn/modules/container.py | Recompiled due to new input shape |
2 | Net/0 | True | True | torch.nn.modules.activation.ReLU | /usr/local/lib/python3.8/dist-packages/torch/nn/modules/activation.py | Recompiled due to new input shape |
2 | Net/1/conv | True | True | torch.nn.modules.conv.Conv2d | /usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py | Recompiled due to new input shape |
2 | Net/1 | True | False | __main__.InnerNet | dyn_ops_static.py | Recompiled due to new input shape |
2 | Net | True | False | torch.nn.modules.container.Sequential | /usr/local/lib/python3.8/dist-packages/torch/nn/modules/container.py | Recompiled due to new input shape |
Real model example
In the next example we will look at a real model, Faster RCNN, and try to detect dynamic sections in the model
wget https://ultralytics.com/assets/coco128.zip
unzip coco128.zip
import torchvision, os
from PIL import Image
import torchvision.transforms as T
import habana_frameworks.torch.core as htcore
device = 'hpu'
#load model
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval() # set to evaluation mode
model = model.to(device) # move model to device
from habana_frameworks.torch.utils.experimental import detect_recompilation_auto_model
model = detect_recompilation_auto_model(model, waittime=0.3)
for idx, k in enumerate(os.listdir('coco128/images/train2017/')):
img = Image.open('coco128/images/train2017/' + k).resize((600,600))
img = T.ToTensor()(img).to(device)
print('inp shape:', img.shape)
pred = model([img])
htcore.mark_step()
if idx == 6: # just running first few images
break
print('done img', idx)
model.analyse_dynamicity()
From the outputs, we see the following
Step | Recompiling modules | New in | New out | Class | Location | Comment |
1 | Net/roi_heads/box_roi_pool | False | False | torchvision.ops.poolers.MultiScaleRoIAlign | /usr/local/lib/python3.8/dist-packages/torchvision/ops/poolers.py | Already processed input shape still recompiled. Maybe dyn ops |
1 | Net/roi_heads | False | True | torchvision.models.detection.roi_heads.RoIHeads | /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/roi_heads.py | Already processed input shape still recompiled and has new output shape. Maybe dyn ops |
This tells us that the MultiScaleRoIAlign and RoIHeads classes have some dynamic ops in it. Checking the module we find the following where op used in MultiScaleRoIAlign and another where op used for RoIHeads. We can try to rewrite these sections as static as discussed here or move the operation to CPU. For such strategies please see this reference
Summary
This Performance tutorial will show how to optimize your model by detecting excessive recompilations due to dynamic inputs and Ops, which slow down execution. These tools will show how to detect recompilations and give examples of how to reduce these, either by using Habana’s automated dynamic shapes handler or explicitly creating bucketing algorithms.
Copyright© 2023 Habana Labs, Ltd. an Intel Company.
Licensed under the Apache License, Version 2.0 (the “License”);
You may not use this file except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.