Reinforcement Learning from Human Feedback (RLHF) has change into the de facto last training step of LLMs resembling GPT-4 or Claude to make sure that the language model’s outputs are aligned with human expectations resembling chattiness or safety features. Nonetheless, it brings a number of the complexity of RL into NLP: we’d like to construct a great reward function, train the model to estimate the worth of a state, and at the identical time watch out to not strive too removed from the unique model and produce gibberish as an alternative of sensible text. Such a process is sort of involved requiring quite a few complex moving parts where it is just not at all times easy to get things right.
The recent paper Direct Preference Optimization by Rafailov, Sharma, Mitchell et al. proposes to solid the RL-based objective utilized by existing methods to an objective which will be directly optimized via a straightforward binary cross-entropy loss which simplifies this strategy of refining LLMs greatly.
This blog-post introduces the Direct Preference Optimization (DPO) method which is now available within the TRL library and shows how one can nice tune the recent Llama v2 7B-parameter model on the stack-exchange preference dataset which comprises ranked answers to questions on the assorted stack-exchange portals.
DPO vs PPO
In the standard model of optimising human derived preferences via RL, the goto method has been to make use of an auxiliary reward model and fine-tune the model of interest in order that it maximizes this given reward via the machinery of RL. Intuitively we use the reward model to supply feedback to the model we’re optimising in order that it generates high-reward samples more often and low-reward samples less often. At the identical time we use a frozen reference model to ensure that that whatever is generated doesn’t deviate an excessive amount of and continues to keep up generation diversity. This is often done by adding a KL penalty to the total reward maximisation objective via a reference model, which serves to forestall the model from learning to cheat or exploit the reward model.
The DPO formulation bypasses the reward modeling step and directly optimises the language model on preference data via a key insight: namely an analytical mapping from the reward function to the optimal RL policy that allows the authors to rework the RL loss over the reward and reference models to a loss over the reference model directly! This mapping intuitively measures how well a given reward function aligns with the given preference data. DPO thus starts with the optimal solution to the RLHF loss and via a change of variables derives a loss over only the reference model!
Thus this direct likelihood objective will be optimized without the necessity for a reward model or the necessity to perform the possibly fiddly RL based optimisation.
Methods to train with TRL
As mentioned, typically the RLHF pipeline consists of those distinct parts:
- a supervised fine-tuning (SFT) step
- the strategy of annotating data with preference labels
- training a reward model on the preference data
- and the RL optmization step
The TRL library comes with helpers for all these parts, nevertheless the DPO training does away with the duty of reward modeling and RL (steps 3 and 4) and directly optimizes the DPO object on preference annotated data.
On this respect we’d still have to do the step 1, but as an alternative of steps 3 and 4 we’d like to supply the DPOTrainer in TRL with preference data from step 2 which has a really specific format, namely a dictionary with the next three keys:
promptthis consists of the context prompt which is given to a model at inference time for text generationchosencomprises the popular generated response to the corresponding promptrejectedcomprises the response which is just not preferred or shouldn’t be the sampled response with respect to the given prompt
For example, for the stack-exchange preference pairs dataset, we are able to map the dataset entries to return the specified dictionary via the next helper and drop all the unique columns:
def return_prompt_and_responses(samples) -> Dict[str, str, str]:
return {
"prompt": [
"Question: " + question + "nnAnswer: "
for question in samples["question"]
],
"chosen": samples["response_j"],
"rejected": samples["response_k"],
}
dataset = load_dataset(
"lvwerra/stack-exchange-paired",
split="train",
data_dir="data/rl"
)
original_columns = dataset.column_names
dataset.map(
return_prompt_and_responses,
batched=True,
remove_columns=original_columns
)
Once now we have the dataset sorted the DPO loss is actually a supervised loss which obtains an implicit reward via a reference model and thus at a high-level the DPOTrainer requires the bottom model we want to optimize in addition to a reference model:
dpo_trainer = DPOTrainer(
model,
model_ref,
beta=0.1,
train_dataset=dataset,
tokenizer=tokenizer,
args=training_args,
)
where the beta hyper-parameter is the temperature parameter for the DPO loss, typically within the range 0.1 to 0.5. This controls how much we concentrate to the reference model within the sense that as beta gets smaller the more we ignore the reference model. Once now we have our trainer initialised we are able to then train it on the dataset with the given training_args by simply calling:
dpo_trainer.train()
Experiment with Llama v2
The good thing about implementing the DPO trainer in TRL is that one can reap the benefits of all the additional bells and whistles of coaching large LLMs which include TRL and its dependent libraries like Peft and Speed up. With these libraries we’re even in a position to train a Llama v2 model using the QLoRA technique provided by the bitsandbytes library.
Supervised Superb Tuning
The method as introduced above involves the supervised fine-tuning step using QLoRA on the 7B Llama v2 model on the SFT split of the info via TRL’s SFTTrainer:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
base_model = AutoModelForCausalLM.from_pretrained(
script_args.model_name,
quantization_config=bnb_config,
device_map={"": 0},
trust_remote_code=True,
use_auth_token=True,
)
base_model.config.use_cache = False
peft_config = LoraConfig(
r=script_args.lora_r,
lora_alpha=script_args.lora_alpha,
lora_dropout=script_args.lora_dropout,
target_modules=["q_proj", "v_proj"],
bias="none",
task_type="CAUSAL_LM",
)
...
trainer = SFTTrainer(
model=base_model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
packing=True,
max_seq_length=None,
tokenizer=tokenizer,
args=training_args,
)
trainer.train()
DPO Training
Once the SFT has finished, we are able to save the resulting model and move onto the DPO training. As is often done we’ll utilize the saved model from the previous SFT step for each the bottom model in addition to reference model of DPO. Then we are able to use these to coach the model with the DPO objective on the stack-exchange preference data shown above. Because the models were trained via LoRa adapters, we load the models via Peft’s AutoPeftModelForCausalLM helpers:
model = AutoPeftModelForCausalLM.from_pretrained(
script_args.model_name_or_path,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
load_in_4bit=True,
is_trainable=True,
)
model_ref = AutoPeftModelForCausalLM.from_pretrained(
script_args.model_name_or_path,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
load_in_4bit=True,
)
...
dpo_trainer = DPOTrainer(
model,
model_ref,
args=training_args,
beta=script_args.beta,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
peft_config=peft_config,
)
dpo_trainer.train()
dpo_trainer.save_model()
In order will be seen we load the model within the 4-bit configuration after which train it via the QLora method via the peft_config arguments. The trainer may even evaluate the progress during training with respect to the evaluation dataset and report back quite a few key metrics just like the implicit reward which will be recorded and displayed via WandB for instance. We are able to then push the ultimate trained model to the HuggingFace Hub.
Conclusion
The complete source code of the training scripts for the SFT and DPO can be found in the next examples/stack_llama_2 directory and the trained model with the merged adapters will be found on the HF Hub here.
The WandB logs for the DPO training run will be found here where during training and evaluation the DPOTrainer records the next reward metrics:
rewards/chosen: the mean difference between the log probabilities of the policy model and the reference model for the chosen responses scaled bybetarewards/rejected: the mean difference between the log probabilities of the policy model and the reference model for the rejected responses scaled bybetarewards/accuracies: mean of how often the chosen rewards are > than the corresponding rejected rewardsrewards/margins: the mean difference between the chosen and corresponding rejected rewards.
Intuitively, during training we would like the margins to extend and the accuracies to go to 1.0, or in other words the chosen reward to be higher than the rejected reward (or the margin larger than zero). These metrics can then be calculated over some evaluation dataset.
We hope with the code release it lowers the barrier to entry for you the readers to check out this approach to aligning large language models on your personal datasets and we cannot wait to see what you construct! And if you should check out the model yourself you may achieve this here: trl-lib/stack-llama.
