Home » Habana Developer Blog » Fine-Tuning Llama2-70B with DeepSpeed ZeRO-3 and Low-Rank Adaptation (LoRA) on Intel® Gaudi®2 AI Accelerator

Fine-Tuning Llama2-70B with DeepSpeed ZeRO-3 and Low-Rank Adaptation (LoRA) on Intel® Gaudi®2 AI Accelerator

With the Intel Gaudi SynapseAI 1.13.0 release, users can run Fine Tune the Llama2 70B model using only 8 Gaudi2 Accelerators.

Fine-tuning large language models (LLMs) with billions of parameters such as Llama2-70B is a challenging task that demands huge memory and high computational resources. At bfloat16 precision, a single model parameter requires 2 bytes of memory. Thus, simply loading 70-billion parameters of Llama2-70B will require 140GB of device memory. Additionally, more memory is required to accommodate optimizer states and gradients of the model during the training process.

In this article, we will explore how to fine-tune Llama2-70B with DeepSpeed ZeRO-3 and Low-Rank Adaptation (LoRA) techniques on 8x Intel Gaudi2 AI accelerators.

DeepSpeed ZeRO-3 Optimization

DeepSpeed is a deep learning optimization library that enables the scaling of models for training and inference. The Zero Redundancy Optimizer (ZeRO) is a memory optimization technique within DeepSpeed, comprised of three optimization stages. Stage 3 of ZeRO (ZeRO-3) optimization reduces memory consumption in distributed training by partitioning optimizer states, gradients, and model parameters across the worker processes. The following diagram shows that each worker possesses only a subset of the parameters. In preparation for executing the forward or backward pass, the necessary parameters are made available using communication collective operations just before the execution. After execution, parameters that are no longer needed until the subsequent forward or backward pass are removed. Moreover, in the parameter update phase, each worker is responsible for updating only the optimizer states corresponding to the parameters assigned to it.

DeepSpeed ZeRO-3 Optimization

Although DeepSpeed ZeRO-3 optimization can significantly reduce memory usage, full parameter fine-tuning of Llama2-70B, even on 8x Gaudi2 cards, is still impossible. For a model with 70-billion parameters, the total memory requirements are approximately 1.1TB (140GB per Gaudi2 card on HLS-2 server): loading model parameters in BF16 precision consumes 140GB (2 Bytes * 70B), gradients in BF16 precision require 140GB (2 Bytes * 70B), and the optimizer states (parameters, momentum of the gradients, and variance of the gradients) of Adam optimizer in FP32 occupy 840GB (3 * 4 Bytes * 70B). Thus, we also introduced a Parameter-Efficient Fine-Tuning (PEFT) method to fine-tune only a subset of parameters to reduce resource utilization.

Parameter-Efficient Fine-Tuning with Low-Rank Adaptation (LoRA)

Parameter-Efficient Fine-Tuning (PEFT) is a cost-effective solution to the resource-intensive fine-tuning of large language models. It fine-tunes only a small number of model parameters, adapting the pre-trained model for a specific downstream task instead of fine-tuning the entire model. Low-Rank Adaptation (LoRA) is one of the most used methods among the various techniques of PEFT. LoRA dramatically reduces the number of trainable parameters by freezing the pre-trained model weights and performing weight updates with low-rank matrices. This is because the fine-tuning of pre-trained weights can be represented as a sum of the pre-trained weight (W0) and the accumulated gradient update (ΔW), which can be decomposed into two low-rank matrices, A and B.

W’ = W0 + ΔW = W0 + BA
W’: weight matrix after fine-tuning, ∈Rd×k
W0: pre-trained weight matrix, ∈Rd×k
ΔW: accumulated gradient update of W0 during fine-tuning, ∈Rd×k
A, B: trainable low-rank matrices, B∈Rd×r,A∈Rr×k where r ≪min⁡(d,k)

In the forward pass, input features are multiplied with both pre-trained weight W0 and accumulated gradient update ΔW= BA. Then, their outputs are added to yield the results. During the backward pass, A and B receive gradient updates while the pre-trained full-rank weights remain frozen.  

Parameter-Efficient Fine-Tuning with Low-Rank Adaptation (LoRA)

Fine-Tuning Llama2-70B on 8x Gaudi2 Cards with ZeRO-3 and LoRA

In the Gaudi SynapseAI 1.13.0 release, we enabled Llama2-70B fine-tuning on 8x Gaudi2 cards with DeepSpeed ZeRO-3 optimization and LoRA. To improve the model’s training performance, we added support for running the softmax in the attention layer in bfloat16 precision without compromising the accuracy of the outputs. Furthermore, memory consumption with DeepSpeed ZeRO-3 has been optimized by constraining the internal graph size and adding synchronization points. The PT_HPU_MAX_COMPOUND_OP_SIZE and DEEPSPEED_HPU_ZERO3_SYNC_MARK_STEP_REQUIRED environment variables are switches to enable the optimization, used along with the command.

To apply DeepSpeed ZeRO-3 optimization to the fine-tuning on Intel Gaudi 2, the “stage” is set to 3, and “overlap_comm” and “contiguous_gradients” are configured to “false” within a dictionary under the “zero_optimization” entry. These DeepSpeed settings are configured in a .json file format; and for this example, the .json file is already pre-loaded to Optimum-Habana GitHub repository (llama2_ds_zero3_config.json) and included in the runtime command below.

For LoRA, we injected the trainable low-rank matrices to “q_proj,” “k_proj,” “v_proj,” and “o_proj” modules and used LoRA rank of 4, LoRA α of 16, and dropout probability of 0.05 for the LoRA configurations.

In this example, we fine-tuned Llama2-70B with Alpaca dataset for 2 epochs to converge, using a local batch size of 10 and a maximum sequence length of 2048. Please note that the training batch size of 10 was selected for improved accuracy, not for maximizing memory usage. A larger batch size can also fit in the device memory, but the Alpaca dataset results in a smaller number of weight updates per epoch, therefore making it more challenging to achieve convergence. We delivered Llama2-70B fine-tuning example to Optimum Habana GitHub repository. Optimum-Habana is an interface between the HuggingFace Transformers library and the Intel Gaudi AI Accelerator. To run the example, pull the docker image from the Habana Vault, and then clone the Optimum-Habana repository and install Optimum-Habana from the cloned repository inside the docker container. Also, install Habana DeepSpeed and the dependent Python packages required for Llama2-70B fine-tuning.

You will need to authenticate your Hugging Face account to be able to download the Llama 2 Model. See the Addendum below for more details.

pip install git+https://github.com/HabanaAI/DeepSpeed.git @1.14.0
pip install optimum-habana==1.10.0
git clone https://github.com/huggingface/optimum-habana.git
cd optimum-habana/
git checkout v1.10.0
cd examples/language-modeling
pip install -r requirements.txt
huggingface-cli login --token

To execute the fine-tuning example using 8x Gaudi2 accelerators, go to optimum-habana/examples/language-modeling directory, and run the following command:

PT_HPU_MAX_COMPOUND_OP_SIZE=10 DEEPSPEED_HPU_ZERO3_SYNC_MARK_STEP_REQUIRED=1 \
python3 ../gaudi_spawn.py --use_deepspeed  --world_size 8  run_lora_clm.py \
  --model_name_or_path meta-llama/Llama-2-70b-hf \
  --deepspeed llama2_ds_zero3_config.json \
  --dataset_name tatsu-lab/alpaca \
  --bf16 True \
  --output_dir ./lora_out \
  --num_train_epochs 2 \
  --max_seq_len 2048 \
  --per_device_train_batch_size 10 \
  --per_device_eval_batch_size 10 \
  --gradient_checkpointing \
  --evaluation_strategy epoch \
  --eval_delay 2 \
  --save_strategy no \
  --learning_rate 0.0018 \
  --warmup_ratio 0.03 \
  --lr_scheduler_type "cosine" \
  --logging_steps 1 \
  --dataset_concatenation \
  --attn_softmax_bf16 True \
  --do_train \
  --do_eval \
  --use_habana \
  --use_lazy_mode \
  --pipelining_fwd_bwd \
  --throughput_warmup_steps 3 \
  --lora_rank 4 \
  --lora_target_modules "q_proj" "v_proj" "k_proj" "o_proj" \
  --validation_split_percentage 4

It takes approximately 44 minutes to fine-tune Llama2-70B on 8x Gaudi2 cards for 2 epochs to converge.

Conclusion

In this blog, we showed how we enabled Llama2-70B fine-tuning on 8x Intel Gaudi2 AI accelerators by applying DeepSpeed ZeRO-3 optimization and LoRA technique. While the example in this blog primarily focuses on Llama2-70B, these methodologies are widely applicable to other LLMs.

We are continuously working to improve the performance of Llama2-70B as well as other popular LLMs in upcoming releases. Stay tuned for more release updates which generally occur on a 6 to 8 week cadence.

Reference

  1. Rajbhandari et al., “ZeRO: Memory Optimizations Toward Training Trillion Parameter Models”, arXiv:1910.02054
  2. Hu et al., “LoRA: Low-Rank Adaptation of Large Language Models”, arXiv:2106.09685
  3. https://developer.habana.ai/blog/memory-efficient-training-on-habana-gaudi-with-deepspeed/

Addendum: How to Access and Use the Llama 2 model

Use of the pretrained model is subject to compliance with third party licenses, including the “Llama 2 Community License Agreement” (LLAMAV2). For guidance on the intended use of the LLAMA2 model, what will be considered misuse and out-of-scope uses, who are the intended users and additional terms please review and read the instructions in this link https://ai.meta.com/llama/license/. Users bear sole liability and responsibility to follow and comply with any third party licenses, and Habana Labs disclaims and will bear no liability with respect to users’ use or compliance with third party licenses.

To be able to run gated models like this Llama-2-70b-hf, you need the following:

  • Have a HuggingFace account
  • Agree to the terms of use of the model in its model card on the HF Hub
  • set a read token
  • Login to your account using the HF CLI: run huggingface-cli login before launching your script
Share this article:
Stay Informed: Register for the latest Intel Gaudi AI Accelerator developer news, events, training, and updates.