Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/source/en/api/pipelines/stable_diffusion/text2img.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,6 @@ Available Checkpoints are:
- enable_vae_slicing
- disable_vae_slicing
- enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention
- enable_vae_tiling
- disable_vae_tiling
28 changes: 28 additions & 0 deletions docs/source/en/optimization/fp16.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,34 @@ images = pipe([prompt] * 32).images
You may see a small performance boost in VAE decode on multi-image batches. There should be no performance impact on single-image batches.


## Tiled VAE decode and encode for large images

Tiled VAE processing makes it possible to work with large images on limited VRAM. For example, generating 4k images in 8GB of VRAM. Tiled VAE decoder splits the image into overlapping tiles, decodes the tiles, and blends the outputs to make the final image.

You want to couple this with [`~StableDiffusionPipeline.enable_attention_slicing`] or [`~StableDiffusionPipeline.enable_xformers_memory_efficient_attention`] to further minimize memory use.

To use tiled VAE processing, invoke [`~StableDiffusionPipeline.enable_vae_tiling`] in your pipeline before inference. For example:

```python
import torch
from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler

pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
prompt = "a beautiful landscape photograph"
pipe.enable_vae_tiling()
pipe.enable_xformers_memory_efficient_attention()

image = pipe([prompt], width=3840, height=2224, num_inference_steps=20).images[0]
```

The output image will have some tile-to-tile tone variation from the tiles having separate decoders, but you shouldn't see sharp seams between the tiles. The tiling is turned off for images that are 512x512 or smaller.


<a name="sequential_offloading"></a>
## Offloading to CPU with accelerate for memory savings

Expand Down
165 changes: 149 additions & 16 deletions src/diffusers/models/autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,54 @@ def __init__(

self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)

self.use_slicing = False
self.use_tiling = False

# only relevant if vae tiling is enabled
self.tile_sample_min_size = self.config.sample_size
sample_size = (
self.config.sample_size[0]
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
)
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25

def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow
the processing of larger images.
"""
self.use_tiling = use_tiling

def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
computing decoding in one step.
"""
self.enable_tiling(False)

def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True

def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False

@apply_forward_hook
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
return self.tiled_encode(x, return_dict=return_dict)

h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
Expand All @@ -121,6 +165,9 @@ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderK
return AutoencoderKLOutput(latent_dist=posterior)

def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
return self.tiled_decode(z, return_dict=return_dict)

z = self.post_quant_conv(z)
dec = self.decoder(z)

Expand All @@ -129,22 +176,6 @@ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decod

return DecoderOutput(sample=dec)

def enable_slicing(self):
r"""
Enable sliced VAE decoding.

When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True

def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False

@apply_forward_hook
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
if self.use_slicing and z.shape[0] > 1:
Expand All @@ -158,6 +189,108 @@ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decode

return DecoderOutput(sample=decoded)

def blend_v(self, a, b, blend_extent):
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b

def blend_h(self, a, b, blend_extent):
for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b

def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a question, what use case does tiled encoding fulfills? Also, do we need blending during the encoding phase? If we used this as an autoencoder of a large image, I would have thought that blending during decoding would be enough to avoid seams between the tiles.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's just to save memory? Encoding a large image is pretty memory intensive no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I meant that we don't need encoding for inference, and I don't think we'll train with very large images. It was just out of curiosity.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did need it for something. IIRC img2img with large images.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, you did img2img on huge images; cool, understood.

r"""Encode a batch of images using a tiled encoder.
Args:
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is:
different from non-tiled encoding due to each tile using a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
look of the output, but they should be much less noticeable.
x (`torch.FloatTensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple.
"""
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent

# Split the image into 512x512 tiles and encode them separately.
rows = []
for i in range(0, x.shape[2], overlap_size):
row = []
for j in range(0, x.shape[3], overlap_size):
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = self.encoder(tile)
tile = self.quant_conv(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))

moments = torch.cat(result_rows, dim=2)
posterior = DiagonalGaussianDistribution(moments)

if not return_dict:
return (posterior,)

return AutoencoderKLOutput(latent_dist=posterior)

def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
r"""Decode a batch of images using a tiled decoder.
Args:
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is:
different from non-tiled decoding due to each tile using a different decoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
look of the output, but they should be much less noticeable.
z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to
`True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent

# Split z into overlapping 64x64 tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, z.shape[2], overlap_size):
row = []
for j in range(0, z.shape[3], overlap_size):
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))

dec = torch.cat(result_rows, dim=2)
if not return_dict:
return (dec,)

return DecoderOutput(sample=dec)

def forward(
self,
sample: torch.FloatTensor,
Expand Down
16 changes: 16 additions & 0 deletions src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,22 @@ def disable_vae_slicing(self):
"""
self.vae.disable_slicing()

def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding.

When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
"""
self.vae.enable_tiling()

def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()

def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,22 @@ def disable_vae_slicing(self):
"""
self.vae.disable_slicing()

def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding.

When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
"""
self.vae.enable_tiling()

def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()

def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
Expand Down
1 change: 1 addition & 0 deletions tests/pipelines/audio_diffusion/test_audio_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def dummy_vqvae_and_unet(self):
)
return vqvae, unet

@slow
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something we need?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this test takes a minute and this model has pretty much 0 usage, so disabling for fast

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the decorator is on test_audio_diffusion, seemed unrelated :)

def test_audio_diffusion(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
mel = Mel()
Expand Down
75 changes: 75 additions & 0 deletions tests/pipelines/stable_diffusion/test_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,29 @@ def test_stable_diffusion_vae_slicing(self):
# there is a small discrepancy at image borders vs. full batch decode
assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 3e-3

def test_stable_diffusion_vae_tiling(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice test!

device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()

# make sure here that pndm scheduler skips prk
components["safety_checker"] = None
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)

prompt = "A painting of a squirrel eating a burger"

# Test that tiled decode at 512x512 yields the same result as the non-tiled decode
generator = torch.Generator(device=device).manual_seed(0)
output_1 = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")

# make sure tiled vae decode yields the same result
sd_pipe.enable_vae_tiling()
generator = torch.Generator(device=device).manual_seed(0)
output_2 = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")

assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 5e-1

def test_stable_diffusion_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
Expand Down Expand Up @@ -699,6 +722,58 @@ def test_stable_diffusion_vae_slicing(self):
# There is a small discrepancy at the image borders vs. a fully batched version.
assert np.abs(image_sliced - image).max() < 1e-2

def test_stable_diffusion_vae_tiling(self):
torch.cuda.reset_peak_memory_stats()
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
pipe.unet = pipe.unet.to(memory_format=torch.channels_last)
pipe.vae = pipe.vae.to(memory_format=torch.channels_last)

prompt = "a photograph of an astronaut riding a horse"

# enable vae tiling
pipe.enable_vae_tiling()
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
output_chunked = pipe(
[prompt],
width=640,
height=640,
generator=generator,
guidance_scale=7.5,
num_inference_steps=2,
output_type="numpy",
)
image_chunked = output_chunked.images

mem_bytes = torch.cuda.max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
# make sure that less than 4 GB is allocated
assert mem_bytes < 4e9

# disable vae tiling
pipe.disable_vae_tiling()
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
output = pipe(
[prompt],
width=640,
height=640,
generator=generator,
guidance_scale=7.5,
num_inference_steps=2,
output_type="numpy",
)
image = output.images

# make sure that more than 4 GB is allocated
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes > 4e9
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-2

def test_stable_diffusion_fp16_vs_autocast(self):
# this test makes sure that the original model with autocast
# and the new model with fp16 yield the same result
Expand Down