Generative AI models, equivalent to Stable Diffusion XL (SDXL), enable the creation of high-quality, realistic content with wide-ranging applications. Nonetheless, harnessing the facility of such models presents significant challenges and computational costs. SDXL is a big image generation model whose UNet component is about thrice as large because the one within the previous version of the model. Deploying a model like this in production is difficult because of the increased memory requirements, in addition to increased inference times. Today, we’re thrilled to announce that Hugging Face Diffusers now supports serving SDXL using JAX on Cloud TPUs, enabling high-performance, cost-efficient inference.
Google Cloud TPUs are custom-designed AI accelerators, that are optimized for training and inference of enormous AI models, including state-of-the-art LLMs and generative AI models equivalent to SDXL. The brand new Cloud TPU v5e is purpose-built to bring the cost-efficiency and performance required for large-scale AI training and inference. At lower than half the associated fee of TPU v4, TPU v5e makes it possible for more organizations to coach and deploy AI models.
🧨 Diffusers JAX integration offers a convenient approach to run SDXL on TPU via XLA, and we built a demo to showcase it. You’ll be able to try it out in this Space or within the playground embedded below:
Under the hood, this demo runs on several TPU v5e-4 instances (each instance has 4 TPU chips) and takes advantage of parallelization to serve 4 large 1024×1024 images in about 4 seconds. This time includes format conversions, communications time, and frontend processing; the actual generation time is about 2.3s, as we’ll see below!
On this blog post,
- We describe why JAX + TPU + Diffusers is a strong framework to run SDXL
- Explain how you’ll be able to write a straightforward image generation pipeline with Diffusers and JAX
- Show benchmarks comparing different TPU settings
Why JAX + TPU v5e for SDXL?
Serving SDXL with JAX on Cloud TPU v5e with high performance and cost-efficiency is feasible due to the mixture of purpose-built TPU hardware and a software stack optimized for performance. Below we highlight two key aspects: JAX just-in-time (jit) compilation and XLA compiler-driven parallelism with JAX pmap.
JIT compilation
A notable feature of JAX is its just-in-time (jit) compilation. The JIT compiler traces code throughout the first run and generates highly optimized TPU binaries which are re-used in subsequent calls.
The catch of this process is that it requires all input, intermediate, and output shapes to be static, meaning that they have to be known prematurely. Each time we modify the shapes
a brand new and dear compilation process will probably be triggered again. JIT compilation is right for services that will be designed around static shapes: compilation runs once, after which we reap the benefits of super-fast inference times.
Image generation is well-suited for JIT compilation. If we at all times generate the identical variety of images and so they have the identical size, then the output shapes are constant and known prematurely. The text inputs are also constant: by design, Stable Diffusion and SDXL use fixed-shape embedding vectors (with padding) to represent the prompts typed by the user. Due to this fact, we will write JAX code that relies on fixed shapes, and that will be greatly optimized!
High-performance throughput for prime batch sizes
Workloads will be scaled across multiple devices using JAX’s pmap, which expresses single-program multiple-data (SPMD) programs. Applying pmap to a function will compile a function with XLA, then execute it in parallel on various XLA devices.
For text-to-image generation workloads because of this increasing the variety of images rendered concurrently is easy to implement and doesn’t compromise performance. For instance, running SDXL on a TPU with 8 chips will generate 8 images in the identical time it takes for 1 chip to create a single image.
TPU v5e instances are available in multiple shapes, including 1, 4 and 8-chip shapes, all the way in which as much as 256 chips (a full TPU v5e pod), with ultra-fast ICI links between chips. This lets you select the TPU shape that most closely fits your use case and simply reap the benefits of the parallelism that JAX and TPUs provide.
The way to write a picture generation pipeline in JAX
We’ll go step-by-step over the code you have to write to run inference super-fast using JAX! First, let’s import the dependencies.
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from diffusers import FlaxStableDiffusionXLPipeline
import time
We’ll now load the bottom SDXL model and the remaining of the components required for inference. The diffusers pipeline takes care of downloading and caching every thing for us. Adhering to JAX’s functional approach, the model’s parameters are returned individually and could have to be passed to the pipeline during inference:
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", split_head_dim=True
)
Model parameters are downloaded in 32-bit precision by default. To save lots of memory and run computation faster we’ll convert them to bfloat16, an efficient 16-bit representation. Nonetheless, there is a caveat: for best results, we’ve got to maintain the scheduler state in float32, otherwise precision errors accumulate and lead to low-quality and even black images.
scheduler_state = params.pop("scheduler")
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
params["scheduler"] = scheduler_state
We are actually able to arrange our prompt and the remaining of the pipeline inputs.
default_prompt = "high-quality photo of a baby dolphin playing in a pool and wearing a celebration hat"
default_neg_prompt = "illustration, low-quality"
default_seed = 33
default_guidance_scale = 5.0
default_num_steps = 25
The prompts need to be supplied as tensors to the pipeline, and so they at all times need to have the identical dimensions across invocations. This permits the inference call to be compiled. The pipeline prepare_inputs method performs all of the mandatory steps for us, so we’ll create a helper function to organize each our prompt and negative prompt as tensors. We’ll use it later from our generate function:
def tokenize_prompt(prompt, neg_prompt):
prompt_ids = pipeline.prepare_inputs(prompt)
neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
return prompt_ids, neg_prompt_ids
To reap the benefits of parallelization, we’ll replicate the inputs across devices. A Cloud TPU v5e-4 has 4 chips, so by replicating the inputs we get each chip to generate a unique image, in parallel. We must be careful to produce a unique random seed to every chip so the 4 images are different:
NUM_DEVICES = jax.device_count()
p_params = replicate(params)
def replicate_all(prompt_ids, neg_prompt_ids, seed):
p_prompt_ids = replicate(prompt_ids)
p_neg_prompt_ids = replicate(neg_prompt_ids)
rng = jax.random.PRNGKey(seed)
rng = jax.random.split(rng, NUM_DEVICES)
return p_prompt_ids, p_neg_prompt_ids, rng
We are actually able to put every thing together in a generate function:
def generate(
prompt,
negative_prompt,
seed=default_seed,
guidance_scale=default_guidance_scale,
num_inference_steps=default_num_steps,
):
prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
images = pipeline(
prompt_ids,
p_params,
rng,
num_inference_steps=num_inference_steps,
neg_prompt_ids=neg_prompt_ids,
guidance_scale=guidance_scale,
jit=True,
).images
images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
return pipeline.numpy_to_pil(np.array(images))
jit=True indicates that we would like the pipeline call to be compiled. It will occur the primary time we call generate, and it’s going to be very slow – JAX must trace the operations, optimize them, and convert them to low-level primitives. We’ll run a primary generation to finish this process and warm things up:
start = time.time()
print(f"Compiling ...")
generate(default_prompt, default_neg_prompt)
print(f"Compiled in {time.time() - start}")
This took about three minutes the primary time we ran it.
But once the code has been compiled, inference will probably be super fast. Let’s try again!
start = time.time()
prompt = "llama in ancient Greece, oil on canvas"
neg_prompt = "cartoon, illustration, animation"
images = generate(prompt, neg_prompt)
print(f"Inference in {time.time() - start}")
It now took about 2s to generate the 4 images!
Benchmark
The next measures were obtained running SDXL 1.0 base for 20 steps, with the default Euler Discrete scheduler. We compare Cloud TPU v5e with TPUv4 for a similar batch sizes. Do note that, because of parallelism, a TPU v5e-4 just like the ones we use in our demo will generate 4 images when using a batch size of 1 (or 8 images with a batch size of two). Similarly, a TPU v5e-8 will generate 8 images when using a batch size of 1.
The Cloud TPU tests were run using Python 3.10 and jax version 0.4.16. These are the identical specs utilized in our demo Space.
| Batch Size | Latency | Perf/$ | |
|---|---|---|---|
| TPU v5e-4 (JAX) | 4 | 2.33s | 21.46 |
| 8 | 4.99s | 20.04 | |
| TPU v4-8 (JAX) | 4 | 2.16s | 9.05 |
| 8 | 4.17 | 8.98 |
TPU v5e achieves as much as 2.4x greater perf/$ on SDXL in comparison with TPU v4, demonstrating the cost-efficiency of the most recent TPU generation.
To measure inference performance, we use the industry-standard metric of throughput. First, we measure latency per image when the model has been compiled and loaded. Then, we calculate throughput by dividing batch size over latency per chip. Consequently, throughput measures how the model is performing in production environments no matter what number of chips are used. We then divide throughput by the list price to get performance per dollar.
How does the demo work?
The demo we showed before was built using a script that essentially follows the code we posted on this blog post. It runs on a couple of Cloud TPU v5e devices with 4 chips each, and there is a straightforward load-balancing server that routes user requests to backend servers randomly. If you enter a prompt within the demo, your request will probably be assigned to one in all the backend servers, and you will receive the 4 images it generates.
This is an easy solution based on several pre-allocated TPU instances. In a future post, we’ll cover how one can create dynamic solutions that adapt to load using GKE.
All of the code for the demo is open-source and available in Hugging Face Diffusers today. We’re excited to see what you construct with Diffusers + JAX + Cloud TPUs!
