Home » Tutorials » PyTorch » PyTorch Mixed Precision

PyTorch Mixed Precision


Mixed precision is the use of both 16-bit and 32-bit floating-point types in a model during training to make it faster and use less memory. By keeping certain parts of the model in the 32-bit types for numerical stability, the model will have a lower step time and train equally as well in terms of the evaluation metrics such as accuracy.

Models that use float32 will take 32 bits of memory. However, there are two lower-precision dtypes, float16 and bfloat16. Each of them takes 16 bits of memory instead. Modern accelerators can run operations faster in the 16-bit dtypes, as they have specialized hardware to run 16-bit computations and 16-bit dtypes can be read from memory faster.

Habana HPUs can run operations in bfloat16 faster than float32. Therefore, these lower-precision dtypes should be used whenever possible on HPUs. However, variables and a few computations should still be in float32 for numerical stability so that the model is trained to the same quality. The PyTorch mixed precision allows you to use a mix of bfloat16 and float32 during model training, to get the performance benefits from bfloat16 and the numerical stability benefits from float32.

Note: In this tutorial, the term “numerical stability” refers to how a model’s quality is affected by the use of a lower-precision dtype instead of a higher precision dtype. An operation is “numerically unstable” in bfloat16 if running it in bfloat16 dtypes causes the model to have worse evaluation accuracy or other metrics compared to running the operation in float32.

Supported hardware

Habana Gaudi HPUs supports a mix of bfloat16 and float32.

Even on CPUs, where no speedup is expected, mixed precision APIs can still be used for unit testing, debugging, or just to try out the API. However, on CPUs, mixed precision will run significantly slower.


PyTorch mixed precision support can be easily added to the model script by adding the following lines anywhere in the script before the start of the training loop:

import torch
from habana_frameworks.torch.hpex import hmp

Any segment of script (e.g. optimizer) in which you want to avoid using mixed precision should be kept under the following Python context:

from habana_frameworks.torch.hpex import hmp
with hmp.disable_casts():
    code line:1
    code line:2

We also need to add the below code to enable a Gaudi device

from habana_frameworks.torch.utils.library_loader import load_habana_module
Loading Habana modules from /usr/local/lib/python3.8/dist-packages/habana_frameworks/torch/lib
Code language: JavaScript (javascript)

Design Rules

Two different lists are maintained:

  • OPs that always run in BF16 only
  • OPs that always run in FP32 only.

Python decorators are used to add required functionality (bf16 or fp32 casts on OP input(s)) to torch functions (refer to code snippet below).

Any OPs not in the above two lists will run with the precision type of its 1st input (except for exceptions listed below).

For OPs with multiple tensor inputs (maintained in a separate list, e.g. add, sub, cat, stack etc.), cast all inputs to the widest precision type among all input precision types. If any of these OPs are in BF16 or FP32 list, that list has higher precedence.

For in-place OPs (output & 1st input share storage), cast all inputs to the precision type of 1st input.

from functools import wraps
def op_wrap(op, cast_fn):
    """Adds wrapper function to OPs. All tensor inputs
    for the OP are casted to type determined by cast_fn

    op (torch.nn.functional/torch/torch.Tensor): Input OP
    cast_fn (to_bf16/to_fp32): Fn to cast input tensors

    Wrapper function that shall be inserted back to
    corresponding module for this OP.
    def wrapper(*args, **kwds):
        args_cast = get_new_args(cast_fn, args, kwds)
        return op(*args_cast, **kwds)
    return wrapper

Configuration Options

Habana Mixed Precision (HMP) provides two modes (opt_level = O1/O2) of mixed precision training to choose from. These modes can be chosen by passing opt_level= as an argument to hmp.convert().

O1 is the default and recommended mode of operation when using HMP. O2 can be used for debugging convergence issues as well as for initial iterations of converting a new model to run with mixed precision

Opt_level = O1

In this mode, OPs that always run in BF16 and OPs that always run in FP32 are selected from a BF16 list and FP32 list respectively. BF16 list contains OPs that are numerically safe to run in lower precision on HPU, whereas FP32 list contains OPs that should be run in higher precision (a conservative choice that works across models).

Default BF16 list = [addmm, bmm, conv1d, conv2d, conv3d, dot, mm, mv]

Default FP32 list = [batch_norm, cross_entropy, log_softmax, softmax, nll_loss, topk]

HMP provides the option of overriding these internal lists, allowing you to provide your own BF16 and FP32 lists (pass bf16_file_path=<.txt> and fp32_file_path=<.txt> as arguments to hmp.convert()). This is particularly useful when customizing mixed precision training for a particular model. For example:

Custom BF16 list for ResNet50 = [ addmm, avg_pool2d, bmm, conv2d, dot, max_pool2d, mm, mv, relu, t, linear]

Custom FP32 list for ResNet50 = [cross_entropy, log_softmax, softmax, nll_loss, topk]

Opt_level = O2

In this mode, only GEMM and Convolution type OPs (e.g. conv1d, conv2d, conv3d, addmm, mm, bmm, mv, dot) should run in BF16 and all other OPs should run in FP32.

Usage Examples

import torch
from habana_frameworks.torch.hpex import hmp
from habana_frameworks.torch.utils.library_loader import load_habana_module
N, D_in, D_out = 64, 1024, 512
x = torch.randn(N, D_in, device=torch.device("hpu"))
y = torch.randn(N, D_out, device=torch.device("hpu"))
# enable mixed precision training with optimization level O1, default BF16 list, default FP32 list and logging disabled
# use opt_level to select desired mode of operation
# use bf16_file_path to provide absolute path to a file with custom BF16 list
# use fp32_file_path to provide absolute path to a file with custom FP32 list
# use isVerbose to disable/enable debug logs
hmp.convert(opt_level="O1", bf16_file_path="", fp32_file_path="", isVerbose=False)
model = torch.nn.Linear(D_in, D_out).to(torch.device("hpu"))
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
for t in range(500):
   y_pred = model(x)
   loss = torch.nn.functional.mse_loss(y_pred, y)
   # disable mixed precision for optimizer block
   with hmp.disable_casts():
Loading Habana modules from /usr/local/lib/python3.8/dist-packages/habana_frameworks/torch/lib hmp:verbose_mode False hmp:opt_level O1
Code language: JavaScript (javascript)

HPU performance tips

You should consider doubling your batch size when using HPUs because bfloat16 tensors use half the memory. Doubling batch size may increase training throughput.


  • ou should use mixed precision if you use Habana HPUs, as it will improve training performance.
  • You can use mixed precision with the following lines:import torch from habana_frameworks.torch.hpex import hmp hmp.convert()
  • You can customize mixed precision training for a particular model by defining two lists for OPs that always run in BF16 and OPs that always run in FP32.
  • Add the following code before any segment of code where you want to avoid using mixed precision:with hmp.disable_casts()
  • Double the training batch size if it does not reduce evaluation accuracy

Copyright (c) 2022 Habana Labs, Ltd. an Intel Company.
All rights reserved.

Licensed under the Apache License, Version 2.0 (the “License”);

You may not use this file except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

Sign up for the latest Habana developer news, events, training, and updates.