Home » Habana Developer Blog » Leveraging Intel Gaudi for Distributed Training with FSDP

Leveraging Intel Gaudi for Distributed Training with FSDP

Learn how to execute scalable model development with Fully sharded data parallel (FSDP) training using PyTorch and Intel Gaudi Accelerators

The pursuit of ubiquitous and accessible AI compute is a persistent struggle for developers across the spectrum. Intel® Gaudi® Accelerators are helping change this narrative by providing a price/performant alternative to high-performance GPUs like NVIDIA’s A100 and H100 processors. Not only is the hardware performance there, but so is the software support. 

In abstraction, Gaudi supports and has optimizations for major frameworks like PyTorch, Hugging Face, DeepSpeed, and others. Specifically for model training, Gaudi supports techniques like parameter efficient fine-tuning (PEFT) with Low-Rank Adaptations (LoRA), ZeRO 3D parallelism, and now fully sharded data parallel (FSDP) from PyTorch. These integrations and optimizations represent Gaudi’s commitment to the open-source developer ecosystem.

The more recent enablement of FSDP gives developers yet another powerful tool for developing GenAI models at scale.

What is FSDP?

Fully sharded data parallel (FSDP) is a type of PyTorch data-parallel training that is now supported by Intel Gaudi AI accelerators, allowing for large-scale model distributed training. 

In the legacy PyTorch parallel training technique called distributed data parallel (DDP), the model weights and optimizer states are replicated across all workers — creating a larger memory footprint during training. FSDP shards model parameters, optimizer states, and gradients across all ranks, reducing Intel Gaudi’s memory footprint, as each worker no longer requires a full model copy. 

Figure 1 Diagram demonstrating the difference between FSDP and DDP. Refer to FSDP Theory of Operations for more information. - Image by Author

Intel Gaudi can be integrated with FSDP using eager mode and torch.compile, offering a seamless and efficient approach to large-scale model training. This can be considered an alternative to the popular DeepSpeed distributed training functionality.

Multi-card finetuning with FSDP

Similar to other Intel Gaudi workflows, to enable FSDP, we use the popular Optimum Habana library to fine-tune Llama2–70B with FSDP and LoRA. Optimum Habana is an interface between 🤗 Transformers, Diffusers, and the Intel Gaudi AI Accelerator. This toolset simplifies model loading, training, and inference, supporting single and multi-card configurations for diverse downstream tasks.

Environment Setup

Before getting started, you will need to set up your environment. 

  1. You will want to start by provisioning an Intel Gaudi instance from the Intel® Tiber® Developer Cloud. See Getting Access in the Developer webpage for more information.
  2. After that, you will need to follow these instructions to use the Intel Gaudi PyTorch Docker image. Although it is possible to use Gaudi without containers, we find it to be the most straightforward way to get started. 
  3. From inside the container, install the Optimum Habana Library and clone the Optimum Habana GitHub repository
    • pip install optimum-habana==1.11.0
    • cd ~
    •  git clone -b v1.11.0 https://github.com/huggingface/optimum-habana.git
  4. Navigate to optimum-habana/examples/language-modeling folder and run pip install -r requirements.txt to install the relevant dependencies. 

An Example of FSDP on Intel Gaudi

Now that the environment is setup, lets let’s take a deep dive into the relevant Gaudi code using the simplified Gaudi FSDP implementation below:

import os
import torch
import argparse
import torch.nn as nn
from torch.nn import Linear
from torch.optim import SGD
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

import habana_frameworks.torch.distributed.hccl

os.environ["PT_HPU_LAZY_MODE"] = "0"

device_hpu = torch.device('hpu')


def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    import habana_frameworks.torch.distributed.hccl
    dist.init_process_group(backend='hccl', rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.lin1 = Linear(3, 3, bias=False)
    def forward(self, x):
        return (self.lin1(x))

def simple_demo(rank, world_size, args):
    setup(rank, world_size)

    model = ToyModel().to(device_hpu)

    input = torch.rand(8, 3)


    model = FSDP(model, device_id = device_hpu)

    optim = SGD(model.parameters(), lr=0.1)
    in_data = torch.Tensor(input[rank]).to(device_hpu)
    for i in range(args.iterations):
        out = model(in_data)
        out.float().sum().backward()
        optim.step()
        optim.zero_grad()
    cleanup()
    print("All done")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Simple_demo Example')
    parser.add_argument('--iterations', type=int, default=5, metavar='I',  help='iterations (default: 5)')
    parser.add_argument('--ws', type=int, default=2, help='world size (default: 2)')
    parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbosity")

    args = parser.parse_args()
    if args.verbose:
        os.environ["TORCH_CPP_LOG_LEVEL"]="INFO"
        os.environ["TORCH_DISTRIBUTED_DEBUG"]="DETAIL"
        os.environ["TORCH_SHOW_CPP_STACKTRACES"]="1"
        torch._dynamo.config.verbose=True
    WORLD_SIZE = args.ws
    mp.spawn(simple_demo,
        args=(WORLD_SIZE, args),
        nprocs=WORLD_SIZE,
        join=True)

The Gaudi relevant changes to the code are: 

  • Line 10 — Import import habana_frameworks.torch.distributed.hccl
  • Line 12 — Enable Eager mode: os.environment["PT_HPU_LAZY_MODE"] = "0"
  • Line 13 — Target the Gaudi device device_hpu = torch.device('hpu')
  • Line 33 — Target the Gaudi device for the model execution model = ToyModel().to(device_hpu) 
  • Line 36 — Target FSDP execution with Gaudi model = FSDP(model,device_id = device+hpu) 

As you can see, there are minimal code changes required. Line 36 is where we enable FSDP. This will partition the model’s parameters across different Gaudi cards, allowing each Gaudi to manage only a fraction of the total parameters. As previously mentioned, this is essential for significantly reducing the memory requirements per card and enabling scaling to larger models and more Gaudis. Note that if you want to use all 8 Gaudi cards, you must provide the —ws 8 (world size) flag. 

To run this example, copy the above code into a fsdp_example.py file and the command `python3 fsdp_example.py`

Multi-card finetuning of Llama2–70B with FSDP and LoRA

Once you’ve mastered this simple implementation, we can move on to a more comprehensive example from the Optimum Habana repository. Many examples in the repository come with pre-packaged convenience scripts, utilities, and configuration files. Typically, you just need to adjust a few parameters and run a command-line Python argument to start. 

Go to the language-modeling folder in the Optimum Habana examples, install the requirements and run the following command line script.  You will notice this is using mpi instead of DeepSpeed, to execute the multi-card finetuning of Llama2–70B with FSDP and LoRA:

  • cd ~/optimum-habana/examples/language-modeling
  • pip install -r requirements.txt
LOWER_LIST=ops_bf16.txt PT_HPU_LAZY_MODE=0 \
python3 ../gaudi_spawn.py --world_size 8 --use_mpi run_lora_clm.py \
  --model_name_or_path meta-llama/Llama-2-70b-hf \
  --dataset_name tatsu-lab/alpaca \
  --bf16 True \
  --output_dir ./lora_out \
  --max_seq_len 2048 \
  --gradient_checkpointing \
  --per_device_train_batch_size 5 \
  --save_strategy no \
  --learning_rate 0.0004 \
  --warmup_ratio 0.03 \
  --lr_scheduler_type "constant" \
  --logging_steps 1 \
  --dataset_concatenation \
  --do_train \
  --use_habana \
  --throughput_warmup_steps 3 \
  --lora_rank 4 \
  --lora_target_modules "q_proj" "v_proj" "k_proj" "o_proj" \
  --attn_softmax_bf16 True \
  --validation_split_percentage 4 \
  --use_lazy_mode False \
  --fsdp_config fsdp_config.json \
  --fsdp auto_wrap \
  --num_train_epochs 2 \
  --evaluation_strategy epoch \
  --per_device_eval_batch_size 1 \
  --eval_delay 2 \
  --do_eval \
  --pipelining_fwd_bwd False \
  --use_fused_rope False \
  --torch_compile_backend hpu_backend \
  --torch_compile \
  --gradient_accumulation_steps 2

The FSDP relevant changes arguments are:

  • --fsdp auto_wrap the “auto_wrap” FSDP policy provides a memory and computationally efficient method that prevents all models parameters from being loaded simultaneously during FSDP fine-tuning. This creates multiple collection of layers called “FSDP units” which load parameters independently from one another. 
  • --fsdp_config fsdp_config.json  provides configuration parameters for FSDP. The full list of supported features are in the FSDP documentation.  Here are a few of the most important:
    “BACKWARD_PRE”: requests FSDP unit parameters before the computation of the current unit begins. 
    “TRANSFORMER_BASED_WRAP”: allows for user-defined wrapping policy
    – “1”: is the default FSDP policy commonly referred to as “FULL_SHARD”
    “GaudiLlamaDecoderLayer”: user-defined FSDP layer to wrap
{
    "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
    "fsdp_backward_prefetch": "BACKWARD_PRE",
    "fsdp_forward_prefetch": false,
    "fsdp_offload_params": false,
    "fsdp_sharding_strategy": 1,
    "fsdp_state_dict_type": "FULL_STATE_DICT",
    "fsdp_sync_module_states": true,
    "fsdp_use_orig_params": true,
    "transformer_layer_cls_to_wrap": "GaudiLlamaDecoderLayer",
    "fsdp_activation_checkpointing": false
}

Conclusion and Discussion

In this article, we’ve explored the integration of Intel Gaudi accelerators with Fully Sharded Data Parallel (FSDP) to enhance the efficiency and scalability of large model training. The practical steps outlined, from environment setup to detailed code examples, showcase how to leverage Intel’s hardware alongside cutting-edge model training techniques. By reducing the memory footprint through techniques such as parameter sharding and enabling efficient distributed training, developers can now more effectively tackle large-scale GenAI model development. 

Next Steps

  • Expanding FSDP Usage with Gaudi Accelerators: Integrate FSDP into more complex AI models and training scenarios. 
  • Deep Dive into Fine-Tuning Techniques: This could include varying the hyperparameters, using different fine-tuning strategies like LoRA, and comparing the performance outcomes
Share this article:
Stay Informed: Register for the latest Intel Gaudi AI Accelerator developer news, events, training, and updates.