Skip to content

Commit c28d3c8

Browse files
kigIlmari Heikkinen
andauthored
StableDiffusion: Decode latents separately to run larger batches (#1150)
* StableDiffusion: Decode latents separately to run larger batches * Move VAE sliced decode under enable_vae_sliced_decode and vae.enable_sliced_decode * Rename sliced_decode to slicing * fix whitespace * fix quality check and repository consistency * VAE slicing tests and documentation * API doc hooks for VAE slicing * reformat vae slicing tests * Skip VAE slicing for one-image batches * Documentation tweaks for VAE slicing Co-authored-by: Ilmari Heikkinen <[email protected]>
1 parent bcb6cc1 commit c28d3c8

File tree

6 files changed

+171
-1
lines changed

6 files changed

+171
-1
lines changed

docs/source/api/pipelines/stable_diffusion.mdx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ If you want to use all possible use cases in a single `DiffusionPipeline` you ca
7676
- __call__
7777
- enable_attention_slicing
7878
- disable_attention_slicing
79+
- enable_vae_slicing
80+
- disable_vae_slicing
7981

8082
## StableDiffusionImg2ImgPipeline
8183
[[autodoc]] StableDiffusionImg2ImgPipeline

docs/source/optimization/fp16.mdx

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,34 @@ image = pipe(prompt).images[0]
117117

118118
There's a small performance penalty of about 10% slower inference times, but this method allows you to use Stable Diffusion in as little as 3.2 GB of VRAM!
119119

120+
121+
## Sliced VAE decode for larger batches
122+
123+
To decode large batches of images with limited VRAM, or to enable batches with 32 images or more, you can use sliced VAE decode that decodes the batch latents one image at a time.
124+
125+
You likely want to couple this with [`~StableDiffusionPipeline.enable_attention_slicing`] or [`~StableDiffusionPipeline.enable_xformers_memory_efficient_attention`] to further minimize memory use.
126+
127+
To perform the VAE decode one image at a time, invoke [`~StableDiffusionPipeline.enable_vae_slicing`] in your pipeline before inference. For example:
128+
129+
```Python
130+
import torch
131+
from diffusers import StableDiffusionPipeline
132+
133+
pipe = StableDiffusionPipeline.from_pretrained(
134+
"runwayml/stable-diffusion-v1-5",
135+
revision="fp16",
136+
torch_dtype=torch.float16,
137+
)
138+
pipe = pipe.to("cuda")
139+
140+
prompt = "a photo of an astronaut riding a horse on mars"
141+
pipe.enable_vae_slicing()
142+
images = pipe([prompt] * 32).images
143+
```
144+
145+
You may see a small performance boost in VAE decode on multi-image batches. There should be no performance impact on single-image batches.
146+
147+
120148
## Offloading to CPU with accelerate for memory savings
121149

122150
For additional memory savings, you can offload the weights to CPU and load them to GPU when performing the forward pass.

src/diffusers/models/vae.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,7 @@ def __init__(
565565

566566
self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
567567
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
568+
self.use_slicing = False
568569

569570
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
570571
h = self.encoder(x)
@@ -576,7 +577,7 @@ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderK
576577

577578
return AutoencoderKLOutput(latent_dist=posterior)
578579

579-
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
580+
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
580581
z = self.post_quant_conv(z)
581582
dec = self.decoder(z)
582583

@@ -585,6 +586,34 @@ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decode
585586

586587
return DecoderOutput(sample=dec)
587588

589+
def enable_slicing(self):
590+
r"""
591+
Enable sliced VAE decoding.
592+
593+
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
594+
steps. This is useful to save some memory and allow larger batch sizes.
595+
"""
596+
self.use_slicing = True
597+
598+
def disable_slicing(self):
599+
r"""
600+
Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
601+
decoding in one step.
602+
"""
603+
self.use_slicing = False
604+
605+
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
606+
if self.use_slicing and z.shape[0] > 1:
607+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
608+
decoded = torch.cat(decoded_slices)
609+
else:
610+
decoded = self._decode(z).sample
611+
612+
if not return_dict:
613+
return (decoded,)
614+
615+
return DecoderOutput(sample=decoded)
616+
588617
def forward(
589618
self,
590619
sample: torch.FloatTensor,

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,22 @@ def disable_attention_slicing(self):
216216
# set slice_size = `None` to disable `attention slicing`
217217
self.enable_attention_slicing(None)
218218

219+
def enable_vae_slicing(self):
220+
r"""
221+
Enable sliced VAE decoding.
222+
223+
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
224+
steps. This is useful to save some memory and allow larger batch sizes.
225+
"""
226+
self.vae.enable_slicing()
227+
228+
def disable_vae_slicing(self):
229+
r"""
230+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
231+
computing decoding in one step.
232+
"""
233+
self.vae.disable_slicing()
234+
219235
def enable_sequential_cpu_offload(self, gpu_id=0):
220236
r"""
221237
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,22 @@ def disable_attention_slicing(self):
215215
# set slice_size = `None` to disable `attention slicing`
216216
self.enable_attention_slicing(None)
217217

218+
def enable_vae_slicing(self):
219+
r"""
220+
Enable sliced VAE decoding.
221+
222+
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
223+
steps. This is useful to save some memory and allow larger batch sizes.
224+
"""
225+
self.vae.enable_slicing()
226+
227+
def disable_vae_slicing(self):
228+
r"""
229+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
230+
computing decoding in one step.
231+
"""
232+
self.vae.disable_slicing()
233+
218234
def enable_sequential_cpu_offload(self, gpu_id=0):
219235
r"""
220236
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,

tests/pipelines/stable_diffusion/test_stable_diffusion.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,46 @@ def test_stable_diffusion_attention_chunk(self):
557557

558558
assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 1e-4
559559

560+
def test_stable_diffusion_vae_slicing(self):
561+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
562+
unet = self.dummy_cond_unet
563+
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
564+
vae = self.dummy_vae
565+
bert = self.dummy_text_encoder
566+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
567+
568+
# make sure here that pndm scheduler skips prk
569+
sd_pipe = StableDiffusionPipeline(
570+
unet=unet,
571+
scheduler=scheduler,
572+
vae=vae,
573+
text_encoder=bert,
574+
tokenizer=tokenizer,
575+
safety_checker=None,
576+
feature_extractor=self.dummy_extractor,
577+
)
578+
sd_pipe = sd_pipe.to(device)
579+
sd_pipe.set_progress_bar_config(disable=None)
580+
581+
prompt = "A painting of a squirrel eating a burger"
582+
583+
image_count = 4
584+
585+
generator = torch.Generator(device=device).manual_seed(0)
586+
output_1 = sd_pipe(
587+
[prompt] * image_count, generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np"
588+
)
589+
590+
# make sure sliced vae decode yields the same result
591+
sd_pipe.enable_vae_slicing()
592+
generator = torch.Generator(device=device).manual_seed(0)
593+
output_2 = sd_pipe(
594+
[prompt] * image_count, generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np"
595+
)
596+
597+
# there is a small discrepancy at image borders vs. full batch decode
598+
assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 3e-3
599+
560600
def test_stable_diffusion_negative_prompt(self):
561601
device = "cpu" # ensure determinism for the device-dependent torch.Generator
562602
unet = self.dummy_cond_unet
@@ -886,6 +926,45 @@ def test_stable_diffusion_memory_chunking(self):
886926
assert mem_bytes > 3.75 * 10**9
887927
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3
888928

929+
def test_stable_diffusion_vae_slicing(self):
930+
torch.cuda.reset_peak_memory_stats()
931+
model_id = "CompVis/stable-diffusion-v1-4"
932+
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
933+
pipe.to(torch_device)
934+
pipe.set_progress_bar_config(disable=None)
935+
pipe.enable_attention_slicing()
936+
937+
prompt = "a photograph of an astronaut riding a horse"
938+
939+
# enable vae slicing
940+
pipe.enable_vae_slicing()
941+
generator = torch.Generator(device=torch_device).manual_seed(0)
942+
with torch.autocast(torch_device):
943+
output_chunked = pipe(
944+
[prompt] * 4, generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
945+
)
946+
image_chunked = output_chunked.images
947+
948+
mem_bytes = torch.cuda.max_memory_allocated()
949+
torch.cuda.reset_peak_memory_stats()
950+
# make sure that less than 4 GB is allocated
951+
assert mem_bytes < 4e9
952+
953+
# disable vae slicing
954+
pipe.disable_vae_slicing()
955+
generator = torch.Generator(device=torch_device).manual_seed(0)
956+
with torch.autocast(torch_device):
957+
output = pipe(
958+
[prompt] * 4, generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
959+
)
960+
image = output.images
961+
962+
# make sure that more than 4 GB is allocated
963+
mem_bytes = torch.cuda.max_memory_allocated()
964+
assert mem_bytes > 4e9
965+
# There is a small discrepancy at the image borders vs. a fully batched version.
966+
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 3e-3
967+
889968
def test_stable_diffusion_text2img_pipeline_fp16(self):
890969
torch.cuda.reset_peak_memory_stats()
891970
model_id = "CompVis/stable-diffusion-v1-4"

0 commit comments

Comments
 (0)