Home » Tutorials » TensorFlow » Distributed Training Using TensorFlow and Horovod

Distributed Training Using TensorFlow and Horovod

This tutorial demonstrates how distributed training works with Horovod using Habana Gaudi AI processors. Horovod is a distributed deep learning training framework, which can achieve high scaling efficiency. Using Horovod, Users can distribute the training of models between multiple Gaudi devices and also between multiple servers. To demonstrate distributed training, we will train a simple Keras model ...

This tutorial demonstrates how distributed training works with Horovod using Habana Gaudi AI processors.

Horovod is a distributed deep learning training framework, which can achieve high scaling efficiency. Using Horovod, Users can distribute the training of models between multiple Gaudi devices and also between multiple servers.

To demonstrate distributed training, we will train a simple Keras model on the MNIST database.

You can find more information on distributed training using TensorFlow and Horovod on Gaudi TensorFlow Scaling tutorial.

Start MPI engines in Jupiter notebook

MPI is used for coordinating work between processes in Horovod. You can find a simple example of how to initialize MPI and run the model with Horovod using the command “mpirun” here.

You can find more information on the Open MPI website.

ipyparallel and mpi4py are required to use MPI from the Jupiter notebook, If they have not been installed, install them using the following command:

# uncomment next line if ipyparallel is not installed
# !pip install jupyter
# !pip install ipyparallel
# !pip install mpi4py

First, import the ipyparallel package, and then start the MPI engines.

In our example, we will start 8 MPI engines to use all the 8 Gaudi devices in our machine.

import ipyparallel as ipp
import os
os.environ["OMPI_ALLOW_RUN_AS_ROOT"] = "1"
os.environ["OMPI_ALLOW_RUN_AS_ROOT_CONFIRM"] = "1"

n_hpu=8
cluster = ipp.Cluster(engines='mpi', n=n_hpu)
client = cluster.start_and_connect_sync()
Starting 8 engines with <class 'ipyparallel.cluster.launcher.MPIEngineSetLauncher'>
  0%|          | 0/8 [00:00<?, ?engine/s]Code language: HTML, XML (xml)

Execute Python commands in parallel

The %%px cell magic is used to execute Python command on all the MPI engines in parallel.

Import TensorFlow

The MPI engines have been started. The following scripts will import the TensorFlow library in each engine in parallel.

%%px
import tensorflow as tf

Import and enable Habana TensorFlow module

Let’s enable Gaudi devices by loading the Habana module:

%%px
from habana_frameworks.tensorflow import load_habana_module
load_habana_module()
[stderr:4] WARNING:/usr/local/lib/python3.8/dist-packages/habana_frameworks/tensorflow/library_loader.py:Habana-TensorFlow(1.2.0) and Habanalabs Driver(1.3.0-e793625) versions differ!
[stderr:5] WARNING:/usr/local/lib/python3.8/dist-packages/habana_frameworks/tensorflow/library_loader.py:Habana-TensorFlow(1.2.0) and Habanalabs Driver(1.3.0-e793625) versions differ!
[stderr:0] WARNING:/usr/local/lib/python3.8/dist-packages/habana_frameworks/tensorflow/library_loader.py:Habana-TensorFlow(1.2.0) and Habanalabs Driver(1.3.0-e793625) versions differ!
[stderr:7] WARNING:/usr/local/lib/python3.8/dist-packages/habana_frameworks/tensorflow/library_loader.py:Habana-TensorFlow(1.2.0) and Habanalabs Driver(1.3.0-e793625) versions differ!
[stderr:3] WARNING:/usr/local/lib/python3.8/dist-packages/habana_frameworks/tensorflow/library_loader.py:Habana-TensorFlow(1.2.0) and Habanalabs Driver(1.3.0-e793625) versions differ!
%px:   0%|          | 0/8 [00:00<?, ?tasks/s]
[stderr:1] WARNING:/usr/local/lib/python3.8/dist-packages/habana_frameworks/tensorflow/library_loader.py:Habana-TensorFlow(1.2.0) and Habanalabs Driver(1.3.0-e793625) versions differ!
[stderr:2] WARNING:/usr/local/lib/python3.8/dist-packages/habana_frameworks/tensorflow/library_loader.py:Habana-TensorFlow(1.2.0) and Habanalabs Driver(1.3.0-e793625) versions differ!
[stderr:6] WARNING:/usr/local/lib/python3.8/dist-packages/habana_frameworks/tensorflow/library_loader.py:Habana-TensorFlow(1.2.0) and Habanalabs Driver(1.3.0-e793625) versions differ!Code language: JavaScript (javascript)

Import Horovod and get it ready for distributed training

%%px
import horovod.tensorflow.keras as hvd
## Initialize Horovod
hvd.init()
[stdout:6] Using custom Horovod, path: /usr/local/lib/python3.8/dist-packages/horovod/tensorflow/mpi_lib.cpython-38-x86_64-linux-gnu.so
[stdout:2] Using custom Horovod, path: /usr/local/lib/python3.8/dist-packages/horovod/tensorflow/mpi_lib.cpython-38-x86_64-linux-gnu.so
[stdout:0] Using custom Horovod, path: /usr/local/lib/python3.8/dist-packages/horovod/tensorflow/mpi_lib.cpython-38-x86_64-linux-gnu.so
[stdout:5] Using custom Horovod, path: /usr/local/lib/python3.8/dist-packages/horovod/tensorflow/mpi_lib.cpython-38-x86_64-linux-gnu.so
[stdout:1] Using custom Horovod, path: /usr/local/lib/python3.8/dist-packages/horovod/tensorflow/mpi_lib.cpython-38-x86_64-linux-gnu.so
[stdout:4] Using custom Horovod, path: /usr/local/lib/python3.8/dist-packages/horovod/tensorflow/mpi_lib.cpython-38-x86_64-linux-gnu.so
[stdout:7] Using custom Horovod, path: /usr/local/lib/python3.8/dist-packages/horovod/tensorflow/mpi_lib.cpython-38-x86_64-linux-gnu.so
[stdout:3] Using custom Horovod, path: /usr/local/lib/python3.8/dist-packages/horovod/tensorflow/mpi_lib.cpython-38-x86_64-linux-gnu.soCode language: JavaScript (javascript)

Download the MNIST database

We will download the MNIST database only in the first process on each node, and save it locally. Once downloaded, we partition the data evenly among all workers.

%%px

# Ensure only 1 process downloads the data on each node
if hvd.local_rank() == 0:
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    hvd.broadcast(0, 0)
else:
    hvd.broadcast(0, 0)
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# Data partition for different workers
num_pics_per_rank = x_train.shape[0] // hvd.size()
pic_begin = num_pics_per_rank * hvd.rank()
pic_end = pic_begin + num_pics_per_rank
x_train = x_train[pic_begin:pic_end,]
y_train = y_train[pic_begin:pic_end,]

x_train, x_test = x_train / 255.0, x_test / 255.0
[stdout:0] Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 1s 0us/step
11501568/11490434 [==============================] - 1s 0us/step
[stderr:2] 2022-01-22 08:09:27.737402: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-01-22 08:09:29.205166: W /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/habana_device/habana_device.cpp:182] HPU initialization done for library version 1.2.0_c6aea18b_tf2.7.0
[stderr:6] 2022-01-22 08:09:27.736178: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-01-22 08:09:29.481811: W /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/habana_device/habana_device.cpp:182] HPU initialization done for library version 1.2.0_c6aea18b_tf2.7.0
[stderr:4] 2022-01-22 08:09:27.736167: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-01-22 08:09:29.491051: W /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/habana_device/habana_device.cpp:182] HPU initialization done for library version 1.2.0_c6aea18b_tf2.7.0
%px:   0%|          | 0/8 [00:00<?, ?tasks/s]
[stderr:7] 2022-01-22 08:09:27.738345: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-01-22 08:09:29.522639: W /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/habana_device/habana_device.cpp:182] HPU initialization done for library version 1.2.0_c6aea18b_tf2.7.0
[stderr:5] 2022-01-22 08:09:27.738608: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-01-22 08:09:29.544330: W /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/habana_device/habana_device.cpp:182] HPU initialization done for library version 1.2.0_c6aea18b_tf2.7.0
[stderr:3] 2022-01-22 08:09:27.736339: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-01-22 08:09:29.568677: W /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/habana_device/habana_device.cpp:182] HPU initialization done for library version 1.2.0_c6aea18b_tf2.7.0
[stderr:1] 2022-01-22 08:09:27.730682: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-01-22 08:09:29.575768: W /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/habana_device/habana_device.cpp:182] HPU initialization done for library version 1.2.0_c6aea18b_tf2.7.0
[stderr:0] 2022-01-22 08:09:28.699532: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-01-22 08:09:30.369017: W /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/habana_device/habana_device.cpp:182] HPU initialization done for library version 1.2.0_c6aea18b_tf2.7.0Code language: JavaScript (javascript)

Create a model for training

Create a simple model, wrap the optimizer with Horovod’s distributedOptimizer. The distributedOptimizer averages gradients and applies gradients to each worker.

%%px
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(10),
])
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Using hvd.size()(number of workers) to scale learning rate and wrapping
# optimizer with Distributed optimizer class provided by horovod.
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01*hvd.size())
optimizer = hvd.DistributedOptimizer(optimizer)

callbacks = [
    # Horovod: broadcast initial variable states from rank0 to all other processes.
    # This is necessary to ensure consistent initialization of all workers when
    # training is started with random weights or restored from a checkpoint.
    hvd.callbacks.BroadcastGlobalVariablesCallback(0),
]

Compile and train the model

Traing using model.fit. Each MPI engine will start the training on a different Gaudi device in parallel.

%%px
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
model.fit(x_train, y_train, epochs=1, batch_size=128, callbacks=callbacks)
%px:   0%|          | 0/8 [00:00<?, ?tasks/s]
[stdout:5] 59/59 [==============================] - 3s 8ms/step - loss: 1.2938 - accuracy: 0.6984
[stdout:1] 59/59 [==============================] - 3s 8ms/step - loss: 1.2891 - accuracy: 0.6927
[stdout:7] 59/59 [==============================] - 3s 8ms/step - loss: 1.2486 - accuracy: 0.7231
[stdout:6] 59/59 [==============================] - 3s 8ms/step - loss: 1.2868 - accuracy: 0.6973
[stdout:3] 59/59 [==============================] - 3s 8ms/step - loss: 1.2819 - accuracy: 0.6929
[stdout:4] 59/59 [==============================] - 3s 8ms/step - loss: 1.3169 - accuracy: 0.6909
[stdout:0] 59/59 [==============================] - 3s 8ms/step - loss: 1.2692 - accuracy: 0.6999
[stdout:2] 59/59 [==============================] - 3s 8ms/step - loss: 1.2875 - accuracy: 0.6993
[stderr:0] WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0048s vs `on_train_batch_end` time: 0.0221s). Check your callbacks.
[stderr:5] WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0050s vs `on_train_batch_end` time: 0.0214s). Check your callbacks.
[stderr:4] WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0049s vs `on_train_batch_end` time: 0.0204s). Check your callbacks.
[stderr:1] WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0051s vs `on_train_batch_end` time: 0.0231s). Check your callbacks.
[stderr:2] WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0048s vs `on_train_batch_end` time: 0.0223s). Check your callbacks.
[stderr:7] WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0049s vs `on_train_batch_end` time: 0.0225s). Check your callbacks.
[stderr:6] WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0050s vs `on_train_batch_end` time: 0.0211s). Check your callbacks.
[stderr:3] WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0050s vs `on_train_batch_end` time: 0.0224s). Check your callbacks.
Out[6:6]: <keras.callbacks.History at 0x7f1c0e034ac0>
Out[3:6]: <keras.callbacks.History at 0x7fd480885760>
Out[4:6]: <keras.callbacks.History at 0x7f5864a46760>
Out[1:6]: <keras.callbacks.History at 0x7f1eab38a760>
Out[2:6]: <keras.callbacks.History at 0x7f5bbb1d7ca0>
Out[0:6]: <keras.callbacks.History at 0x7efbf2daf7c0>
Out[5:6]: <keras.callbacks.History at 0x7f472e92d760>
Out[7:6]: <keras.callbacks.History at 0x7f5bd0885970>Code language: JavaScript (javascript)

Evaluate the model

%%px
model.evaluate(x_test, y_test)
[stdout:2] 313/313 [==============================] - 1s 3ms/step - loss: 0.8074 - accuracy: 0.8328
[stdout:1] 313/313 [==============================] - 1s 3ms/step - loss: 0.8074 - accuracy: 0.8328
[stdout:5] 313/313 [==============================] - 1s 3ms/step - loss: 0.8074 - accuracy: 0.8328
[stdout:7] 313/313 [==============================] - 1s 3ms/step - loss: 0.8074 - accuracy: 0.8328
[stdout:6] 313/313 [==============================] - 1s 3ms/step - loss: 0.8074 - accuracy: 0.8328
[stdout:3] 313/313 [==============================] - 1s 3ms/step - loss: 0.8074 - accuracy: 0.8328
[stdout:0] 313/313 [==============================] - 1s 3ms/step - loss: 0.8074 - accuracy: 0.8328
[stdout:4] 313/313 [==============================] - 1s 3ms/step - loss: 0.8074 - accuracy: 0.8328
Out[2:7]: [0.807358980178833, 0.8327999711036682]
Out[5:7]: [0.807358980178833, 0.8327999711036682]
Out[1:7]: [0.807358980178833, 0.8327999711036682]
Out[3:7]: [0.807358980178833, 0.8327999711036682]
Out[7:7]: [0.807358980178833, 0.8327999711036682]
Out[6:7]: [0.807358980178833, 0.8327999711036682]
Out[0:7]: [0.807358980178833, 0.8327999711036682]
Out[4:7]: [0.807358980178833, 0.8327999711036682]

Training has been done! Remember to shut down the MPI engines to release resources.


client.shutdown(hub=True)
Controller stopped: {'exit_code': 0, 'pid': 18274, 'identifier': 'ipcontroller-1642838956-zlnc-18257'}
engine set stopped 1642838957: {'exit_code': 0, 'pid': 18369, 'identifier': 'ipengine-1642838956-zlnc-1642838957-18257'}Code language: JavaScript (javascript)

Licensed under the Apache License, Version 2.0 (the “License”);

you may not use this file except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

Share this article:

Stay Informed: Register for the latest Intel Gaudi AI Accelerator developer news, events, training, and updates.