Home » Tutorials » TensorFlow » Mixed precision

Mixed precision

An adaptation of TensorFlow Mixed precision tutorial using Habana Gaudi AI processors.

This tutorial demonstrates enabling mixed-precision training for Keras models.
You can find the full guide of TensorFlow Mixed Precision Training on Gaudi here.


Mixed precision is the use of both 16-bit and 32-bit floating-point types in a model during training to make it run faster and use less memory. By keeping certain parts of the model in the 32-bit types for numeric stability, the model will have a lower step time and train equally as well in terms of the evaluation metrics such as accuracy. This guide describes how to use the Keras mixed precision API to speed up your models. Using this API can improve performance by more than 2 times on Habana HPUs and 60% on CPUs.

Today, most models use the float32 dtype, which takes 32 bits of memory. However, there are two lower-precision dtypes, float16 and bfloat16, each which take 16 bits of memory instead. Modern accelerators can run operations faster in the 16-bit dtypes, as they have specialized hardware to run 16-bit computations and 16-bit dtypes can be read from memory faster.

Habana HPUs can run operations in bfloat16 faster than float32. Therefore, these lower-precision dtypes should be used whenever possible on those devices. However, variables and a few computations should still be in float32 for numeric reasons so that the model trains to the same quality. The Keras mixed precision API allows you to use a mix of either bfloat16 with float32, to get the performance benefits from bfloat16 and the numeric stability benefits from float32.

Note: In this guide, the term “numeric stability” refers to how a model’s quality is affected by the use of a lower-precision dtype instead of a higher precision dtype. An operation is “numerically unstable” in bfloat16 if running it in one of those dtypes causes the model to have worse evaluation accuracy or other metrics compared to running the operation in float32.


import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import mixed_precision
from tensorflow.python.client import device_lib

Add the below code to enable a Gaudi device

import habana_frameworks.tensorflow as htf

Supported hardware

Habana’s HPUs support a mix of bfloat16 and float32.

Even on CPUs, where no speedup is expected, mixed precision APIs can still be used for unit testing, debugging, or just to try out the API. On CPUs, mixed precision will run significantly slower, however.

Setting the dtype policy

To use mixed precision in Keras, you need to create a tf.keras.mixed_precision.Policy, typically referred to as a dtype policy. Dtype policies specify the dtypes layers will run in. In this guide, you will construct a policy from the string 'mixed_bfloat16' and set it as the global policy. This will cause subsequently created layers to use mixed precision with a mix of bfloat16 and float32.

policy = mixed_precision.Policy('mixed_bfloat16')

For short, you can directly pass a string to set_global_policy, which is typically done in practice.

# Equivalent to the two lines above

The policy specifies two important aspects of a layer: the dtype the layer’s computations are done in, and the dtype of a layer’s variables. Above, you created a mixed_bfloat16 policy (i.e., a mixed_precision.Policy created by passing the string 'mixed_bfloat16' to its constructor). With this policy, layers use bfloat16 computations and float32 variables. Computations are done in bfloat16 for performance, but variables must be kept in float32 for numeric stability. You can directly query these properties of the policy.

print('Compute dtype: %s' % policy.compute_dtype)
print('Variable dtype: %s' % policy.variable_dtype)
Compute dtype: bfloat16 Variable dtype: float32

As mentioned before, the mixed_bfloat16 policy will most significantly improve performance on HPUs. The policy will run on CPUs but may not improve performance.

Building the model

Next, let’s start building a simple model. Very small toy models typically do not benefit from mixed precision, because overhead from the TensorFlow runtime typically dominates the execution time, making any performance improvement on the HPU negligible. Therefore, let’s build two large Dense layers with 256 units each if a HPU is used.

inputs = keras.Input(shape=(784,), name='digits')
if device_lib.list_local_devices():
  print('The model will run with 256 units on a HPU')
  num_units = 256
  # Use fewer units on CPUs so the model finishes in a reasonable amount of time
  print('The model will run with 64 units on a CPU')
  num_units = 64
dense1 = layers.Dense(num_units, activation='relu', name='dense_1')
x = dense1(inputs)
dense2 = layers.Dense(num_units, activation='relu', name='dense_2')
x = dense2(x)
The model will run with 256 units on a HPU
Code language: JavaScript (javascript)

Each layer has a policy and uses the global policy by default. Each of the Dense layers therefore have the mixed_bfloat16 policy because you set the global policy to mixed_bfloat16 previously. This will cause the dense layers to do bfloat16 computations and have float32 variables. They cast their inputs to bfloat16 in order to do bfloat16 computations, which causes their outputs to be bfloat16 as a result. Their variables are float32 and will be cast to bfloat16 when the layers are called to avoid errors from dtype mismatches.

print('x.dtype: %s' % x.dtype.name)
# 'kernel' is dense1's variable
print('dense1.kernel.dtype: %s' % dense1.kernel.dtype.name)
<Policy "mixed_bfloat16"> x.dtype: bfloat16 dense1.kernel.dtype: float32
Code language: CSS (css)

Next, create the output predictions. Normally, you can create the output predictions as follows, but this is not always numerically stable with bfloat16.

# INCORRECT: softmax and model output will be bfloat16, when it should be float32
outputs = layers.Dense(10, activation='softmax', name='predictions')(x)
print('Outputs dtype: %s' % outputs.dtype.name)

Here’s the complete architecture of your model:

Outputs dtype: bfloat16

A softmax activation at the end of the model should be float32. Because the dtype policy is mixed_bfloat16, the softmax activation would normally have a bfloat16 compute dtype and output bfloat16 tensors.

This can be fixed by separating the Dense and softmax layers, and by passing dtype='float32' to the softmax layer:

# CORRECT: softmax and model output are float32
x = layers.Dense(10, name='dense_logits')(x)
outputs = layers.Activation('softmax', dtype='float32', name='predictions')(x)
print('Outputs dtype: %s' % outputs.dtype.name)
Outputs dtype: float32

Passing dtype='float32' to the softmax layer constructor overrides the layer’s dtype policy to be the float32 policy, which does computations and keeps variables in float32. Equivalently, you could have instead passed dtype=mixed_precision.Policy('float32'); layers always convert the dtype argument to a policy. Because the Activation layer has no variables, the policy’s variable dtype is ignored, but the policy’s compute dtype of float32 causes softmax and the model output to be float32.

Adding a bfloat16 softmax in the middle of a model is fine, but a softmax at the end of the model should be in float32. The reason is that if the intermediate tensor flowing from the softmax to the loss is bfloat16, numeric issues may occur.

You can override the dtype of any layer to be float32 by passing dtype='float32' if you think it will not be numerically stable with bfloat16 computations. But typically, this is only necessary on the last layer of the model, as most layers have sufficient precision with mixed_bfloat16.

Even if the model does not end in a softmax, the outputs should still be float32. While unnecessary for this specific model, the model outputs can be cast to float32 with the following:

# The linear activation is an identity function. So this simply casts 'outputs'
# to float32. In this particular case, 'outputs' is already float32 so this is a
# no-op.
outputs = layers.Activation('linear', dtype='float32')(outputs)

Next, finish and compile the model, and generate input data:

model = keras.Model(inputs=inputs, outputs=outputs)

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255

This example cast the input data from int8 to float32. You don’t cast to bfloat16 since the division by 255 is on the CPU, which runs bfloat16 operations slower than float32 operations. In this case, the performance difference in negligible, but in general you should run input processing math in float32 if it runs on the CPU. The first layer of the model will cast the inputs to bfloat16, as each layer casts floating-point inputs to its compute dtype.

The initial weights of the model are retrieved. This will allow training from scratch again by loading the weights.

initial_weights = model.get_weights()

Training the model with Model.fit

Next, train the model:

Notice the model prints the time per step in the logs: for example, “2s/step”. The first epoch may be slower as TensorFlow spends some time optimizing the model, but afterwards the time per step should stabilize.

If you are running this guide in Colab, you can compare the performance of mixed precision with float32. To do so, change the policy from mixed_bfloat16 to float32 in the “Setting the dtype policy” section, then rerun all the cells up to this point. On HPUs, you should see the time per step significantly increase, indicating mixed precision sped up the model. Make sure to change the policy back to mixed_bfloat16 and rerun the cells before continuing with the guide.

For many real-world models, mixed precision also allows you to double the batch size without running out of memory, as bfloat16 tensors take half the memory. This does not apply however to this toy model, as you can likely run the model in any dtype where each batch consists of the entire MNIST dataset of 60,000 images.

history = model.fit(x_train, y_train,
test_scores = model.evaluate(x_test, y_test, verbose=2)
print('Test loss:', test_scores[0])
print('Test accuracy:', test_scores[1])
Train on 48000 samples, validate on 12000 samples Epoch 1/5 48000/48000 [==============================] - ETA: 0s - loss: 1.5784 - accuracy: 0.5529 48000/48000 [==============================] - 10s 206us/sample - loss: 1.5784 - accuracy: 0.5529 - val_loss: 0.9286 - val_accuracy: 0.7214 Epoch 2/5 48000/48000 [==============================] - 1s 12us/sample - loss: 0.7776 - accuracy: 0.7743 - val_loss: 0.6914 - val_accuracy: 0.7849 Epoch 3/5 48000/48000 [==============================] - 1s 12us/sample - loss: 0.5905 - accuracy: 0.8260 - val_loss: 0.4472 - val_accuracy: 0.8795 Epoch 4/5 48000/48000 [==============================] - 1s 13us/sample - loss: 0.4701 - accuracy: 0.8607 - val_loss: 0.4482 - val_accuracy: 0.8682 Epoch 5/5 48000/48000 [==============================] - 1s 12us/sample - loss: 0.4683 - accuracy: 0.8545 - val_loss: 0.4122 - val_accuracy: 0.8813 Test loss: 0.42173512543439867 Test accuracy: 0.8749

HPU performance tips

You should try doubling your batch size when using HPUs because bfloat16 tensors use half the memory. Doubling batch size may increase training throughput.


  • You should use mixed precision if you use Habana HPUs, as it will improve performance by up to 2x.
  • You can use mixed precision with the following lines: mixed_precision.set_global_policy('mixed_bfloat16')
  • If your model ends in softmax, make sure it is float32. And regardless of what your model ends in, make sure the output is float32.
  • If you use a custom training loop with mixed_bfloat16, in addition to the above lines, you need to wrap your optimizer with a tf.keras.mixed_precision.LossScaleOptimizer. Then call optimizer.get_scaled_loss to scale the loss, and optimizer.get_unscaled_gradients to unscale the gradients.
  • Double the training batch size if it does not reduce evaluation accuracy
Copyright (c) 2021 Habana Labs, Ltd. an Intel Company.
Copyright 2019 The TensorFlow Authors.
All rights reserved.

Licensed under the Apache License, Version 2.0 (the “License”);

you may not use this file except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

Sign up for the latest Habana developer news, events, training, and updates.