What’s Würstchen?
Würstchen is a diffusion model, whose text-conditional component works in a highly compressed latent space of images. Why is that this necessary? Compressing data can reduce computational costs for each training and inference by orders of magnitude. Training on 1024×1024 images is way dearer than training on 32×32. Often, other works make use of a comparatively small compression, within the range of 4x – 8x spatial compression. Würstchen takes this to an extreme. Through its novel design, it achieves a 42x spatial compression! This had never been seen before, because common methods fail to faithfully reconstruct detailed images after 16x spatial compression. Würstchen employs a two-stage compression, what we call Stage A and Stage B. Stage A is a VQGAN, and Stage B is a Diffusion Autoencoder (more details will be present in the  paper). Together Stage A and B are called the Decoder, because they decode the compressed images back into pixel space. A 3rd model, Stage C, is learned in that highly compressed latent space. This training requires fractions of the compute used for current top-performing models, while also allowing cheaper and faster inference. We check with Stage C because the Prior.
Why one other text-to-image model?
Well, this one is pretty fast and efficient. Würstchen’s biggest advantages come from the indisputable fact that it might generate images much faster than models like Stable Diffusion XL, while using rather a lot less memory! So for all of us who don’t have A100s lying around, this can come in useful. Here’s a comparison with SDXL over different batch sizes:
Along with that, one other greatly significant advantage of Würstchen comes with the reduced training costs. Würstchen v1, which works at 512×512, required only 9,000 GPU hours of coaching. Comparing this to the 150,000 GPU hours spent on Stable Diffusion 1.4 suggests that this 16x reduction in cost not only advantages researchers when conducting recent experiments, but it surely also opens the door for more organizations to coach such models. Würstchen v2 used 24,602 GPU hours. With resolutions going as much as 1536, this remains to be 6x cheaper than SD1.4, which was only trained at 512×512.
You may also find an in depth explanation video here:
The way to use Würstchen?
You possibly can either try it using the Demo here:
Otherwise, the model is on the market through the Diffusers Library, so you should use the interface you might be already acquainted with. For instance, that is the right way to run inference using the AutoPipeline:
import torch
from diffusers import AutoPipelineForText2Image
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
pipeline = AutoPipelineForText2Image.from_pretrained("warp-ai/wuerstchen", torch_dtype=torch.float16).to("cuda")
caption = "Anthropomorphic cat dressed as a firefighter"
images = pipeline(
caption,
height=1024,
width=1536,
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
prior_guidance_scale=4.0,
num_images_per_prompt=4,
).images
What image sizes does Würstchen work on?
Würstchen was trained on image resolutions between 1024×1024 & 1536×1536. We sometimes also observe good outputs at resolutions like 1024×2048. Be happy to try it out.
We also observed that the Prior (Stage C) adapts extremely fast to recent resolutions. So finetuning it at 2048×2048 needs to be computationally low-cost.

Models on the Hub
All checkpoints may also be seen on the Huggingface Hub. Multiple checkpoints, in addition to future demos and model weights will be found there. Immediately there are 3 checkpoints for the Prior available and 1 checkpoint for the Decoder.
Take a have a look at the documentation where the checkpoints are explained and what different Prior models are and will be used for.
Diffusers integration
Because Würstchen is fully integrated in diffusers, it robotically comes with various goodies and optimizations out of the box. These include:
- Automatic use of PyTorch 2
SDPAaccelerated attention, as described below. - Support for the xFormers flash attention implementation, if it’s essential to use PyTorch 1.x as an alternative of two.
- Model offload, to maneuver unused components to CPU while they should not in use. This protects memory with negligible performance impact.
- Sequential CPU offload, for situations where memory is admittedly precious. Memory use shall be minimized, at the fee of slower inference.
- Prompt weighting with the Compel library.
- Support for the
mpsdevice on Apple Silicon macs. - Use of generators for reproducibility.
- Sensible defaults for inference to supply high-quality leads to most situations. In fact you may tweak all parameters as you want!
Optimisation Technique 1: Flash Attention
Ranging from version 2.0, PyTorch has integrated a highly optimised and resource-friendly version of the eye mechanism called torch.nn.functional.scaled_dot_product_attention or SDPA. Depending on the character of the input, this function taps into multiple underlying optimisations. Its performance and memory efficiency outshine the standard attention model. Remarkably, the SDPA function mirrors the characteristics of the flash attention technique, as highlighted within the research paper Fast and Memory-Efficient Exact Attention with IO-Awareness penned by Dao and team.
When you’re using Diffusers with PyTorch 2.0 or a later version, and the SDPA function is accessible, these enhancements are robotically applied. Start by establishing torch 2.0 or a more moderen version using the official guidelines!
images = pipeline(caption, height=1024, width=1536, prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, prior_guidance_scale=4.0, num_images_per_prompt=4).images
For an in-depth have a look at how diffusers leverages SDPA, try the documentation.
When you’re on a version of Pytorch sooner than 2.0, you may still achieve memory-efficient attention using the xFormers library:
pipeline.enable_xformers_memory_efficient_attention()
Optimisation Technique 2: Torch Compile
When you’re on the hunt for an additional performance boost, you may make use of torch.compile. It’s best to use it to each the prior’s
and decoder’s major model for the most important increase in performance.
pipeline.prior_prior = torch.compile(pipeline.prior_prior , mode="reduce-overhead", fullgraph=True)
pipeline.decoder = torch.compile(pipeline.decoder, mode="reduce-overhead", fullgraph=True)
Keep in mind that the initial inference step will take an extended time (as much as 2 minutes) while the models are being compiled. After you can just normally run inference:
images = pipeline(caption, height=1024, width=1536, prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, prior_guidance_scale=4.0, num_images_per_prompt=4).images
And the excellent news is that this compilation is a one-time execution. Post that, you are set to experience faster inferences consistently for a similar image resolutions. The initial time investment in compilation is quickly offset by the next speed advantages. For a deeper dive into torch.compile and its nuances, try the official documentation.
How was the model trained?
The flexibility to coach this model was only possible through compute resources provided by Stability AI.
We wanna say a special thanks to Stability for giving us the likelihood to pursue this type of research, with the prospect
to make it accessible to so many more people!
Resources
- Further details about this model will be present in the official diffusers documentation.
- All of the checkpoints will be found on the hub
- You possibly can check out the demo here.
- Join our Discord if you should discuss future projects and even contribute along with your own ideas!
- Training code and more will be present in the official GitHub repository





