Skip to content

Conversation

@kig
Copy link
Contributor

@kig kig commented Nov 5, 2022

You can use larger batch sizes if you do the VAE decode one image at a time. Now VAE decode runs on the full batch, which limits 8GB GPU batch size and max throughput.

This PR makes VAE decode run one image at a time. This makes 20-image (512x512) batches possible on 8GB GPUs with fp16.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 5, 2022

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten
Copy link
Contributor

Hey @kig,

I don't think we should make this the default as it necessarily makes the execution slower - however it might make a lot of sense to add a enable_vae_slicing function for this. Before doing so could you maybe provide a codesnippet that can be run to see the savings in memory?

@patrickvonplaten
Copy link
Contributor

E.g. a code snippet that works with your PR for 8GB RAM GPU but not for the current pipeline implementation?

@kig
Copy link
Contributor Author

kig commented Nov 8, 2022

Hi @patrickvonplaten!

Yeah, I agree on the enable_vae_slicing approach. Here's a small snippet to test:

from diffusers import StableDiffusionPipeline
import torch
import os

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",
    revision="fp16",
    torch_dtype=torch.float16)
pipe.enable_attention_slicing()
# Disable safety_checker for testing, it's triggered by noise.
pipe.safety_checker = None
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"

for samples in [1, 4, 8, 16, 32]:
    print(f"Generating {samples} image{'s' if samples > 1 else ''}")
    images = pipe([prompt] * samples, num_inference_steps=1).images
    if len(images) != samples:
        raise RuntimeError(f"Expected {samples} images, got {len(images)}")

I added some simple time.time() profiling around the VAE decode too, from before the decode to after the .cpu() call.

image = self.vae.decode(latents).sample results:

$ python test_samples.py
Generating 1 image
100%|| 1/1 [00:01<00:00,  1.70s/it]
VAE decode elapsed 0.10824728012084961
Generating 4 images
100%|| 1/1 [00:00<00:00,  2.49it/s]
VAE decode elapsed 0.905648946762085
Generating 8 images
100%|| 1/1 [00:01<00:00,  1.33s/it]
VAE decode elapsed 2.5215229988098145
Generating 16 images
100%|| 1/1 [00:02<00:00,  2.72s/it]
Traceback (most recent call last):
  File "test_samples.py", line 15, in <module>
    images = pipe([prompt] * samples, num_inference_steps=1).images
...
RuntimeError: CUDA out of memory. Tried to allocate 2.00 GiB (GPU 0; 8.00 GiB total capacity; 5.53 GiB already allocated; 0 bytes free; 6.54 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Weird thing is, the full batch VAE decode time seems to scale superlinearly. Double the batch size and the runtime triples. [edit] Couldn't reproduce it in a Docker image. There the decode time scaled linearly until it got close to 100% memory use and then there's a sudden 10x-50x jump in decode time followed by OOM at batch size 12.

image = torch.cat([self.vae.decode(latent).sample for latent in latents.split(1)]) results:

$ python test_samples.py
Generating 1 image
100%|| 1/1 [00:01<00:00,  1.69s/it]
VAE decode elapsed 0.11412405967712402
Generating 4 images
100%|| 1/1 [00:00<00:00,  2.51it/s]
VAE decode elapsed 0.4106144905090332
Generating 8 images
100%|| 1/1 [00:00<00:00,  1.29it/s]
VAE decode elapsed 0.8257348537445068
Generating 16 images
100%|| 1/1 [00:01<00:00,  1.52s/it]
VAE decode elapsed 1.6556949615478516
Generating 32 images
100%|| 1/1 [00:08<00:00,  8.05s/it]
VAE decode elapsed 3.727773666381836

These scale more or less linearly. The 32 images is starting to hit something though, I had that decode time fluctuate between 3.5 and 8 seconds.

Huh, that was not what I was expecting. I thought the full batch at a time would have some small efficiency benefit from avoiding setup work but looks like there's something else in play.

@kig
Copy link
Contributor Author

kig commented Nov 8, 2022

Testing on a 24GB card, the VAE decode time scales linearly, but it runs into an issue with 32 samples. This with the "full batch at a time"-approach.

# python test.py
Generating 1 image
100%|| 1/1 [00:00<00:00,  3.21it/s]
VAE decode elapsed 0.08451700210571289
Generating 4 images
100%|| 1/1 [00:00<00:00,  3.14it/s]
VAE decode elapsed 0.32993221282958984
Generating 8 images
100%|| 1/1 [00:00<00:00,  2.01it/s]
VAE decode elapsed 0.6512401103973389
Generating 16 images
100%|| 1/1 [00:00<00:00,  1.00it/s]
VAE decode elapsed 1.289898157119751
Generating 32 images
100%|| 1/1 [00:01<00:00,  1.94s/it]
Traceback (most recent call last):
  File "test.py", line 15, in <module>
    images = pipe([prompt] * samples, num_inference_steps=1).images
  File "/opt/conda/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/app/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 403, in __call__
    image = self.vae.decode(latents).sample
  File "/app/diffusers/src/diffusers/models/vae.py", line 581, in decode
    dec = self.decoder(z)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1185, in _call_impl
    return forward_call(*input, **kwargs)
  File "/app/diffusers/src/diffusers/models/vae.py", line 217, in forward
    sample = up_block(sample)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1185, in _call_impl
    return forward_call(*input, **kwargs)
  File "/app/diffusers/src/diffusers/models/unet_2d_blocks.py", line 1322, in forward
    hidden_states = upsampler(hidden_states)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1185, in _call_impl
    return forward_call(*input, **kwargs)
  File "/app/diffusers/src/diffusers/models/resnet.py", line 58, in forward
    hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py", line 3918, in interpolate
    return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors)
RuntimeError: upsample_nearest_nhwc only supports output tensors with less than INT_MAX elements

[edit] Here doing VAE one image at a time seems to be 15% faster at batch size 8.

@kig
Copy link
Contributor Author

kig commented Nov 9, 2022

Doing VAE one image at a time seems to be 15% faster at batch size 8 and 2% slower at batch size 1. It's a small enough difference that it might be noise. The time difference is 10 ms per image, which doesn't move the needle all that much if an 8-image batch takes 8 seconds to generate.

However, you can go a step further and do tiled VAE decode coupled with xformers to render 8k images on a 24GB GPU...

Prompt: "a beautiful landscape photograph, 8k", 150 seconds per iteration.

img-7680-4320

3840x2160 goes 15x faster, so there's some perf cliff there.

@Birch-san
Copy link
Contributor

btw, I have a speed boost for the decoder here:
#1203

eliminates a sqrt() and a multiply, simplifies 4D tensor to 3D (num_heads is always 1), uses batch matmul.
yes, it'd be possible to implement sliced attention too, if that's a limiting factor.


latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = torch.cat([self.vae.decode(latent).sample for latent in latents.split(1)])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we please add a enable_vae_slicing() for this? I don't think all the backends such as MPS like the "for loop". Happy to add this features with a enable_vae_slicing

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this list comprehension is iterating over a regular python array, not indexing into a tensor. I don't think MPS would have trouble with that. a similar pattern is working for me on MPS (split a tensor then cat the splits):
https://github.com/apple/ml-ane-transformers/blob/da64000fa56cc85b0859bc17cb16a3d753b8304a/ane_transformers/reference/multihead_attention.py#L116

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Added enable_vae_slicing() in the latest commit.

@kig
Copy link
Contributor Author

kig commented Nov 18, 2022

Sorry for the late reply, it's been hectic. I added an enable_vae_slicing() function to PipelineStableDiffusion and moved the slicing implementation to AutoencoderKL.

Let me know if you prefer it in the pipeline and I can move it there.

@kig
Copy link
Contributor Author

kig commented Nov 18, 2022

yes, it'd be possible to implement sliced attention too, if that's a limiting factor.

Thanks! I tried making the VAE use xformers attention and that did help with memory use. But the Resnet convolution layers in src/diffusers/models/resnet.py Upsample2D and ResnetBlock2D turned out to be another limitation.

@patrickvonplaten
Copy link
Contributor

Hey @kig,

Great the API looks very nice to me now :-)

Could we do two last things:

  • 1.) Adds docs
  • 2.) Adds tests

Maybe adding two links that might help:
For:

Let me know if you need more pointers :-)

@kig
Copy link
Contributor Author

kig commented Nov 23, 2022

@patrickvonplaten here we go, I added tests and docs. Let me know how they look.

@patrickvonplaten
Copy link
Contributor

Hey @kig,

Awesome job :-) Merging this PR!

@patrickvonplaten patrickvonplaten merged commit c28d3c8 into huggingface:main Nov 29, 2022
sliard pushed a commit to sliard/diffusers that referenced this pull request Dec 21, 2022
…gingface#1150)

* StableDiffusion: Decode latents separately to run larger batches

* Move VAE sliced decode under enable_vae_sliced_decode and vae.enable_sliced_decode

* Rename sliced_decode to slicing

* fix whitespace

* fix quality check and repository consistency

* VAE slicing tests and documentation

* API doc hooks for VAE slicing

* reformat vae slicing tests

* Skip VAE slicing for one-image batches

* Documentation tweaks for VAE slicing

Co-authored-by: Ilmari Heikkinen <[email protected]>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
…gingface#1150)

* StableDiffusion: Decode latents separately to run larger batches

* Move VAE sliced decode under enable_vae_sliced_decode and vae.enable_sliced_decode

* Rename sliced_decode to slicing

* fix whitespace

* fix quality check and repository consistency

* VAE slicing tests and documentation

* API doc hooks for VAE slicing

* reformat vae slicing tests

* Skip VAE slicing for one-image batches

* Documentation tweaks for VAE slicing

Co-authored-by: Ilmari Heikkinen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants