-
Notifications
You must be signed in to change notification settings - Fork 6.5k
StableDiffusion: Decode latents separately to run larger batches #1150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
|
Hey @kig, I don't think we should make this the default as it necessarily makes the execution slower - however it might make a lot of sense to add a |
|
E.g. a code snippet that works with your PR for 8GB RAM GPU but not for the current pipeline implementation? |
|
Yeah, I agree on the from diffusers import StableDiffusionPipeline
import torch
import os
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",
revision="fp16",
torch_dtype=torch.float16)
pipe.enable_attention_slicing()
# Disable safety_checker for testing, it's triggered by noise.
pipe.safety_checker = None
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
for samples in [1, 4, 8, 16, 32]:
print(f"Generating {samples} image{'s' if samples > 1 else ''}")
images = pipe([prompt] * samples, num_inference_steps=1).images
if len(images) != samples:
raise RuntimeError(f"Expected {samples} images, got {len(images)}")I added some simple time.time() profiling around the VAE decode too, from before the decode to after the
$ python test_samples.py
Generating 1 image
100%|█| 1/1 [00:01<00:00, 1.70s/it]
VAE decode elapsed 0.10824728012084961
Generating 4 images
100%|█| 1/1 [00:00<00:00, 2.49it/s]
VAE decode elapsed 0.905648946762085
Generating 8 images
100%|█| 1/1 [00:01<00:00, 1.33s/it]
VAE decode elapsed 2.5215229988098145
Generating 16 images
100%|█| 1/1 [00:02<00:00, 2.72s/it]
Traceback (most recent call last):
File "test_samples.py", line 15, in <module>
images = pipe([prompt] * samples, num_inference_steps=1).images
...
RuntimeError: CUDA out of memory. Tried to allocate 2.00 GiB (GPU 0; 8.00 GiB total capacity; 5.53 GiB already allocated; 0 bytes free; 6.54 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
$ python test_samples.py
Generating 1 image
100%|█| 1/1 [00:01<00:00, 1.69s/it]
VAE decode elapsed 0.11412405967712402
Generating 4 images
100%|█| 1/1 [00:00<00:00, 2.51it/s]
VAE decode elapsed 0.4106144905090332
Generating 8 images
100%|█| 1/1 [00:00<00:00, 1.29it/s]
VAE decode elapsed 0.8257348537445068
Generating 16 images
100%|█| 1/1 [00:01<00:00, 1.52s/it]
VAE decode elapsed 1.6556949615478516
Generating 32 images
100%|█| 1/1 [00:08<00:00, 8.05s/it]
VAE decode elapsed 3.727773666381836These scale more or less linearly. The 32 images is starting to hit something though, I had that decode time fluctuate between 3.5 and 8 seconds. Huh, that was not what I was expecting. I thought the full batch at a time would have some small efficiency benefit from avoiding setup work but looks like there's something else in play. |
|
Testing on a 24GB card, the VAE decode time scales linearly, but it runs into an issue with 32 samples. This with the "full batch at a time"-approach. # python test.py
Generating 1 image
100%|█| 1/1 [00:00<00:00, 3.21it/s]
VAE decode elapsed 0.08451700210571289
Generating 4 images
100%|█| 1/1 [00:00<00:00, 3.14it/s]
VAE decode elapsed 0.32993221282958984
Generating 8 images
100%|█| 1/1 [00:00<00:00, 2.01it/s]
VAE decode elapsed 0.6512401103973389
Generating 16 images
100%|█| 1/1 [00:00<00:00, 1.00it/s]
VAE decode elapsed 1.289898157119751
Generating 32 images
100%|█| 1/1 [00:01<00:00, 1.94s/it]
Traceback (most recent call last):
File "test.py", line 15, in <module>
images = pipe([prompt] * samples, num_inference_steps=1).images
File "/opt/conda/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/app/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 403, in __call__
image = self.vae.decode(latents).sample
File "/app/diffusers/src/diffusers/models/vae.py", line 581, in decode
dec = self.decoder(z)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1185, in _call_impl
return forward_call(*input, **kwargs)
File "/app/diffusers/src/diffusers/models/vae.py", line 217, in forward
sample = up_block(sample)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1185, in _call_impl
return forward_call(*input, **kwargs)
File "/app/diffusers/src/diffusers/models/unet_2d_blocks.py", line 1322, in forward
hidden_states = upsampler(hidden_states)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1185, in _call_impl
return forward_call(*input, **kwargs)
File "/app/diffusers/src/diffusers/models/resnet.py", line 58, in forward
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
File "/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py", line 3918, in interpolate
return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors)
RuntimeError: upsample_nearest_nhwc only supports output tensors with less than INT_MAX elements[edit] Here doing VAE one image at a time seems to be 15% faster at batch size 8. |
|
Doing VAE one image at a time seems to be 15% faster at batch size 8 and 2% slower at batch size 1. It's a small enough difference that it might be noise. The time difference is 10 ms per image, which doesn't move the needle all that much if an 8-image batch takes 8 seconds to generate. However, you can go a step further and do tiled VAE decode coupled with xformers to render 8k images on a 24GB GPU... Prompt: "a beautiful landscape photograph, 8k", 150 seconds per iteration. 3840x2160 goes 15x faster, so there's some perf cliff there. |
|
btw, I have a speed boost for the decoder here: eliminates a |
|
|
||
| latents = 1 / 0.18215 * latents | ||
| image = self.vae.decode(latents).sample | ||
| image = torch.cat([self.vae.decode(latent).sample for latent in latents.split(1)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we please add a enable_vae_slicing() for this? I don't think all the backends such as MPS like the "for loop". Happy to add this features with a enable_vae_slicing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this list comprehension is iterating over a regular python array, not indexing into a tensor. I don't think MPS would have trouble with that. a similar pattern is working for me on MPS (split a tensor then cat the splits):
https://github.com/apple/ml-ane-transformers/blob/da64000fa56cc85b0859bc17cb16a3d753b8304a/ane_transformers/reference/multihead_attention.py#L116
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Added enable_vae_slicing() in the latest commit.
|
Sorry for the late reply, it's been hectic. I added an Let me know if you prefer it in the pipeline and I can move it there. |
Thanks! I tried making the VAE use xformers attention and that did help with memory use. But the Resnet convolution layers in src/diffusers/models/resnet.py |
|
Hey @kig, Great the API looks very nice to me now :-) Could we do two last things:
Maybe adding two links that might help:
Let me know if you need more pointers :-) |
|
@patrickvonplaten here we go, I added tests and docs. Let me know how they look. |
|
Hey @kig, Awesome job :-) Merging this PR! |
…gingface#1150) * StableDiffusion: Decode latents separately to run larger batches * Move VAE sliced decode under enable_vae_sliced_decode and vae.enable_sliced_decode * Rename sliced_decode to slicing * fix whitespace * fix quality check and repository consistency * VAE slicing tests and documentation * API doc hooks for VAE slicing * reformat vae slicing tests * Skip VAE slicing for one-image batches * Documentation tweaks for VAE slicing Co-authored-by: Ilmari Heikkinen <[email protected]>
…gingface#1150) * StableDiffusion: Decode latents separately to run larger batches * Move VAE sliced decode under enable_vae_sliced_decode and vae.enable_sliced_decode * Rename sliced_decode to slicing * fix whitespace * fix quality check and repository consistency * VAE slicing tests and documentation * API doc hooks for VAE slicing * reformat vae slicing tests * Skip VAE slicing for one-image batches * Documentation tweaks for VAE slicing Co-authored-by: Ilmari Heikkinen <[email protected]>

You can use larger batch sizes if you do the VAE decode one image at a time. Now VAE decode runs on the full batch, which limits 8GB GPU batch size and max throughput.
This PR makes VAE decode run one image at a time. This makes 20-image (512x512) batches possible on 8GB GPUs with fp16.