Diffusion models (e.g., DALL-E 2, Stable Diffusion) are a category of generative models which can be widely successful at generating images most notably of the photorealistic kind. Nevertheless, the pictures generated by these models may not at all times be on par with human preference or human intention. Thus arises the alignment problem i.e. how does one go about ensuring that the outputs of a model are aligned with human preferences like “quality” or that outputs are aligned with intent that is tough to specific via prompts? That is where Reinforcement Learning comes into the image.
On the earth of Large Language Models (LLMs), Reinforcement learning (RL) has proven to turn into a really effective tool for aligning said models to human preferences. It’s one among the major recipes behind the superior performance of systems like ChatGPT. More precisely, RL is the critical ingredient of Reinforcement Learning from Human Feedback (RLHF), which makes ChatGPT chat like human beings.
In Training Diffusion Models with Reinforcement Learning, Black et al. show learn how to augment diffusion models to leverage RL to fine-tune them with respect to an objective function via a way named Denoising Diffusion Policy Optimization (DDPO).
On this blog post, we discuss how DDPO got here to be, a transient description of how it really works, and the way DDPO might be incorporated into an RLHF workflow to attain model outputs more aligned with the human aesthetics. We then quickly switch gears to speak about how you possibly can apply DDPO to your models with the newly integrated DDPOTrainer from the trl library and discuss our findings from running DDPO on Stable Diffusion.
The Benefits of DDPO
DDPO shouldn’t be the one working answer to the query of learn how to try and fine-tune diffusion models with RL.
Before diving in, there are two key points to recollect in relation to understanding some great benefits of one RL solution over the opposite
- Computational efficiency is essential. The more complicated your data distribution gets, the upper your computational costs get.
- Approximations are nice, but because approximations aren’t the true thing, associated errors stack up.
Before DDPO, Reward-weighted regression (RWR) was a longtime way of using Reinforcement Learning to fine-tune diffusion models. RWR reuses the denoising loss function of the diffusion model together with training data sampled from the model itself and per-sample loss weighting that will depend on the reward related to the ultimate samples. This algorithm ignores the intermediate denoising steps/samples. While this works, two things ought to be noted:
- Optimizing by weighing the associated loss, which is a maximum likelihood objective, is an approximate optimization
- The associated loss shouldn’t be a precise maximum likelihood objective but an approximation that’s derived from a reweighed variational certain
The 2 orders of approximation have a big impact on each performance and the power to handle complex objectives.
DDPO uses this method as a place to begin. Slightly than viewing the denoising step as a single step by only specializing in the ultimate sample, DDPO frames the entire denoising process as a multistep Markov Decision Process (MDP) where the reward is received on the very end. This formulation along with using a set sampler paves the way in which for the agent policy to turn into an isotropic Gaussian versus an arbitrarily complicated distribution. So as a substitute of using the approximate likelihood of the ultimate sample (which is the trail RWR takes), here the precise likelihood of every denoising step which is incredibly easy to compute ( ).
When you’re taken with learning more details about DDPO, we encourage you to ascertain out the original paper and the accompanying blog post.
DDPO algorithm briefly
Given the MDP framework used to model the sequential nature of the denoising process and the remaining of the considerations that follow, the tool of alternative to tackle the optimization problem is a policy gradient method. Specifically Proximal Policy Optimization (PPO). The entire DDPO algorithm is just about the identical as Proximal Policy Optimization (PPO) but as a side, the portion that stands out as highly customized is the trajectory collection portion of PPO
Here’s a diagram to summarize the flow:
DDPO and RLHF: a mixture to implement aestheticness
The final training aspect of RLHF can roughly be broken down into the next steps:
- Supervised fine-tuning a “base” model learns to the distribution of some recent data
- Gathering preference data and training a reward model using it.
- Tremendous-tuning the model with reinforcement learning using the reward model as a signal.
It ought to be noted that preference data is the first source for capturing human feedback within the context of RLHF.
After we add DDPO to the combination, the workflow gets morphed to the next:
- Starting with a pretrained Diffusion Model
- Gathering preference data and training a reward model using it.
- Tremendous-tuning the model with DDPO using the reward model as a signal
Notice that step 3 from the overall RLHF workflow is missing within the latter list of steps and it’s because empirically it has been shown (as you’re going to get to see yourself) that this shouldn’t be needed.
To get on with our enterprise to get a diffusion model to output images more according to the human perceived notion of what it means to be aesthetic, we follow these steps:
- Starting with a pretrained Stable Diffusion (SD) Model
- Training a frozen CLIP model with a trainable regression head on the Aesthetic Visual Evaluation (AVA) dataset to predict how much people like an input image on average
- Tremendous-tuning the SD model with DDPO using the aesthetic predictor model because the reward signaller
We keep these steps in mind while moving on to truly getting these running which is described in the next sections.
Training Stable Diffusion with DDPO
Setup
To start, in relation to the hardware side of things and this implementation of DDPO, on the very least access to an A100 NVIDIA GPU is required for successful training. Anything below this GPU type will soon run into Out-of-memory issues.
Use pip to put in the trl library
pip install trl[diffusers]
This could get the major library installed. The next dependencies are for tracking and image logging. After getting wandb installed, remember to login to save lots of the outcomes to a private account
pip install wandb torchvision
Note: you may select to make use of tensorboard relatively than wandb for which you’d need to install the tensorboard package via pip.
A Walkthrough
The major classes inside the trl library chargeable for DDPO training are the DDPOTrainer and DDPOConfig classes. See docs for more general info on the DDPOTrainer and DDPOConfig. There may be an example training script within the trl repo. It uses each of those classes in tandem with default implementations of required inputs and default parameters to finetune a default pretrained Stable Diffusion Model from RunwayML .
This instance script uses wandb for logging and uses an aesthetic reward model whose weights are read from a public facing HuggingFace repo (so gathering data and training the aesthetic reward model is already done for you). The default prompt dataset used is an inventory of animal names.
There is just one commandline flag argument that’s required of the user to get things up and running. Moreover, the user is predicted to have a huggingface user access token that will likely be used to upload the model post finetuning to HuggingFace hub.
The next bash command gets things running:
python ddpo.py --hf_user_access_token
The next table comprises key hyperparameters which can be directly correlated with positive results:
| Parameter | Description | Beneficial value for single GPU training (as of now) |
|---|---|---|
num_epochs |
The variety of epochs to coach for | 200 |
train_batch_size |
The batch size to make use of for training | 3 |
sample_batch_size |
The batch size to make use of for sampling | 6 |
gradient_accumulation_steps |
The variety of accelerator based gradient accumulation steps to make use of | 1 |
sample_num_steps |
The variety of steps to sample for | 50 |
sample_num_batches_per_epoch |
The variety of batches to sample per epoch | 4 |
per_prompt_stat_tracking |
Whether to trace stats per prompt. If false, benefits will likely be calculated using the mean and std of all the batch versus tracking per prompt | True |
per_prompt_stat_tracking_buffer_size |
The dimensions of the buffer to make use of for tracking stats per prompt | 32 |
mixed_precision |
Mixed precision training | True |
train_learning_rate |
Learning rate | 3e-4 |
The provided script is merely a place to begin. Be at liberty to regulate the hyperparameters and even overhaul the script to accommodate different objective functions. As an example, one could integrate a function that gauges JPEG compressibility or one which evaluates visual-text alignment using a multi-modal model, amongst other possibilities.
Lessons learned
- The outcomes appear to generalize over a wide selection of prompts despite the minimally sized training prompts size. This has been thoroughly verified for the target function that rewards aesthetics
- Attempts to attempt to explicitly generalize a minimum of for the aesthetic objective function by increasing the training prompt size and ranging the prompts appear to decelerate the convergence rate for barely noticeable learned general behavior if in any respect this exists
- While LoRA is really useful and is tried and tested multiple times, the non-LoRA is something to contemplate, amongst other reasons from empirical evidence, non-Lora does seem to provide relatively more intricate images than LoRA. Nevertheless, getting the best hyperparameters for a stable non-LoRA run is significantly tougher.
- Recommendations for the config parameters for non-Lora are: set the training rate relatively low, something around
1e-5should do the trick and setmixed_precisiontoNone
Results
The next are pre-finetuned (left) and post-finetuned (right) outputs for the prompts bear, heaven and dune (each row is for the outputs of a single prompt):
Limitations
- At once
trl‘s DDPOTrainer is restricted to finetuning vanilla SD models; - In our experiments we primarily focused on LoRA which works thoroughly. We did a number of experiments with full training which may lead to raised quality but finding the best hyperparameters is more difficult.
Conclusion
Diffusion models like Stable Diffusion, when fine-tuned using DDPO, can offer significant improvements in the standard of generated images as perceived by humans or another metric once properly conceptualized as an objective function
The computational efficiency of DDPO and its ability to optimize without counting on approximations, especially over earlier methods to attain the identical goal of fine-tuning diffusion models, make it an appropriate candidate for fine-tuning diffusion models like Stable Diffusion
trl library’s DDPOTrainer implements DDPO for finetuning SD models.
Our experimental findings underline the strength of DDPO in generalizing across a broad range of prompts, although attempts at explicit generalization through various prompts had mixed results. The problem of finding the best hyperparameters for non-LoRA setups also emerged as a crucial learning.
DDPO is a promising technique to align diffusion models with any reward function and we hope that with the discharge in TRL we will make it more accessible to the community!
Acknowledgements
Due to Chunte Lee for the thumbnail of this blog post.

