🧨 Stable Diffusion in JAX / Flax !

-


Pedro Cuenca's avatar

Patrick von Platen's avatar



Open In Colab

🤗 Hugging Face Diffusers supports Flax since version 0.5.1! This permits for super fast inference on Google TPUs, resembling those available in Colab, Kaggle or Google Cloud Platform.

This post shows the right way to run inference using JAX / Flax. In the event you want more details about how Stable Diffusion works or wish to run it in GPU, please check with this Colab notebook.

If you wish to follow along, click the button above to open this post as a Colab notebook.

First, ensure you might be using a TPU backend. In the event you are running this notebook in Colab, select Runtime within the menu above, then select the choice “Change runtime type” after which select TPU under the Hardware accelerator setting.

Note that JAX is just not exclusive to TPUs, but it surely shines on that hardware because each TPU server has 8 TPU accelerators working in parallel.



Setup

import jax
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind

print(f"Found {num_devices} JAX devices of type {device_type}.")
assert "TPU" in device_type, "Available device is just not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator"

Output:

    Found 8 JAX devices of type TPU v2.

Be sure diffusers is installed.

!pip install diffusers==0.5.1

Then we import all of the dependencies.

import numpy as np
import jax
import jax.numpy as jnp

from pathlib import Path
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image

from huggingface_hub import notebook_login
from diffusers import FlaxStableDiffusionPipeline



Model Loading

Before using the model, you should accept the model license with a purpose to download and use the weights.

The license is designed to mitigate the potential harmful effects of such a strong machine learning system.
We request users to read the license entirely and thoroughly. Here we provide a summary:

  1. You may’t use the model to deliberately produce nor share illegal or harmful outputs or content,
  2. We claim no rights on the outputs you generate, you might be free to make use of them and are accountable for his or her use which shouldn’t go against the provisions set within the license, and
  3. It’s possible you’ll re-distribute the weights and use the model commercially and/or as a service. In the event you do, please remember you will have to incorporate the identical use restrictions because the ones within the license and share a duplicate of the CreativeML OpenRAIL-M to all of your users.

Flax weights can be found in Hugging Face Hub as a part of the Stable Diffusion repo. The Stable Diffusion model is distributed under the CreateML OpenRail-M license. It’s an open license that claims no rights on the outputs you generate and prohibits you from deliberately producing illegal or harmful content. The model card provides more details, so take a moment to read them and consider rigorously whether you accept the license. In the event you do, you should be a registered user within the Hub and use an access token for the code to work. You’ve two options to supply your access token:

  • Use the huggingface-cli login command-line tool in your terminal and paste your token when prompted. It’ll be saved in a file in your computer.
  • Or use notebook_login() in a notebook, which does the identical thing.

The next cell will present a login interface unless you’ve got already authenticated before on this computer. You’ll have to stick your access token.

if not (Path.home()/'.huggingface'/'token').exists(): notebook_login()

TPU devices support bfloat16, an efficient half-float type. We’ll use it for our tests, but it’s also possible to use float32 to make use of full precision as an alternative.

dtype = jnp.bfloat16

Flax is a functional framework, so models are stateless and parameters are stored outside them. Loading the pre-trained Flax pipeline will return each the pipeline itself and the model weights (or parameters). We’re using a bf16 version of the weights, which results in type warnings that you would be able to safely ignore.

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    revision="bf16",
    dtype=dtype,
)



Inference

Since TPUs often have 8 devices working in parallel, we’ll replicate our prompt as over and over as devices we have now. Then we’ll perform inference on the 8 devices directly, each chargeable for generating one image. Thus, we’ll get 8 images in the identical period of time it takes for one chip to generate a single one.

After replicating the prompt, we obtain the tokenized text ids by invoking the prepare_inputs function of the pipeline. The length of the tokenized text is ready to 77 tokens, as required by the configuration of the underlying CLIP Text model.

prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids.shape

Output:

    (8, 77)



Replication and parallelization

Model parameters and inputs must be replicated across the 8 parallel devices we have now. The parameters dictionary is replicated using flax.jax_utils.replicate, which traverses the dictionary and changes the form of the weights in order that they are repeated 8 times. Arrays are replicated using shard.

p_params = replicate(params)
prompt_ids = shard(prompt_ids)
prompt_ids.shape

Output:

    (8, 1, 77)

That shape implies that each certainly one of the 8 devices will receive as an input a jnp array with shape (1, 77). 1 is subsequently the batch size per device. In TPUs with sufficient memory, it might be larger than 1 if we desired to generate multiple images (per chip) directly.

We’re almost able to generate images! We just must create a random number generator to pass to the generation function. That is the usual procedure in Flax, which may be very serious and opinionated about random numbers – all functions that cope with random numbers are expected to receive a generator. This ensures reproducibility, even once we are training across multiple distributed devices.

The helper function below uses a seed to initialize a random number generator. So long as we use the identical seed, we’ll get the very same results. Be at liberty to make use of different seeds when exploring results later within the notebook.

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

We obtain a rng after which “split” it 8 times so each device receives a unique generator. Due to this fact, each device will create a unique image, and the total process is reproducible.

rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())

JAX code will be compiled to an efficient representation that runs very fast. Nevertheless, we’d like to be sure that all inputs have the identical shape in subsequent calls; otherwise, JAX may have to recompile the code, and we would not have the ability to make the most of the optimized speed.

The Flax pipeline can compile the code for us if we pass jit = True as an argument. It’ll also be sure that the model runs in parallel within the 8 available devices.

The primary time we run the next cell it is going to take a protracted time to compile, but subsequent calls (even with different inputs) will likely be much faster. For instance, it took greater than a minute to compile in a TPU v2-8 once I tested, but then it takes about 7s for future inference runs.

images = pipeline(prompt_ids, p_params, rng, jit=True)[0]

Output:

    CPU times: user 464 ms, sys: 105 ms, total: 569 ms
    Wall time: 7.07 s

The returned array has shape (8, 1, 512, 512, 3). We reshape it to do away with the second dimension and procure 8 images of 512 × 512 × 3 after which convert them to PIL.

images = images.reshape((images.shape[0],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)



Visualization

Let’s create a helper function to display images in a grid.

def image_grid(imgs, rows, cols):
    w,h = imgs[0].size
    grid = Image.recent('RGB', size=(cols*w, rows*h))
    for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid
image_grid(images, 2, 4)

png



Using different prompts

We do not have to copy the same prompt in all of the devices. We are able to do whatever we wish: generate 2 prompts 4 times each, and even generate 8 different prompts directly. Let’s do this!

First, we’ll refactor the input preparation code right into a handy function:

prompts = [
    "Labrador in the style of Hokusai",
    "Painting of a squirrel skating in New York",
    "HAL-9000 in the style of Van Gogh",
    "Times Square under water, with fish and a dolphin swimming around",
    "Ancient Roman fresco showing a man working on his laptop",
    "Close-up photograph of young black woman against urban background, high quality, bokeh",
    "Armchair in the shape of an avocado",
    "Clown astronaut in space, with Earth in the background",
]
prompt_ids = pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, p_params, rng, jit=True).images
images = images.reshape((images.shape[0], ) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
image_grid(images, 2, 4)

png




How does parallelization work?

We said before that the diffusers Flax pipeline routinely compiles the model and runs it in parallel on all available devices. We’ll now briefly look inside that process to point out how it really works.

JAX parallelization will be done in multiple ways. The best one revolves around using the jax.pmap function to realize single-program, multiple-data (SPMD) parallelization. It means we’ll run several copies of the identical code, each on different data inputs. More sophisticated approaches are possible, we invite you to go over the JAX documentation and the pjit pages to explore this topic in case you have an interest!

jax.pmap does two things for us:

  • Compiles (or jits) the code, as if we had invoked jax.jit(). This doesn’t occur once we call pmap, but the primary time the pmapped function is invoked.
  • Ensures the compiled code runs in parallel in all of the available devices.

To indicate how it really works we pmap the _generate approach to the pipeline, which is the private method that runs generates images. Please, note that this method could also be renamed or removed in future releases of diffusers.

p_generate = pmap(pipeline._generate)

After we use pmap, the prepared function p_generate will conceptually do the next:

  • Invoke a duplicate of the underlying function pipeline._generate in each device.
  • Send each device a unique portion of the input arguments. That is what sharding is used for. In our case, prompt_ids has shape (8, 1, 77, 768). This array will likely be split in 8 and every copy of _generate will receive an input with shape (1, 77, 768).

We are able to code _generate completely ignoring the undeniable fact that it is going to be invoked in parallel. We just care about our batch size (1 in this instance) and the scale that make sense for our code, and do not have to alter anything to make it work in parallel.

The identical way as once we used the pipeline call, the primary time we run the next cell it is going to take some time, but then it is going to be much faster.

images = p_generate(prompt_ids, p_params, rng)
images = images.block_until_ready()
images.shape

Output:

    CPU times: user 118 ms, sys: 83.9 ms, total: 202 ms
    Wall time: 6.82 s

    (8, 1, 512, 512, 3)

We use block_until_ready() to appropriately measure inference time, because JAX uses asynchronous dispatch and returns control to the Python loop as soon as it could possibly. You needn’t use that in your code; blocking will occur routinely when you wish to use the results of a computation that has not yet been materialized.



Source link

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x