Training with packed instruction tuning examples (without padding) is now compatible with Flash Attention 2 in Hugging Face, due to a recent PR and the brand new DataCollatorWithFlattening
It could possibly provide as much as 2x improvement in training throughput while maintaining convergence quality. Read on for the small print!
Introduction
Padding input sequences in mini-batches is a usual method to collate inputs during training. Nevertheless, this introduces inefficiencies due to the irrelevant padding tokens. Packing examples without padding, and using the token position information, is a more efficient alternative. Nevertheless, previous implementations of packing didn’t consider example boundaries when using Flash Attention 2, leading to undesired cross-example attention that reduce quality and convergence.
Hugging Face Transformers now addresses this with a brand new feature that maintains boundary awareness during packing, alongside the introduction of a brand new data collator, DataCollatorWithFlattening.
By choosing DataCollatorWithFlattening, Hugging Face Trainer users can now seamlessly concatenate sequences right into a single tensor while accounting for sequence boundaries during Flash Attention 2 computations. That is achieved through the flash_attn_varlen_func, which calculates the cumulative sequence lengths in each mini-batch (cu_seqlens).
The identical feature is obtainable to Hugging Face SFTTrainer users within the TRL library by setting a brand new flag, padding_free=True, when calling the info collator DataCollatorForCompletionOnlyLM.
As much as 2x throughput increase
We see significant improvement in training throughput using this feature with the brand new DataCollatorWithFlattening. The figure below shows the throughput measured in tokens/second during training. In this instance, the throughput is the per-GPU average over 8 A100-80 GPU over one epoch of a 20K randomly chosen sample from two different instruct tuning datasets, FLAN and OrcaMath.
FLAN has short sequences on average but a big variance in sequence length, so example lengths in each batch may vary widely. Which means padded FLAN batches may incur a big overhead in unused padding tokens. Training on the FLAN dataset shows a big profit using the brand new DataCollatorWithFlattening when it comes to increased throughput. We see a 2x throughput increase on the models shown here: llama2-7B, mistral-7B, and granite-8B-code.
OrcaMath has longer examples and a lower variance in example length. As such, the advance from packing is lower. Our experiments show a 1.4x increase in throughput when training using this way of packing on the OrcaMath dataset across these three models.
Memory usage also improves through packing with the brand new DataCollatorWithFlattening. The next figure shows the height memory usage of the identical three models training on the identical two datasets. Peak memory is reduced by 20% on the FLAN dataset, which advantages considerably from packing.
Peak memory reduction is 6% on the OrcaMath dataset with its more homogeneous example lengths.
Packing examples, when it reduces the variety of optimization steps, may harm training convergence. The brand new feature, nevertheless, retains the minibatches and, hence, the identical variety of optimization steps as could be used with padded examples. Thus, there isn’t a impact on train convergence, as we see in the subsequent figure, which shows similar validation lack of the identical three models training on the identical two datasets, whether the models are trained with packing using the brand new DataCollatorWithFlattening or with padding.
How it really works
Consider a batch of knowledge with a batchsize = 4 where the 4 sequences are as follows:
After concatenating the examples, the padding-free collator returns the input_ids, labels, and position_ids of every example. Hence, the collator provides, for this batch of knowledge,
The modifications required are lightweight and are limited to providing the position_ids to Flash Attention 2.
This relies, nevertheless, on the model exposing position_ids. As of the time of writing, 14 models expose them and are supported by the answer. Specifically, Llama 2 and three, Mistral, Mixtral, Granite, DBRX, Falcon, Gemma, OLMo, Phi 1, 2, and three, phi3, Qwen 2 and a pair of MoE, StableLM, and StarCoder 2 are all supported by the answer.
Getting began
Reaping the advantages of packing with position_ids is straightforward.
Should you are using Hugging Face Trainer from Transformers, only two steps are required:
- Instantiate the model with Flash Attention 2
- Use the brand new
DataCollatorWithFlattening
Should you are using Hugging Face SFTTrainer from TRL with DataCollatorForCompletionOnlyLM, then the 2 required steps are:
- Instantiate the model with Flash Attention 2
- Set
padding_free=Truewhen callingDataCollatorForCompletionOnlyLMas follows:
collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer, padding_free=True)
Find out how to use it
For Trainer users, the instance below illustrates tips on how to use the brand new feature.
import torch
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"instructlab/merlinite-7b-lab",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2"
)
from datasets import load_dataset
train_dataset = load_dataset("json", data_files="path/to/my/dataset")["train"]
from transformers import DataCollatorWithFlattening
data_collator = DataCollatorWithFlattening()
from transformers import TrainingArguments, Trainer
train_args = TrainingArguments(output_dir="/save/path")
trainer = Trainer(
args=train_args,
model=model,
train_dataset=train_dataset,
data_collator=data_collator
)
trainer.train()
For TRL users, the instance below shows tips on how to use the brand new feature with SFTTrainer.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
dataset = load_dataset("lucasmccabe-lmi/CodeAlpaca-20k", split="train")
model = AutoModelForCausalLM.from_pretrained(
"instructlab/merlinite-7b-lab",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2")
tokenizer = AutoTokenizer.from_pretrained("instructlab/merlinite-7b-lab")
tokenizer.pad_token = tokenizer.eos_token
def formatting_prompts_func(example):
output_texts = []
for i in range(len(example['instruction'])):
text = f"### Query: {example['instruction'][i]}n ### Answer: {example['output'][i]}"
output_texts.append(text)
return output_texts
response_template = " ### Answer:"
response_template_ids = tokenizer.encode(response_template, add_special_tokens=False)[2:]
collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer, padding_free=True)
trainer = SFTTrainer(
model,
train_dataset=dataset,
args=SFTConfig(
output_dir="./tmp",
gradient_checkpointing=True,
per_device_train_batch_size=8
),
formatting_func=formatting_prompts_func,
data_collator=collator,
)
trainer.train()
Conclusions
Packing instruction tuning examples, as an alternative of padding, is now fully compatible with Flash Attention 2, due to a recent PR and the brand new DataCollatorWithFlattening. The strategy is compatible with models that use position_ids. Advantages might be seen in throughput and peak memory usage during training, with no degradation in training convergence. Actual throughput and memory improvement depends upon the model and the distribution of example lengths within the training data. Training with data that has a large variation of example lengths will see the best profit, with respect to padding, through the use of the DataCollatorWithFlattening. The identical feature is obtainable to SFTTrainer users within the TRL library by setting a brand new flag, padding_free=True, when calling DataCollatorForCompletionOnlyLM.
For a more detailed evaluation, have a take a look at the paper at https://huggingface.co/papers/2407.09105





