Home » Tutorials » TensorFlow » Distributed Training Using TensorFlow and HPUStrategy

Distributed Training Using TensorFlow and HPUStrategy

This tutorial demonstrates how distributed training works with HPUStrategy using Habana Gaudi AI processors. tf.distribute.Strategy is a TensorFlow API to distribute training across multiple Gaudi devices, and multiple machines. Using this API, you can distribute your existing models and training code with minimal code changes. To demonstrate distributed training, we will train a simple Keras model on the MNIST ...

This tutorial demonstrates how distributed training works with HPUStrategy using Habana Gaudi AI processors.

tf.distribute.Strategy is a TensorFlow API to distribute training across multiple Gaudi devices, and multiple machines. Using this API, you can distribute your existing models and training code with minimal code changes.

To demonstrate distributed training, we will train a simple Keras model on the MNIST database.

You can find more information on distributed training using TensorFlow and HPUStrategy on Multi-Worker Training using HPUStrategy tutorial.

Start MPI engines in Jupiter notebook

MPI is used for coordinating work between processes. You can find a simple example of how to initialize MPI and run the model using the command “mpirun” here.

You can find more information on the Open MPI website.

ipyparallel and mpi4py are required to use MPI from the Jupiter notebook, If they have not been installed, install them using the following command:

# uncomment next lines if needed
# !pip install jupyter
# !pip install ipyparallel
# !pip install mpi4py
# !pip install tensorflow-datasets

First, import the ipyparallel package, and then start the MPI engines.

In our example, we will start 8 MPI engines to use all the 8 Gaudi devices in our machine.

import ipyparallel as ipp
import os
os.environ["OMPI_ALLOW_RUN_AS_ROOT"] = "1"
os.environ["OMPI_ALLOW_RUN_AS_ROOT_CONFIRM"] = "1"

n_hpu=8
cluster = ipp.Cluster(engines='mpi', n=n_hpu)
client = cluster.start_and_connect_sync()
Starting 8 engines with <class 'ipyparallel.cluster.launcher.MPIEngineSetLauncher'>Code language: JavaScript (javascript)

Execute Python commands in parallel

The %%px cell magic is used to execute Python command on all the MPI engines in parallel.

Import TensorFlow

The MPI engines have been started. The following scripts will import the TensorFlow library in each engine in parallel.

%%px
import json
import os

import tensorflow as tf
import tensorflow_datasets as tfds
%px:   0%|          | 0/8 [00:00<?, ?tasks/s]Code language: HTML, XML (xml)

Import and enable Habana TensorFlow module

Let’s enable Gaudi devices by loading the Habana module:

%%px
from habana_frameworks.tensorflow import load_habana_module
load_habana_module()
[stderr:1] WARNING:/usr/local/lib/python3.8/dist-packages/habana_frameworks/tensorflow/library_loader.py:Habana-TensorFlow(1.2.0) and Habanalabs Driver(1.3.0-e793625) versions differ!
[stderr:7] WARNING:/usr/local/lib/python3.8/dist-packages/habana_frameworks/tensorflow/library_loader.py:Habana-TensorFlow(1.2.0) and Habanalabs Driver(1.3.0-e793625) versions differ!
[stderr:5] WARNING:/usr/local/lib/python3.8/dist-packages/habana_frameworks/tensorflow/library_loader.py:Habana-TensorFlow(1.2.0) and Habanalabs Driver(1.3.0-e793625) versions differ!
[stderr:2] WARNING:/usr/local/lib/python3.8/dist-packages/habana_frameworks/tensorflow/library_loader.py:Habana-TensorFlow(1.2.0) and Habanalabs Driver(1.3.0-e793625) versions differ!
[stderr:4] WARNING:/usr/local/lib/python3.8/dist-packages/habana_frameworks/tensorflow/library_loader.py:Habana-TensorFlow(1.2.0) and Habanalabs Driver(1.3.0-e793625) versions differ!
[stderr:3] WARNING:/usr/local/lib/python3.8/dist-packages/habana_frameworks/tensorflow/library_loader.py:Habana-TensorFlow(1.2.0) and Habanalabs Driver(1.3.0-e793625) versions differ!
[stderr:0] WARNING:/usr/local/lib/python3.8/dist-packages/habana_frameworks/tensorflow/library_loader.py:Habana-TensorFlow(1.2.0) and Habanalabs Driver(1.3.0-e793625) versions differ!
[stderr:6] WARNING:/usr/local/lib/python3.8/dist-packages/habana_frameworks/tensorflow/library_loader.py:Habana-TensorFlow(1.2.0) and Habanalabs Driver(1.3.0-e793625) versions differ!Code language: JavaScript (javascript)

Set TF_CONFIG

TensorFlow uses the TF_CONFIG environment variable to facilitate distributed training. Define a helper function to set up the TF_CONFIG environment.

%%px

from mpi4py import MPI

BASE_TF_SERVER_PORT = 7850
SHUFFLE_BUFFER_SIZE = 10000

num_workers = MPI.COMM_WORLD.Get_size()
worker_index = MPI.COMM_WORLD.Get_rank()

def set_tf_config():
    """ Makes a TensorFlow cluster information and sets it to TF_CONFIG environment variable.
    """
    tf_config = {
        "cluster": {
            "worker": [f"localhost:{BASE_TF_SERVER_PORT + index}" for index in range(num_workers)]
        },
        "task": {"type": "worker", "index": worker_index}
    }
    tf_config_text = json.dumps(tf_config)
    os.environ["TF_CONFIG"] = tf_config_text
    print(f"TF_CONFIG = {tf_config_text}")
    return tf_config_text

Create a training function

To train the model on multiple Gaudi devices, import the HPUStrategy from habana_frameworks.tensorflow.distribute, and set the strategy to be HPUStrategy. Remember to create a model and compile it with the strategy.scope() for distributed training

%%px
def train_mnist(batch_size: int, num_epochs: int):
    """ Train the distributed model on MNIST Dataset.
    """
    # Set TF_CONFIG.
    set_tf_config()
    # Instantiate the distributed strategy class.
    # Use HPUStrategy 
    from habana_frameworks.tensorflow.distribute import HPUStrategy
    strategy = HPUStrategy()
    # Determine the total training batch size.
    batch_size_per_replica = batch_size
    total_batch_size = batch_size_per_replica * strategy.num_replicas_in_sync
    print(
        f"total_batch_size = {batch_size_per_replica} * {strategy.num_replicas_in_sync} workers = {total_batch_size}")
    # Load and preprocess the MNIST Dataset.
    # As tfds.load() may download the dataset if not cached, let the first worker do it first.
    for dataload_turn in range(2):
        if (dataload_turn == 0) == (worker_index == 0):
            print("Loading MNIST dataset...")
            datasets, info = tfds.load(
                name="mnist", with_info=True, as_supervised=True)
        MPI.COMM_WORLD.barrier()
    def preprocess(image, label):
        image = tf.cast(image, tf.float32) / 255.0
        label = tf.cast(label, tf.int32)
        return image, label
    train_dataset = datasets["train"]
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
    train_dataset = train_dataset.with_options(options)
    train_dataset = train_dataset.map(
        preprocess).cache().shuffle(SHUFFLE_BUFFER_SIZE).batch(total_batch_size)
    test_dataset = datasets["test"]
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
    test_dataset = test_dataset.with_options(options)
    test_dataset = test_dataset.map(
        preprocess).batch(total_batch_size)
    # Create and compile the distributed CNN model.
    with strategy.scope():
        model = tf.keras.Sequential([
            tf.keras.layers.Conv2D(
                32, 3, activation="relu", input_shape=(28, 28, 1)),
            tf.keras.layers.MaxPooling2D(),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(64, activation="relu"),
            tf.keras.layers.Dense(10)
        ])
        model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                      optimizer=tf.keras.optimizers.Adam(),
                      metrics=["accuracy"])
        
    # Train the model.
    print("Calling model.fit()...")
    model.fit(train_dataset, epochs=num_epochs, verbose=2)
    print("Calling model.evaluate()...")
    eval_results = model.evaluate(test_dataset, verbose=2)
    print(f"Evaluation results: {eval_results}")

Train the model

It’s time to call the training function built above to run model training.

%%px
if __name__ == "__main__":
    train_mnist(batch_size=64,num_epochs=2)
[stdout:0] TF_CONFIG = {"cluster": {"worker": ["localhost:7850", "localhost:7851", "localhost:7852", "localhost:7853", "localhost:7854", "localhost:7855", "localhost:7856", "localhost:7857"]}, "task": {"type": "worker", "index": 0}}
   local_devices=('/job:worker/task:0/device:HPU:0',)
total_batch_size = 64 * 8 workers = 512
Loading MNIST dataset...
Calling model.fit()...
Epoch 1/2
118/118 - 7s - loss: 0.4593 - accuracy: 0.8770 - 7s/epoch - 62ms/step
Epoch 2/2
118/118 - 1s - loss: 0.1515 - accuracy: 0.9567 - 922ms/epoch - 8ms/step
Calling model.evaluate()...
20/20 - 2s - loss: 0.1042 - accuracy: 0.9664 - 2s/epoch - 84ms/step
Evaluation results: [0.10418514907360077, 0.9664062261581421]
[stdout:6] TF_CONFIG = {"cluster": {"worker": ["localhost:7850", "localhost:7851", "localhost:7852", "localhost:7853", "localhost:7854", "localhost:7855", "localhost:7856", "localhost:7857"]}, "task": {"type": "worker", "index": 6}}
   local_devices=('/job:worker/task:6/device:HPU:0',)
total_batch_size = 64 * 8 workers = 512
Loading MNIST dataset...
Calling model.fit()...
Epoch 1/2
118/118 - 7s - loss: 0.4593 - accuracy: 0.8770 - 7s/epoch - 62ms/step
Epoch 2/2
118/118 - 1s - loss: 0.1515 - accuracy: 0.9567 - 922ms/epoch - 8ms/step
Calling model.evaluate()...
20/20 - 2s - loss: 0.1042 - accuracy: 0.9664 - 2s/epoch - 85ms/step
Evaluation results: [0.10418514907360077, 0.9664062261581421]
[stdout:7] TF_CONFIG = {"cluster": {"worker": ["localhost:7850", "localhost:7851", "localhost:7852", "localhost:7853", "localhost:7854", "localhost:7855", "localhost:7856", "localhost:7857"]}, "task": {"type": "worker", "index": 7}}
   local_devices=('/job:worker/task:7/device:HPU:0',)
total_batch_size = 64 * 8 workers = 512
Loading MNIST dataset...
Calling model.fit()...
Epoch 1/2
118/118 - 7s - loss: 0.4593 - accuracy: 0.8770 - 7s/epoch - 62ms/step
Epoch 2/2
118/118 - 1s - loss: 0.1515 - accuracy: 0.9567 - 922ms/epoch - 8ms/step
Calling model.evaluate()...
20/20 - 2s - loss: 0.1042 - accuracy: 0.9664 - 2s/epoch - 85ms/step
Evaluation results: [0.10418514907360077, 0.9664062261581421]
[stderr:0] 2022-01-22 08:08:42.777161: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-01-22 08:08:44.329198: W /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/habana_device/habana_device.cpp:182] HPU initialization done for library version 1.2.0_c6aea18b_tf2.7.0
2022-01-22 08:08:44.334855: W /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/habana_device/habana_device.cpp:182] HPU initialization done for library version 1.2.0_c6aea18b_tf2.7.0
2022-01-22 08:08:44.339890: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:272] Initialize GrpcChannelCache for job worker -> {0 -> localhost:7850, 1 -> localhost:7851, 2 -> localhost:7852, 3 -> localhost:7853, 4 -> localhost:7854, 5 -> localhost:7855, 6 -> localhost:7856, 7 -> localhost:7857}
2022-01-22 08:08:44.342238: I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:427] Started server with target: grpc://localhost:7850
2022-01-22 08:08:44.740452: I /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/synapse_helpers/hccl_communicator.cpp:55] Opening communication. Device id:0.
2022-01-22 08:08:57.200230: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
[stderr:6] 2022-01-22 08:08:42.777690: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-01-22 08:08:44.360132: W /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/habana_device/habana_device.cpp:182] HPU initialization done for library version 1.2.0_c6aea18b_tf2.7.0
2022-01-22 08:08:44.362736: W /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/habana_device/habana_device.cpp:182] HPU initialization done for library version 1.2.0_c6aea18b_tf2.7.0
2022-01-22 08:08:44.367844: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:272] Initialize GrpcChannelCache for job worker -> {0 -> localhost:7850, 1 -> localhost:7851, 2 -> localhost:7852, 3 -> localhost:7853, 4 -> localhost:7854, 5 -> localhost:7855, 6 -> localhost:7856, 7 -> localhost:7857}
2022-01-22 08:08:44.369731: I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:427] Started server with target: grpc://localhost:7856
2022-01-22 08:08:44.738346: I /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/synapse_helpers/hccl_communicator.cpp:55] Opening communication. Device id:0.
2022-01-22 08:08:57.200009: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
[stdout:2] TF_CONFIG = {"cluster": {"worker": ["localhost:7850", "localhost:7851", "localhost:7852", "localhost:7853", "localhost:7854", "localhost:7855", "localhost:7856", "localhost:7857"]}, "task": {"type": "worker", "index": 2}}
   local_devices=('/job:worker/task:2/device:HPU:0',)
total_batch_size = 64 * 8 workers = 512
Loading MNIST dataset...
Calling model.fit()...
Epoch 1/2
118/118 - 7s - loss: 0.4593 - accuracy: 0.8770 - 7s/epoch - 62ms/step
Epoch 2/2
118/118 - 1s - loss: 0.1515 - accuracy: 0.9567 - 922ms/epoch - 8ms/step
Calling model.evaluate()...
20/20 - 2s - loss: 0.1042 - accuracy: 0.9664 - 2s/epoch - 84ms/step
Evaluation results: [0.10418514907360077, 0.9664062261581421]
[stderr:7] 2022-01-22 08:08:42.778841: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-01-22 08:08:44.363850: W /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/habana_device/habana_device.cpp:182] HPU initialization done for library version 1.2.0_c6aea18b_tf2.7.0
2022-01-22 08:08:44.368635: W /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/habana_device/habana_device.cpp:182] HPU initialization done for library version 1.2.0_c6aea18b_tf2.7.0
2022-01-22 08:08:44.373796: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:272] Initialize GrpcChannelCache for job worker -> {0 -> localhost:7850, 1 -> localhost:7851, 2 -> localhost:7852, 3 -> localhost:7853, 4 -> localhost:7854, 5 -> localhost:7855, 6 -> localhost:7856, 7 -> localhost:7857}
2022-01-22 08:08:44.375903: I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:427] Started server with target: grpc://localhost:7857
2022-01-22 08:08:44.740024: I /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/synapse_helpers/hccl_communicator.cpp:55] Opening communication. Device id:0.
2022-01-22 08:08:57.199857: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
[stdout:4] TF_CONFIG = {"cluster": {"worker": ["localhost:7850", "localhost:7851", "localhost:7852", "localhost:7853", "localhost:7854", "localhost:7855", "localhost:7856", "localhost:7857"]}, "task": {"type": "worker", "index": 4}}
   local_devices=('/job:worker/task:4/device:HPU:0',)
total_batch_size = 64 * 8 workers = 512
Loading MNIST dataset...
Calling model.fit()...
Epoch 1/2
118/118 - 7s - loss: 0.4593 - accuracy: 0.8770 - 7s/epoch - 62ms/step
Epoch 2/2
118/118 - 1s - loss: 0.1515 - accuracy: 0.9567 - 922ms/epoch - 8ms/step
Calling model.evaluate()...
20/20 - 2s - loss: 0.1042 - accuracy: 0.9664 - 2s/epoch - 85ms/step
Evaluation results: [0.10418514907360077, 0.9664062261581421]
[stdout:5] TF_CONFIG = {"cluster": {"worker": ["localhost:7850", "localhost:7851", "localhost:7852", "localhost:7853", "localhost:7854", "localhost:7855", "localhost:7856", "localhost:7857"]}, "task": {"type": "worker", "index": 5}}
   local_devices=('/job:worker/task:5/device:HPU:0',)
total_batch_size = 64 * 8 workers = 512
Loading MNIST dataset...
Calling model.fit()...
Epoch 1/2
118/118 - 7s - loss: 0.4593 - accuracy: 0.8770 - 7s/epoch - 62ms/step
Epoch 2/2
118/118 - 1s - loss: 0.1515 - accuracy: 0.9567 - 921ms/epoch - 8ms/step
Calling model.evaluate()...
20/20 - 2s - loss: 0.1042 - accuracy: 0.9664 - 2s/epoch - 84ms/step
Evaluation results: [0.10418514907360077, 0.9664062261581421]
[stdout:1] TF_CONFIG = {"cluster": {"worker": ["localhost:7850", "localhost:7851", "localhost:7852", "localhost:7853", "localhost:7854", "localhost:7855", "localhost:7856", "localhost:7857"]}, "task": {"type": "worker", "index": 1}}
   local_devices=('/job:worker/task:1/device:HPU:0',)
total_batch_size = 64 * 8 workers = 512
Loading MNIST dataset...
Calling model.fit()...
Epoch 1/2
118/118 - 7s - loss: 0.4593 - accuracy: 0.8770 - 7s/epoch - 62ms/step
Epoch 2/2
118/118 - 1s - loss: 0.1515 - accuracy: 0.9567 - 922ms/epoch - 8ms/step
Calling model.evaluate()...
20/20 - 2s - loss: 0.1042 - accuracy: 0.9664 - 2s/epoch - 85ms/step
Evaluation results: [0.10418514907360077, 0.9664062261581421]
[stdout:3] TF_CONFIG = {"cluster": {"worker": ["localhost:7850", "localhost:7851", "localhost:7852", "localhost:7853", "localhost:7854", "localhost:7855", "localhost:7856", "localhost:7857"]}, "task": {"type": "worker", "index": 3}}
   local_devices=('/job:worker/task:3/device:HPU:0',)
total_batch_size = 64 * 8 workers = 512
Loading MNIST dataset...
Calling model.fit()...
Epoch 1/2
118/118 - 7s - loss: 0.4593 - accuracy: 0.8770 - 7s/epoch - 62ms/step
Epoch 2/2
118/118 - 1s - loss: 0.1515 - accuracy: 0.9567 - 921ms/epoch - 8ms/step
Calling model.evaluate()...
20/20 - 2s - loss: 0.1042 - accuracy: 0.9664 - 2s/epoch - 84ms/step
Evaluation results: [0.10418514907360077, 0.9664062261581421]
[stderr:2] 2022-01-22 08:08:42.778554: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-01-22 08:08:44.557263: W /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/habana_device/habana_device.cpp:182] HPU initialization done for library version 1.2.0_c6aea18b_tf2.7.0
2022-01-22 08:08:44.561464: W /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/habana_device/habana_device.cpp:182] HPU initialization done for library version 1.2.0_c6aea18b_tf2.7.0
2022-01-22 08:08:44.571647: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:272] Initialize GrpcChannelCache for job worker -> {0 -> localhost:7850, 1 -> localhost:7851, 2 -> localhost:7852, 3 -> localhost:7853, 4 -> localhost:7854, 5 -> localhost:7855, 6 -> localhost:7856, 7 -> localhost:7857}
2022-01-22 08:08:44.575023: I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:427] Started server with target: grpc://localhost:7852
2022-01-22 08:08:44.741001: I /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/synapse_helpers/hccl_communicator.cpp:55] Opening communication. Device id:0.
2022-01-22 08:08:57.200107: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
[stderr:4] 2022-01-22 08:08:42.777822: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-01-22 08:08:44.581317: W /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/habana_device/habana_device.cpp:182] HPU initialization done for library version 1.2.0_c6aea18b_tf2.7.0
2022-01-22 08:08:44.587683: W /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/habana_device/habana_device.cpp:182] HPU initialization done for library version 1.2.0_c6aea18b_tf2.7.0
2022-01-22 08:08:44.597617: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:272] Initialize GrpcChannelCache for job worker -> {0 -> localhost:7850, 1 -> localhost:7851, 2 -> localhost:7852, 3 -> localhost:7853, 4 -> localhost:7854, 5 -> localhost:7855, 6 -> localhost:7856, 7 -> localhost:7857}
2022-01-22 08:08:44.601518: I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:427] Started server with target: grpc://localhost:7854
2022-01-22 08:08:44.739501: I /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/synapse_helpers/hccl_communicator.cpp:55] Opening communication. Device id:0.
2022-01-22 08:08:57.200326: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
[stderr:1] 2022-01-22 08:08:42.776619: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-01-22 08:08:44.605777: W /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/habana_device/habana_device.cpp:182] HPU initialization done for library version 1.2.0_c6aea18b_tf2.7.0
2022-01-22 08:08:44.608621: W /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/habana_device/habana_device.cpp:182] HPU initialization done for library version 1.2.0_c6aea18b_tf2.7.0
2022-01-22 08:08:44.618580: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:272] Initialize GrpcChannelCache for job worker -> {0 -> localhost:7850, 1 -> localhost:7851, 2 -> localhost:7852, 3 -> localhost:7853, 4 -> localhost:7854, 5 -> localhost:7855, 6 -> localhost:7856, 7 -> localhost:7857}
2022-01-22 08:08:44.620449: I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:427] Started server with target: grpc://localhost:7851
2022-01-22 08:08:44.740687: I /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/synapse_helpers/hccl_communicator.cpp:55] Opening communication. Device id:0.
2022-01-22 08:08:57.200290: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
[stderr:5] 2022-01-22 08:08:42.777927: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-01-22 08:08:44.601381: W /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/habana_device/habana_device.cpp:182] HPU initialization done for library version 1.2.0_c6aea18b_tf2.7.0
2022-01-22 08:08:44.606835: W /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/habana_device/habana_device.cpp:182] HPU initialization done for library version 1.2.0_c6aea18b_tf2.7.0
2022-01-22 08:08:44.611995: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:272] Initialize GrpcChannelCache for job worker -> {0 -> localhost:7850, 1 -> localhost:7851, 2 -> localhost:7852, 3 -> localhost:7853, 4 -> localhost:7854, 5 -> localhost:7855, 6 -> localhost:7856, 7 -> localhost:7857}
2022-01-22 08:08:44.614041: I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:427] Started server with target: grpc://localhost:7855
2022-01-22 08:08:44.740553: I /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/synapse_helpers/hccl_communicator.cpp:55] Opening communication. Device id:0.
2022-01-22 08:08:57.199882: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
[stderr:3] 2022-01-22 08:08:42.777006: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-01-22 08:08:44.615701: W /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/habana_device/habana_device.cpp:182] HPU initialization done for library version 1.2.0_c6aea18b_tf2.7.0
2022-01-22 08:08:44.620038: W /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/habana_device/habana_device.cpp:182] HPU initialization done for library version 1.2.0_c6aea18b_tf2.7.0
2022-01-22 08:08:44.627451: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:272] Initialize GrpcChannelCache for job worker -> {0 -> localhost:7850, 1 -> localhost:7851, 2 -> localhost:7852, 3 -> localhost:7853, 4 -> localhost:7854, 5 -> localhost:7855, 6 -> localhost:7856, 7 -> localhost:7857}
2022-01-22 08:08:44.630222: I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:427] Started server with target: grpc://localhost:7853
2022-01-22 08:08:44.771511: I /home/jenkins/workspace/cdsoftwarebuilder/create-tensorflow-module---bpt-d/tensorflow-training/synapse_helpers/hccl_communicator.cpp:55] Opening communication. Device id:0.
2022-01-22 08:08:57.200331: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.Code language: JavaScript (javascript)

Training has been done! Remember to shutdown the mpi engines to release resources.

Traing using model.fit. Each MPI engine will start the training on a different Gaudi device in parallel.


client.shutdown(hub=True)
Controller stopped: {'exit_code': 0, 'pid': 9307, 'identifier': 'ipcontroller-1642838911-z6z6-9290'}
engine set stopped 1642838912: {'exit_code': 0, 'pid': 9402, 'identifier': 'ipengine-1642838911-z6z6-1642838912-9290'}Code language: JavaScript (javascript)

Licensed under the Apache License, Version 2.0 (the “License”);

you may not use this file except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.cense.

Share this article:

Stay Informed: Register for the latest Intel Gaudi AI Accelerator developer news, events, training, and updates.