Effective-tuning Llama 2 70B using PyTorch FSDP

-


On this blog post, we are going to have a look at methods to fine-tune Llama 2 70B using PyTorch FSDP and related best practices. We can be leveraging Hugging Face Transformers, Speed up and TRL. We may even learn methods to use Speed up with SLURM.

Fully Sharded Data Parallelism (FSDP) is a paradigm during which the optimizer states, gradients and parameters are sharded across devices. Throughout the forward pass, each FSDP unit performs an all-gather operation to get the entire weights, computation is performed followed by discarding the shards from other devices. After the forward pass, the loss is computed followed by the backward pass. Within the backward pass, each FSDP unit performs an all-gather operation to get the entire weights, with computation performed to get the local gradients. These local gradients are averaged and sharded across the devices via a reduce-scatter operation in order that each device can update the parameters of its shard. For more information on what PyTorch FSDP is, please seek advice from this blog post: Speed up Large Model Training using PyTorch Fully Sharded Data Parallel.

FSDP Workflow

(Source: link)



Hardware Used

Variety of nodes: 2. Minimum required is 1.
Variety of GPUs per node: 8
GPU type: A100
GPU memory: 80GB
intra-node connection: NVLink
RAM per node: 1TB
CPU cores per node: 96
inter-node connection: Elastic Fabric Adapter



Challenges with fine-tuning LLaMa 70B

We encountered three most important challenges when attempting to fine-tune LLaMa 70B with FSDP:

  1. FSDP wraps the model after loading the pre-trained model. If each process/rank inside a node loads the Llama-70B model, it might require 70*4*8 GB ~ 2TB of CPU RAM, where 4 is the variety of bytes per parameter and eight is the variety of GPUs on each node. This is able to lead to the CPU RAM getting out of memory resulting in processes being terminated.

  2. Saving entire intermediate checkpoints using FULL_STATE_DICT with CPU offloading on rank 0 takes lots of time and infrequently leads to NCCL Timeout errors attributable to indefinite hanging during broadcasting. Nonetheless, at the tip of coaching, we would like the entire model state dict as an alternative of the sharded state dict which is barely compatible with FSDP.

  3. We want to enhance the speed and reduce the VRAM usage to coach faster and save compute costs.

Let’s have a look at methods to solve the above challenges and fine-tune a 70B model!

Before we start, here’s all of the required resources to breed our results:

  1. Codebase:
    https://github.com/pacman100/DHS-LLM-Workshop/tree/most important/chat_assistant/sft/training with flash-attn V2

  2. FSDP config: https://github.com/pacman100/DHS-LLM-Workshop/blob/most important/chat_assistant/training/configs/fsdp_config.yaml

  3. SLURM script launch.slurm: https://gist.github.com/pacman100/1cb1f17b2f1b3139a63b764263e70b25

  4. Model: meta-llama/Llama-2-70b-chat-hf

  5. Dataset: smangrul/code-chat-assistant-v1 (mixture of LIMA+GUANACO with proper formatting in a ready-to-train format)



Pre-requisites

First follow these steps to put in Flash Attention V2: Dao-AILab/flash-attention: Fast and memory-efficient exact attention (github.com). Install the most recent nightlies of PyTorch with CUDA ≥11.8. Install the remaining requirements as per DHS-LLM-Workshop/code_assistant/training/requirements.txt. Here, we can be installing 🤗 Speed up and 🤗 Transformers from the most important branch.



Effective-Tuning



Addressing Challenge 1

PRs huggingface/transformers#25107 and huggingface/speed up#1777 solve the primary challenge and requires no code changes from user side. It does the next:

  1. Create the model with no weights on all ranks (using the meta device).
  2. Load the state dict only on rank==0 and set the model weights with that state dict on rank 0
  3. For all other ranks, do torch.empty(*param.size(), dtype=dtype) for each parameter on meta device
  4. So, rank==0 can have loaded the model with correct state dict while all other ranks can have random weights.
  5. Set sync_module_states=True in order that FSDP object takes care of broadcasting them to all of the ranks before training starts.

Below is the output snippet on a 7B model on 2 GPUs measuring the memory consumed and model parameters at various stages. We will observe that in loading the pre-trained model rank 0 & rank 1 have CPU total peak memory of 32744 MB and 1506 MB , respectively. Subsequently, only rank 0 is loading the pre-trained model resulting in efficient usage of CPU RAM. The entire logs at be found here

accelerator.process_index=0 GPU Memory before entering the loading : 0
accelerator.process_index=0 GPU Memory consumed at the tip of the loading (end-begin): 0
accelerator.process_index=0 GPU Peak Memory consumed through the loading (max-begin): 0
accelerator.process_index=0 GPU Total Peak Memory consumed through the loading (max): 0
accelerator.process_index=0 CPU Memory before entering the loading : 926
accelerator.process_index=0 CPU Memory consumed at the tip of the loading (end-begin): 26415
accelerator.process_index=0 CPU Peak Memory consumed through the loading (max-begin): 31818
accelerator.process_index=0 CPU Total Peak Memory consumed through the loading (max): 32744

accelerator.process_index=1 GPU Memory before entering the loading : 0
accelerator.process_index=1 GPU Memory consumed at the tip of the loading (end-begin): 0
accelerator.process_index=1 GPU Peak Memory consumed through the loading (max-begin): 0
accelerator.process_index=1 GPU Total Peak Memory consumed through the loading (max): 0
accelerator.process_index=1 CPU Memory before entering the loading : 933
accelerator.process_index=1 CPU Memory consumed at the tip of the loading (end-begin): 10
accelerator.process_index=1 CPU Peak Memory consumed through the loading (max-begin): 573
accelerator.process_index=1 CPU Total Peak Memory consumed through the loading (max): 1506



Addressing Challenge 2

It’s addressed via selecting SHARDED_STATE_DICT state dict type when creating FSDP config. SHARDED_STATE_DICT saves shard per GPU individually which makes it quick to avoid wasting or resume training from intermediate checkpoint. When FULL_STATE_DICT is used, first process (rank 0) gathers the entire model on CPU after which saving it in an ordinary format.

Let’s create the speed up config via below command:

speed up config --config_file "fsdp_config.yaml"

fsdp_config

The resulting config is accessible here: fsdp_config.yaml. Here, the sharding strategy is FULL_SHARD. We’re using TRANSFORMER_BASED_WRAP for auto wrap policy and it uses _no_split_module to search out the Transformer block name for nested FSDP auto wrap. We use SHARDED_STATE_DICT to avoid wasting the intermediate checkpoints and optimizer states on this format advisable by the PyTorch team. Make sure that to enable broadcasting module parameters from rank 0 initially as mentioned within the above paragraph on addressing Challenge 1. We’re enabling bf16 mixed precision training.

For final checkpoint being the entire model state dict, below code snippet is used:

if trainer.is_fsdp_enabled:
    trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")

trainer.save_model(script_args.output_dir) 



Addressing Challenge 3

Flash Attention and enabling gradient checkpointing are required for faster training and reducing VRAM usage to enable fine-tuning and save compute costs. The codebase currently uses monkey patching and the implementation is at chat_assistant/training/llama_flash_attn_monkey_patch.py.

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness introduces a technique to compute exact attention while being faster and memory-efficient by leveraging the knowledge of the memory hierarchy of the underlying hardware/GPUs – The upper the bandwidth/speed of the memory, the smaller its capability because it becomes dearer.

If we follow the blog Making Deep Learning Go Brrrr From First Principles, we will determine that Attention module on current hardware is memory-bound/bandwidth-bound. The rationale being that Attention mostly consists of elementwise operations as shown below on the left hand side. We will observe that masking, softmax and dropout operations take up the majority of the time as an alternative of matrix multiplications which consists of the majority of FLOPs.

Attention Bottlenecks

(Source: link)

That is precisely the issue that Flash Attention addresses. The concept is to remove redundant HBM reads/writes. It does so by keeping every little thing in SRAM, perform all of the intermediate steps and only then write the end result back to HBM, also often known as Kernel Fusion. Below is an illustration of how this overcomes the memory-bound bottleneck.
kernel_fusion

(Source: link)

Tiling is used during forward and backward passes to chunk the NxN softmax/scores computation into blocks to beat the limitation of SRAM memory size. To enable tiling, online softmax algorithm is used. Recomputation is used during backward pass to be able to avoid storing the whole NxN softmax/rating matrix during forward pass. This greatly reduces the memory consumption.

For a simplified and in depth understanding of Flash Attention, please refer the blog posts ELI5: FlashAttention and Making Deep Learning Go Brrrr From First Principles together with the unique paper FlashAttention: Fast and Memory-Efficient Exact Attention
with IO-Awareness
.



Bringing it all-together

To run the training using Speed up launcher with SLURM, refer this gist launch.slurm. Below is an equivalent command showcasing methods to use Speed up launcher to run the training. Notice that we’re overriding main_process_ip , main_process_port , machine_rank , num_processes and num_machines values of the fsdp_config.yaml. Here, one other necessary point to notice is that the storage is stored between all of the nodes.

speed up launch 
    --config_file configs/fsdp_config.yaml 
    --main_process_ip $MASTER_ADDR 
    --main_process_port $MASTER_PORT 
    --machine_rank $MACHINE_RANK 
    --num_processes 16 
    --num_machines 2 
    train.py 
    --seed 100 
    --model_name "meta-llama/Llama-2-70b-chat-hf" 
    --dataset_name "smangrul/code-chat-assistant-v1" 
    --chat_template_format "none" 
    --add_special_tokens False 
    --append_concat_token False 
    --splits "train,test" 
    --max_seq_len 2048 
    --max_steps 500 
    --logging_steps 25 
    --log_level "info" 
    --eval_steps 100 
    --save_steps 250 
    --logging_strategy "steps" 
    --evaluation_strategy "steps" 
    --save_strategy "steps" 
    --push_to_hub 
    --hub_private_repo True 
    --hub_strategy "every_save" 
    --bf16 True 
    --packing True 
    --learning_rate 5e-5 
    --lr_scheduler_type "cosine" 
    --weight_decay 0.01 
    --warmup_ratio 0.03 
    --max_grad_norm 1.0 
    --output_dir "/shared_storage/sourab/experiments/full-finetune-llama-chat-asst" 
    --per_device_train_batch_size 1 
    --per_device_eval_batch_size 1 
    --gradient_accumulation_steps 1 
    --gradient_checkpointing True 
    --use_reentrant False 
    --dataset_text_field "content" 
    --use_flash_attn True 
    --ddp_timeout 5400 
    --optim paged_adamw_32bit 

Effective-tuning accomplished in ~13.5 hours and below is the training loss plot.

train_loss

Vital Note:
Here, should you are using only a single node with 8 A100 80GB GPUs, then paged_adamw_32bit optimizer from bitsandbytes is required. For more details about paged optimizers, please refer https://huggingface.co/docs/bitsandbytes/most important/en/optimizers#paged-optimizers.

Below is an example conversation using the above model:

System Prompt: You're a helpful, respectful and honest assistant. At all times answer as helpfully 
as possible, while being secure. Your answers shouldn't include any harmful, 
unethical, racist, sexist, toxic, dangerous, or illegal content. Please be sure that 
your responses are socially unbiased and positive in nature.

If a matter doesn't make any sense, or just isn't factually coherent, explain why 
as an alternative of answering something not correct. In the event you don’t know the reply to a 
query, please don’t share false information.

- Human: What's Deep Learning? Explain like a Pirate.

+ Assistant: Arrr Grumete! Are ye lookin' fer a correct explanation o' Deep Learnin'? 
+ Alright then, listen up me hearty! Deep Learnin' be a sort o' machine learnin' that uses artificial neural networks to investigate data. 
+ Aye, ye heard that right, mateys! Artificial neural networks be models o' the human brain, but as an alternative o' using neurons, they use interconnected nodes or "artificial neurons" to process information. 
+ These nodes work together to form layers o' processing units that may learn from large amounts o' data. 
+ The more data these systems receive, the higher they get at recognizing patterns and makin' decisions on their very own. 
+ So hoist the sails and set course for adventure with deep learnin', me hearties! It is a treasure trove o' knowledge just waitin' to be discovered. 
+ Fair winds and following seas!

- Human: Now explain it like a chef.

+ Assistant: Actually! Here's a proof of deep learning from a culinary perspective:

+ Consider a recipe as a sequence of steps used to rework raw ingredients right into a delicious dish. 
+ Similarly, in deep learning, there are multiple layers of "ingredients" (or features) which might be combined and transformed through various operations to provide a final output or prediction. 
+ Each layer builds upon the previous one, refining and enhancing the representation of the input data until it reaches the specified end result. 
+ Just as different spices enhance the flavors of food, each layer adds depth and complexity to the model's understanding of the input, allowing it to make more accurate predictions. 
+ And very like how chefs must fastidiously balance the quantity of salt, sugar, and other seasonings in a recipe, the training process in deep learning involves adjusting the weights and biases of the network to attain the proper mix of accuracy and efficiency. 
+ With enough practice and experimentation, anyone can turn out to be a master chef... I mean, a master of deep learning! Bon appétit!

The entire conversation is formatted as below:

<|system|> system message <|endoftext|> <|prompter|> Q1 <|endoftext|> <|assistant|> A1 <|endoftext|> ...



Conclusion

We successfully fine-tuned 70B Llama model using PyTorch FSDP in a multi-node multi-gpu setting while addressing various challenges. We saw how 🤗 Transformers and 🤗 Accelerates now supports efficient way of initializing large models when using FSDP to beat CPU RAM getting out of memory. This was followed by advisable practices for saving/loading intermediate checkpoints and methods to save the ultimate model in a technique to readily use it. To enable faster training and reducing GPU memory usage, we outlined the importance of Flash Attention and Gradient Checkpointing. Overall, we will see how a straightforward config using 🤗 Speed up enables finetuning of such large models in a multi-node multi-gpu setting.



Source link

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x