🤗 Hugging Face Diffusers supports Flax since version 0.5.1! This permits for super fast inference on Google TPUs, resembling those available in Colab, Kaggle or Google Cloud Platform.
This post shows the right way to run inference using JAX / Flax. In the event you want more details about how Stable Diffusion works or wish to run it in GPU, please check with this Colab notebook.
If you wish to follow along, click the button above to open this post as a Colab notebook.
First, ensure you might be using a TPU backend. In the event you are running this notebook in Colab, select Runtime within the menu above, then select the choice “Change runtime type” after which select TPU under the Hardware accelerator setting.
Note that JAX is just not exclusive to TPUs, but it surely shines on that hardware because each TPU server has 8 TPU accelerators working in parallel.
Setup
import jax
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind
print(f"Found {num_devices} JAX devices of type {device_type}.")
assert "TPU" in device_type, "Available device is just not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator"
Output:
Found 8 JAX devices of type TPU v2.
Be sure diffusers is installed.
!pip install diffusers==0.5.1
Then we import all of the dependencies.
import numpy as np
import jax
import jax.numpy as jnp
from pathlib import Path
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from huggingface_hub import notebook_login
from diffusers import FlaxStableDiffusionPipeline
Model Loading
Before using the model, you should accept the model license with a purpose to download and use the weights.
The license is designed to mitigate the potential harmful effects of such a strong machine learning system.
We request users to read the license entirely and thoroughly. Here we provide a summary:
- You may’t use the model to deliberately produce nor share illegal or harmful outputs or content,
- We claim no rights on the outputs you generate, you might be free to make use of them and are accountable for his or her use which shouldn’t go against the provisions set within the license, and
- It’s possible you’ll re-distribute the weights and use the model commercially and/or as a service. In the event you do, please remember you will have to incorporate the identical use restrictions because the ones within the license and share a duplicate of the CreativeML OpenRAIL-M to all of your users.
Flax weights can be found in Hugging Face Hub as a part of the Stable Diffusion repo. The Stable Diffusion model is distributed under the CreateML OpenRail-M license. It’s an open license that claims no rights on the outputs you generate and prohibits you from deliberately producing illegal or harmful content. The model card provides more details, so take a moment to read them and consider rigorously whether you accept the license. In the event you do, you should be a registered user within the Hub and use an access token for the code to work. You’ve two options to supply your access token:
- Use the
huggingface-cli logincommand-line tool in your terminal and paste your token when prompted. It’ll be saved in a file in your computer. - Or use
notebook_login()in a notebook, which does the identical thing.
The next cell will present a login interface unless you’ve got already authenticated before on this computer. You’ll have to stick your access token.
if not (Path.home()/'.huggingface'/'token').exists(): notebook_login()
TPU devices support bfloat16, an efficient half-float type. We’ll use it for our tests, but it’s also possible to use float32 to make use of full precision as an alternative.
dtype = jnp.bfloat16
Flax is a functional framework, so models are stateless and parameters are stored outside them. Loading the pre-trained Flax pipeline will return each the pipeline itself and the model weights (or parameters). We’re using a bf16 version of the weights, which results in type warnings that you would be able to safely ignore.
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="bf16",
dtype=dtype,
)
Inference
Since TPUs often have 8 devices working in parallel, we’ll replicate our prompt as over and over as devices we have now. Then we’ll perform inference on the 8 devices directly, each chargeable for generating one image. Thus, we’ll get 8 images in the identical period of time it takes for one chip to generate a single one.
After replicating the prompt, we obtain the tokenized text ids by invoking the prepare_inputs function of the pipeline. The length of the tokenized text is ready to 77 tokens, as required by the configuration of the underlying CLIP Text model.
prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids.shape
Output:
(8, 77)
Replication and parallelization
Model parameters and inputs must be replicated across the 8 parallel devices we have now. The parameters dictionary is replicated using flax.jax_utils.replicate, which traverses the dictionary and changes the form of the weights in order that they are repeated 8 times. Arrays are replicated using shard.
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
prompt_ids.shape
Output:
(8, 1, 77)
That shape implies that each certainly one of the 8 devices will receive as an input a jnp array with shape (1, 77). 1 is subsequently the batch size per device. In TPUs with sufficient memory, it might be larger than 1 if we desired to generate multiple images (per chip) directly.
We’re almost able to generate images! We just must create a random number generator to pass to the generation function. That is the usual procedure in Flax, which may be very serious and opinionated about random numbers – all functions that cope with random numbers are expected to receive a generator. This ensures reproducibility, even once we are training across multiple distributed devices.
The helper function below uses a seed to initialize a random number generator. So long as we use the identical seed, we’ll get the very same results. Be at liberty to make use of different seeds when exploring results later within the notebook.
def create_key(seed=0):
return jax.random.PRNGKey(seed)
We obtain a rng after which “split” it 8 times so each device receives a unique generator. Due to this fact, each device will create a unique image, and the total process is reproducible.
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())
JAX code will be compiled to an efficient representation that runs very fast. Nevertheless, we’d like to be sure that all inputs have the identical shape in subsequent calls; otherwise, JAX may have to recompile the code, and we would not have the ability to make the most of the optimized speed.
The Flax pipeline can compile the code for us if we pass jit = True as an argument. It’ll also be sure that the model runs in parallel within the 8 available devices.
The primary time we run the next cell it is going to take a protracted time to compile, but subsequent calls (even with different inputs) will likely be much faster. For instance, it took greater than a minute to compile in a TPU v2-8 once I tested, but then it takes about 7s for future inference runs.
images = pipeline(prompt_ids, p_params, rng, jit=True)[0]
Output:
CPU times: user 464 ms, sys: 105 ms, total: 569 ms
Wall time: 7.07 s
The returned array has shape (8, 1, 512, 512, 3). We reshape it to do away with the second dimension and procure 8 images of 512 × 512 × 3 after which convert them to PIL.
images = images.reshape((images.shape[0],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
Visualization
Let’s create a helper function to display images in a grid.
def image_grid(imgs, rows, cols):
w,h = imgs[0].size
grid = Image.recent('RGB', size=(cols*w, rows*h))
for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
image_grid(images, 2, 4)
Using different prompts
We do not have to copy the same prompt in all of the devices. We are able to do whatever we wish: generate 2 prompts 4 times each, and even generate 8 different prompts directly. Let’s do this!
First, we’ll refactor the input preparation code right into a handy function:
prompts = [
"Labrador in the style of Hokusai",
"Painting of a squirrel skating in New York",
"HAL-9000 in the style of Van Gogh",
"Times Square under water, with fish and a dolphin swimming around",
"Ancient Roman fresco showing a man working on his laptop",
"Close-up photograph of young black woman against urban background, high quality, bokeh",
"Armchair in the shape of an avocado",
"Clown astronaut in space, with Earth in the background",
]
prompt_ids = pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, p_params, rng, jit=True).images
images = images.reshape((images.shape[0], ) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
image_grid(images, 2, 4)
How does parallelization work?
We said before that the diffusers Flax pipeline routinely compiles the model and runs it in parallel on all available devices. We’ll now briefly look inside that process to point out how it really works.
JAX parallelization will be done in multiple ways. The best one revolves around using the jax.pmap function to realize single-program, multiple-data (SPMD) parallelization. It means we’ll run several copies of the identical code, each on different data inputs. More sophisticated approaches are possible, we invite you to go over the JAX documentation and the pjit pages to explore this topic in case you have an interest!
jax.pmap does two things for us:
- Compiles (or
jits) the code, as if we had invokedjax.jit(). This doesn’t occur once we callpmap, but the primary time the pmapped function is invoked. - Ensures the compiled code runs in parallel in all of the available devices.
To indicate how it really works we pmap the _generate approach to the pipeline, which is the private method that runs generates images. Please, note that this method could also be renamed or removed in future releases of diffusers.
p_generate = pmap(pipeline._generate)
After we use pmap, the prepared function p_generate will conceptually do the next:
- Invoke a duplicate of the underlying function
pipeline._generatein each device. - Send each device a unique portion of the input arguments. That is what sharding is used for. In our case,
prompt_idshas shape(8, 1, 77, 768). This array will likely be split in8and every copy of_generatewill receive an input with shape(1, 77, 768).
We are able to code _generate completely ignoring the undeniable fact that it is going to be invoked in parallel. We just care about our batch size (1 in this instance) and the scale that make sense for our code, and do not have to alter anything to make it work in parallel.
The identical way as once we used the pipeline call, the primary time we run the next cell it is going to take some time, but then it is going to be much faster.
images = p_generate(prompt_ids, p_params, rng)
images = images.block_until_ready()
images.shape
Output:
CPU times: user 118 ms, sys: 83.9 ms, total: 202 ms
Wall time: 6.82 s
(8, 1, 512, 512, 3)
We use block_until_ready() to appropriately measure inference time, because JAX uses asynchronous dispatch and returns control to the Python loop as soon as it could possibly. You needn’t use that in your code; blocking will occur routinely when you wish to use the results of a computation that has not yet been materialized.


