From 49b61c8e5bc5f02a655f1af0a245c573da02af1f Mon Sep 17 00:00:00 2001 From: Ilmari Heikkinen Date: Sun, 27 Nov 2022 19:02:24 +0800 Subject: [PATCH 01/19] Tiled VAE for high-res text2img and img2img --- src/diffusers/models/vae.py | 109 ++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 30de343d08ee..a61e52b8b248 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -566,7 +566,18 @@ def __init__( self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) + self.use_tiling = False + + def enable_tiling(self, use_tiling: bool = True): + self.use_tiling = use_tiling + + def disable_tiling(self): + self.enable_tiling(False) + def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: + if self.use_tiling: + return self.tiled_encode(x, return_dict=return_dict) + h = self.encoder(x) moments = self.quant_conv(h) posterior = DiagonalGaussianDistribution(moments) @@ -577,6 +588,9 @@ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderK return AutoencoderKLOutput(latent_dist=posterior) def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + if self.use_tiling: + return self.tiled_decode(z, return_dict=return_dict) + z = self.post_quant_conv(z) dec = self.decoder(z) @@ -585,6 +599,101 @@ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decode return DecoderOutput(sample=dec) + def blend_v(self, a, b, blend_width): + for y in range(blend_width): + b[:, :, y, :] = a[:, :, -blend_width+y, :] * (1 - y / blend_width) + b[:, :, y, :] * (y / blend_width) + return b + + def blend_h(self, a, b, blend_width): + for x in range(blend_width): + b[:, :, :, x] = a[:, :, :, -blend_width+x] * (1 - x / blend_width) + b[:, :, :, x] * (x / blend_width) + return b + + def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + The end result of tiled encoding is different from non-tiled encoding due to each tile using a different encoder. + To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output. + You may still see tile-sized changes in the look of the output, but they should be much less noticeable. + + Args: + x (`torch.FloatTensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple. + """ + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], 384): + row = [] + for j in range(0, x.shape[3], 384): + tile = x[:, :, i:i+512, j:j+512] + tile = self.encoder(tile) + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i,row in enumerate(rows): + result_row = [] + for j,tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i-1][j], tile, 16) + if j > 0: + tile = self.blend_h(row[j-1], tile, 16) + result_row.append(tile[:, :, :48, :48]) + result_rows.append(torch.cat(result_row, dim=3)) + + moments = torch.cat(result_rows, dim=2) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + r"""Decode a batch of images using a tiled decoder. + + The end result of tiled decoding is different from non-tiled decoding due to each tile using a different decoder. + To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output. + You may still see tile-sized changes in the look of the output, but they should be much less noticeable. + + Args: + z (`torch.FloatTensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[2], 48): + row = [] + for j in range(0, z.shape[3], 48): + tile = z[:, :, i:i+64, j:j+64] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i,row in enumerate(rows): + result_row = [] + for j,tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i-1][j], tile, 128) + if j > 0: + tile = self.blend_h(row[j-1], tile, 128) + result_row.append(tile[:, :, :384, :384]) + result_rows.append(torch.cat(result_row, dim=3)) + + dec = torch.cat(result_rows, dim=2) + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + def forward( self, sample: torch.FloatTensor, From 63d5661b47846350532fb23901efb718203b9b70 Mon Sep 17 00:00:00 2001 From: Ilmari Heikkinen Date: Mon, 28 Nov 2022 10:56:16 +0800 Subject: [PATCH 02/19] vae tiling, fix formatting --- src/diffusers/models/vae.py | 38 ++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index a61e52b8b248..efedd083d0bf 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -601,20 +601,20 @@ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decode def blend_v(self, a, b, blend_width): for y in range(blend_width): - b[:, :, y, :] = a[:, :, -blend_width+y, :] * (1 - y / blend_width) + b[:, :, y, :] * (y / blend_width) + b[:, :, y, :] = a[:, :, -blend_width + y, :] * (1 - y / blend_width) + b[:, :, y, :] * (y / blend_width) return b def blend_h(self, a, b, blend_width): for x in range(blend_width): - b[:, :, :, x] = a[:, :, :, -blend_width+x] * (1 - x / blend_width) + b[:, :, :, x] * (x / blend_width) + b[:, :, :, x] = a[:, :, :, -blend_width + x] * (1 - x / blend_width) + b[:, :, :, x] * (x / blend_width) return b def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: - r"""Encode a batch of images using a tiled encoder. + r"""Encode a batch of images using a tiled encoder. - The end result of tiled encoding is different from non-tiled encoding due to each tile using a different encoder. - To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output. - You may still see tile-sized changes in the look of the output, but they should be much less noticeable. + The end result of tiled encoding is different from non-tiled encoding due to each tile using a different + encoder. To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output. You may + still see tile-sized changes in the look of the output, but they should be much less noticeable. Args: x (`torch.FloatTensor`): Input batch of images. @@ -626,21 +626,21 @@ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> Autoen for i in range(0, x.shape[2], 384): row = [] for j in range(0, x.shape[3], 384): - tile = x[:, :, i:i+512, j:j+512] + tile = x[:, :, i : i + 512, j : j + 512] tile = self.encoder(tile) tile = self.quant_conv(tile) row.append(tile) rows.append(row) result_rows = [] - for i,row in enumerate(rows): + for i, row in enumerate(rows): result_row = [] - for j,tile in enumerate(row): + for j, tile in enumerate(row): # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: - tile = self.blend_v(rows[i-1][j], tile, 16) + tile = self.blend_v(rows[i - 1][j], tile, 16) if j > 0: - tile = self.blend_h(row[j-1], tile, 16) + tile = self.blend_h(row[j - 1], tile, 16) result_row.append(tile[:, :, :48, :48]) result_rows.append(torch.cat(result_row, dim=3)) @@ -655,9 +655,9 @@ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> Autoen def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: r"""Decode a batch of images using a tiled decoder. - The end result of tiled decoding is different from non-tiled decoding due to each tile using a different decoder. - To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output. - You may still see tile-sized changes in the look of the output, but they should be much less noticeable. + The end result of tiled decoding is different from non-tiled decoding due to each tile using a different + decoder. To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output. You may + still see tile-sized changes in the look of the output, but they should be much less noticeable. Args: z (`torch.FloatTensor`): Input batch of latent vectors. @@ -670,21 +670,21 @@ def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[ for i in range(0, z.shape[2], 48): row = [] for j in range(0, z.shape[3], 48): - tile = z[:, :, i:i+64, j:j+64] + tile = z[:, :, i : i + 64, j : j + 64] tile = self.post_quant_conv(tile) decoded = self.decoder(tile) row.append(decoded) rows.append(row) result_rows = [] - for i,row in enumerate(rows): + for i, row in enumerate(rows): result_row = [] - for j,tile in enumerate(row): + for j, tile in enumerate(row): # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: - tile = self.blend_v(rows[i-1][j], tile, 128) + tile = self.blend_v(rows[i - 1][j], tile, 128) if j > 0: - tile = self.blend_h(row[j-1], tile, 128) + tile = self.blend_h(row[j - 1], tile, 128) result_row.append(tile[:, :, :384, :384]) result_rows.append(torch.cat(result_row, dim=3)) From 4b6536d49a6237c04e9a0c7ebe0ee2de66bfc392 Mon Sep 17 00:00:00 2001 From: Ilmari Heikkinen Date: Sun, 4 Dec 2022 02:19:11 +0800 Subject: [PATCH 03/19] enable_vae_tiling API and tests --- src/diffusers/models/vae.py | 36 ++++-- .../alt_diffusion/pipeline_alt_diffusion.py | 16 +++ .../pipeline_stable_diffusion.py | 16 +++ .../stable_diffusion/test_stable_diffusion.py | 115 ++++++++++++++++++ 4 files changed, 170 insertions(+), 13 deletions(-) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 9aa0f3f0a300..be1581154648 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -569,24 +569,21 @@ def __init__( self.use_tiling = False def enable_tiling(self, use_tiling: bool = True): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ self.use_tiling = use_tiling def disable_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ self.enable_tiling(False) - def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: - if self.use_tiling: - return self.tiled_encode(x, return_dict=return_dict) - - h = self.encoder(x) - moments = self.quant_conv(h) - posterior = DiagonalGaussianDistribution(moments) - - if not return_dict: - return (posterior,) - - return AutoencoderKLOutput(latent_dist=posterior) - def enable_slicing(self): r""" Enable sliced VAE decoding. @@ -603,6 +600,19 @@ def disable_slicing(self): """ self.use_slicing = False + def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: + if self.use_tiling: + return self.tiled_encode(x, return_dict=return_dict) + + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: if self.use_tiling: return self.tiled_decode(z, return_dict=return_dict) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index fb64a34a0bd8..1c748b01c3da 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -214,6 +214,22 @@ def disable_vae_slicing(self): """ self.vae.disable_slicing() + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index a3a8703f3ea4..5162aebad732 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -213,6 +213,22 @@ def disable_vae_slicing(self): """ self.vae.disable_slicing() + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 8dce61c3a456..48c06ad5819e 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -597,6 +597,68 @@ def test_stable_diffusion_vae_slicing(self): # there is a small discrepancy at image borders vs. full batch decode assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 3e-3 + def test_stable_diffusion_vae_tiling(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + + # Test that tiled decode at 512x512 yields the same result as the non-tiled decode + generator = torch.Generator(device=device).manual_seed(0) + output_1 = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + # make sure tiled vae decode yields the same result + sd_pipe.enable_vae_tiling() + generator = torch.Generator(device=device).manual_seed(0) + output_2 = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 1e-4 + + # Test that tiled decode at 1024x1024 yields a mostly similar result as the non-tiled decode + sd_pipe.disable_vae_tiling() + generator = torch.Generator(device=device).manual_seed(0) + output_1 = sd_pipe( + [prompt], + width=1024, + height=1024, + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + ) + + sd_pipe.enable_vae_tiling() + generator = torch.Generator(device=device).manual_seed(0) + output_2 = sd_pipe( + [prompt], + width=1024, + height=1024, + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + ) + + # the tiling does cause different tonality to the output + assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 1e-2 + def test_stable_diffusion_negative_prompt(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator unet = self.dummy_cond_unet @@ -965,6 +1027,59 @@ def test_stable_diffusion_vae_slicing(self): # There is a small discrepancy at the image borders vs. a fully batched version. assert np.abs(image_chunked.flatten() - image.flatten()).max() < 3e-3 + def test_stable_diffusion_vae_tiling(self): + torch.cuda.reset_peak_memory_stats() + model_id = "CompVis/stable-diffusion-v1-4" + pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "a photograph of an astronaut riding a horse" + + # enable vae tiling + pipe.enable_vae_tiling() + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + output_chunked = pipe( + [prompt], + width=1024, + height=1024, + generator=generator, + guidance_scale=7.5, + num_inference_steps=2, + output_type="numpy", + ) + image_chunked = output_chunked.images + + mem_bytes = torch.cuda.max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + # make sure that less than 4 GB is allocated + print("vae_tiling on", mem_bytes) + assert mem_bytes < 4e9 + + # disable vae tiling + pipe.disable_vae_tiling() + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + output = pipe( + [prompt], + width=1024, + height=1024, + generator=generator, + guidance_scale=7.5, + num_inference_steps=2, + output_type="numpy", + ) + image = output.images + + # make sure that more than 4 GB is allocated + mem_bytes = torch.cuda.max_memory_allocated() + print("vae_tiling off", mem_bytes) + assert mem_bytes > 4e9 + # There is a small discrepancy at the image borders vs. a fully batched version. + assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-2 + def test_stable_diffusion_text2img_pipeline_fp16(self): torch.cuda.reset_peak_memory_stats() model_id = "CompVis/stable-diffusion-v1-4" From ac8b1c2eebbf245851373bdbaa18d3e893d9be21 Mon Sep 17 00:00:00 2001 From: Ilmari Heikkinen Date: Sun, 4 Dec 2022 02:32:16 +0800 Subject: [PATCH 04/19] tiled vae docs, disable tiling for images that would have only one tile --- .../source/api/pipelines/stable_diffusion.mdx | 2 ++ docs/source/optimization/fp16.mdx | 28 +++++++++++++++++++ src/diffusers/models/vae.py | 4 +-- 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/docs/source/api/pipelines/stable_diffusion.mdx b/docs/source/api/pipelines/stable_diffusion.mdx index 6068b961ae26..63b18c211da5 100644 --- a/docs/source/api/pipelines/stable_diffusion.mdx +++ b/docs/source/api/pipelines/stable_diffusion.mdx @@ -78,6 +78,8 @@ If you want to use all possible use cases in a single `DiffusionPipeline` you ca - disable_attention_slicing - enable_vae_slicing - disable_vae_slicing + - enable_vae_tiling + - disable_vae_tiling - enable_xformers_memory_efficient_attention - disable_xformers_memory_efficient_attention diff --git a/docs/source/optimization/fp16.mdx b/docs/source/optimization/fp16.mdx index 49fe3876bd4b..54a0915f665f 100644 --- a/docs/source/optimization/fp16.mdx +++ b/docs/source/optimization/fp16.mdx @@ -145,6 +145,34 @@ images = pipe([prompt] * 32).images You may see a small performance boost in VAE decode on multi-image batches. There should be no performance impact on single-image batches. +## Tiled VAE decode and encode for large images + +Tiled VAE processing makes it possible to work with large images on limited VRAM. For example, generating 4k images in 8GB of VRAM. Tiled VAE decoder splits the image into overlapping tiles, decodes the tiles, and blends the outputs to make the final image. + +You want to couple this with [`~StableDiffusionPipeline.enable_attention_slicing`] or [`~StableDiffusionPipeline.enable_xformers_memory_efficient_attention`] to further minimize memory use. + +To use tiled VAE processing, invoke [`~StableDiffusionPipeline.enable_vae_tiling`] in your pipeline before inference. For example: + +```Python +import torch +from diffusers import StableDiffusionPipeline + +pipe = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + revision="fp16", + torch_dtype=torch.float16, +) +pipe = pipe.to("cuda") + +prompt = "a beautiful landscape photograph" +pipe.enable_vae_tiling() +pipe.enable_xformers_memory_efficient_attention() +images = pipe([prompt], width=3840, height=2224).images +``` + +The output image will have some tile-to-tile tone variation from the tiles having separate decoders, but you shouldn't see sharp seams between the tiles. The tiling is turned off for images that are 512x512 or smaller. + + ## Offloading to CPU with accelerate for memory savings For additional memory savings, you can offload the weights to CPU and load them to GPU when performing the forward pass. diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index be1581154648..98aaf05c3a52 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -601,7 +601,7 @@ def disable_slicing(self): self.use_slicing = False def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: - if self.use_tiling: + if self.use_tiling and (x.shape[-1] > 512 or x.shape[-2] > 512): return self.tiled_encode(x, return_dict=return_dict) h = self.encoder(x) @@ -614,7 +614,7 @@ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderK return AutoencoderKLOutput(latent_dist=posterior) def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: - if self.use_tiling: + if self.use_tiling and (z.shape[-1] > 64 or z.shape[-2] > 64): return self.tiled_decode(z, return_dict=return_dict) z = self.post_quant_conv(z) From 0a96a8184f64b044201523526ac73f870a193083 Mon Sep 17 00:00:00 2001 From: Ilmari Heikkinen Date: Sun, 4 Dec 2022 02:59:39 +0800 Subject: [PATCH 05/19] tiled vae tests, use channels_last memory format --- tests/pipelines/stable_diffusion/test_stable_diffusion.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 48c06ad5819e..5839afdb90f6 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -1034,6 +1034,8 @@ def test_stable_diffusion_vae_tiling(self): pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) pipe.enable_attention_slicing() + pipe.unet = pipe.unet.to(memory_format=torch.channels_last) + pipe.vae = pipe.vae.to(memory_format=torch.channels_last) prompt = "a photograph of an astronaut riding a horse" From 2b0454d8ade12f13744f69dc0a3a6ea8576ec8eb Mon Sep 17 00:00:00 2001 From: Ilmari Heikkinen Date: Sun, 4 Dec 2022 03:05:53 +0800 Subject: [PATCH 06/19] tiled vae tests, use smaller test image --- tests/pipelines/stable_diffusion/test_stable_diffusion.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 5839afdb90f6..f1df07c95186 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -1045,8 +1045,8 @@ def test_stable_diffusion_vae_tiling(self): with torch.autocast(torch_device): output_chunked = pipe( [prompt], - width=1024, - height=1024, + width=640, + height=640, generator=generator, guidance_scale=7.5, num_inference_steps=2, @@ -1066,8 +1066,8 @@ def test_stable_diffusion_vae_tiling(self): with torch.autocast(torch_device): output = pipe( [prompt], - width=1024, - height=1024, + width=640, + height=640, generator=generator, guidance_scale=7.5, num_inference_steps=2, From 307fd12567e1cbd0cc0ff487c613b1cf9da67980 Mon Sep 17 00:00:00 2001 From: Ilmari Heikkinen Date: Sun, 4 Dec 2022 03:24:17 +0800 Subject: [PATCH 07/19] tiled vae tests, remove tiling test from fast tests --- .../stable_diffusion/test_stable_diffusion.py | 31 ------------------- 1 file changed, 31 deletions(-) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index f1df07c95186..c40d05ce620d 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -631,34 +631,6 @@ def test_stable_diffusion_vae_tiling(self): assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 1e-4 - # Test that tiled decode at 1024x1024 yields a mostly similar result as the non-tiled decode - sd_pipe.disable_vae_tiling() - generator = torch.Generator(device=device).manual_seed(0) - output_1 = sd_pipe( - [prompt], - width=1024, - height=1024, - generator=generator, - guidance_scale=6.0, - num_inference_steps=2, - output_type="np", - ) - - sd_pipe.enable_vae_tiling() - generator = torch.Generator(device=device).manual_seed(0) - output_2 = sd_pipe( - [prompt], - width=1024, - height=1024, - generator=generator, - guidance_scale=6.0, - num_inference_steps=2, - output_type="np", - ) - - # the tiling does cause different tonality to the output - assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 1e-2 - def test_stable_diffusion_negative_prompt(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator unet = self.dummy_cond_unet @@ -1057,7 +1029,6 @@ def test_stable_diffusion_vae_tiling(self): mem_bytes = torch.cuda.max_memory_allocated() torch.cuda.reset_peak_memory_stats() # make sure that less than 4 GB is allocated - print("vae_tiling on", mem_bytes) assert mem_bytes < 4e9 # disable vae tiling @@ -1077,9 +1048,7 @@ def test_stable_diffusion_vae_tiling(self): # make sure that more than 4 GB is allocated mem_bytes = torch.cuda.max_memory_allocated() - print("vae_tiling off", mem_bytes) assert mem_bytes > 4e9 - # There is a small discrepancy at the image borders vs. a fully batched version. assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-2 def test_stable_diffusion_text2img_pipeline_fp16(self): From 541f27517ddab4e81d4813ea2c1b8861c3e631a4 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 1 Mar 2023 17:07:11 +0100 Subject: [PATCH 08/19] up --- .../pipelines/stable_diffusion/overview.mdx | 56 +------- src/diffusers/hub_utils.py | 132 ------------------ 2 files changed, 2 insertions(+), 186 deletions(-) delete mode 100644 src/diffusers/hub_utils.py diff --git a/docs/source/en/api/pipelines/stable_diffusion/overview.mdx b/docs/source/en/api/pipelines/stable_diffusion/overview.mdx index e6127eee45cc..87d3d694a42a 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/overview.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/overview.mdx @@ -79,59 +79,7 @@ If you want to use all possible use cases in a single `DiffusionPipeline` you ca ## StableDiffusionPipelineOutput [[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput -<<<<<<< HEAD:docs/source/api/pipelines/stable_diffusion.mdx - -## StableDiffusionPipeline -[[autodoc]] StableDiffusionPipeline - - __call__ - - enable_attention_slicing - - disable_attention_slicing - - enable_vae_slicing - - disable_vae_slicing + - all + - call - enable_vae_tiling - disable_vae_tiling - - enable_xformers_memory_efficient_attention - - disable_xformers_memory_efficient_attention - -## StableDiffusionImg2ImgPipeline -[[autodoc]] StableDiffusionImg2ImgPipeline - - __call__ - - enable_attention_slicing - - disable_attention_slicing - - enable_xformers_memory_efficient_attention - - disable_xformers_memory_efficient_attention - -## StableDiffusionInpaintPipeline -[[autodoc]] StableDiffusionInpaintPipeline - - __call__ - - enable_attention_slicing - - disable_attention_slicing - - enable_xformers_memory_efficient_attention - - disable_xformers_memory_efficient_attention - -## StableDiffusionDepth2ImgPipeline -[[autodoc]] StableDiffusionDepth2ImgPipeline - - __call__ - - enable_attention_slicing - - disable_attention_slicing - - enable_xformers_memory_efficient_attention - - disable_xformers_memory_efficient_attention - -## StableDiffusionImageVariationPipeline -[[autodoc]] StableDiffusionImageVariationPipeline - - __call__ - - enable_attention_slicing - - disable_attention_slicing - - enable_xformers_memory_efficient_attention - - disable_xformers_memory_efficient_attention - - -## StableDiffusionUpscalePipeline -[[autodoc]] StableDiffusionUpscalePipeline - - __call__ - - enable_attention_slicing - - disable_attention_slicing - - enable_xformers_memory_efficient_attention - - disable_xformers_memory_efficient_attention -======= ->>>>>>> 664b4de9e22a825875b6bad45867f8f77cdf95d6:docs/source/en/api/pipelines/stable_diffusion/overview.mdx diff --git a/src/diffusers/hub_utils.py b/src/diffusers/hub_utils.py deleted file mode 100644 index a3d2731d043a..000000000000 --- a/src/diffusers/hub_utils.py +++ /dev/null @@ -1,132 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import os -import sys -from pathlib import Path -from typing import Dict, Optional, Union -from uuid import uuid4 - -from huggingface_hub import HfFolder, whoami - -from . import __version__ -from .utils import ENV_VARS_TRUE_VALUES, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging -from .utils.import_utils import ( - _flax_version, - _jax_version, - _onnxruntime_version, - _torch_version, - is_flax_available, - is_modelcards_available, - is_onnx_available, - is_torch_available, -) - - -if is_modelcards_available(): - from modelcards import CardData, ModelCard - - -logger = logging.get_logger(__name__) - - -MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md" -SESSION_ID = uuid4().hex -HF_HUB_OFFLINE = os.getenv("HF_HUB_OFFLINE", "").upper() in ENV_VARS_TRUE_VALUES -DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", "").upper() in ENV_VARS_TRUE_VALUES -HUGGINGFACE_CO_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/" - - -def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: - """ - Formats a user-agent string with basic info about a request. - """ - ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}" - if DISABLE_TELEMETRY or HF_HUB_OFFLINE: - return ua + "; telemetry/off" - if is_torch_available(): - ua += f"; torch/{_torch_version}" - if is_flax_available(): - ua += f"; jax/{_jax_version}" - ua += f"; flax/{_flax_version}" - if is_onnx_available(): - ua += f"; onnxruntime/{_onnxruntime_version}" - # CI will set this value to True - if os.environ.get("DIFFUSERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES: - ua += "; is_ci/true" - if isinstance(user_agent, dict): - ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items()) - elif isinstance(user_agent, str): - ua += "; " + user_agent - return ua - - -def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): - if token is None: - token = HfFolder.get_token() - if organization is None: - username = whoami(token)["name"] - return f"{username}/{model_id}" - else: - return f"{organization}/{model_id}" - - -def create_model_card(args, model_name): - if not is_modelcards_available: - raise ValueError( - "Please make sure to have `modelcards` installed when using the `create_model_card` function. You can" - " install the package with `pip install modelcards`." - ) - - if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]: - return - - hub_token = args.hub_token if hasattr(args, "hub_token") else None - repo_name = get_full_repo_name(model_name, token=hub_token) - - model_card = ModelCard.from_template( - card_data=CardData( # Card metadata object that will be converted to YAML block - language="en", - license="apache-2.0", - library_name="diffusers", - tags=[], - datasets=args.dataset_name, - metrics=[], - ), - template_path=MODEL_CARD_TEMPLATE_PATH, - model_name=model_name, - repo_name=repo_name, - dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None, - learning_rate=args.learning_rate, - train_batch_size=args.train_batch_size, - eval_batch_size=args.eval_batch_size, - gradient_accumulation_steps=( - args.gradient_accumulation_steps if hasattr(args, "gradient_accumulation_steps") else None - ), - adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None, - adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None, - adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None, - adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None, - lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None, - lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None, - ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None, - ema_power=args.ema_power if hasattr(args, "ema_power") else None, - ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None, - mixed_precision=args.mixed_precision, - ) - - card_path = os.path.join(args.output_dir, "README.md") - model_card.save(card_path) From 928c6d324f02d24d4009f792a58915684ecc8d4a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 1 Mar 2023 17:12:28 +0100 Subject: [PATCH 09/19] up --- .../pipelines/stable_diffusion/overview.mdx | 4 - .../pipelines/stable_diffusion/text2img.mdx | 4 +- src/diffusers/models/autoencoder_kl.py | 147 ++++++++++++++++-- 3 files changed, 134 insertions(+), 21 deletions(-) diff --git a/docs/source/en/api/pipelines/stable_diffusion/overview.mdx b/docs/source/en/api/pipelines/stable_diffusion/overview.mdx index 87d3d694a42a..160fa0d2ebce 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/overview.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/overview.mdx @@ -79,7 +79,3 @@ If you want to use all possible use cases in a single `DiffusionPipeline` you ca ## StableDiffusionPipelineOutput [[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput - - all - - call - - enable_vae_tiling - - disable_vae_tiling diff --git a/docs/source/en/api/pipelines/stable_diffusion/text2img.mdx b/docs/source/en/api/pipelines/stable_diffusion/text2img.mdx index 274caf64a2f7..590617636fa4 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/text2img.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/text2img.mdx @@ -36,4 +36,6 @@ Available Checkpoints are: - enable_vae_slicing - disable_vae_slicing - enable_xformers_memory_efficient_attention - - disable_xformers_memory_efficient_attention \ No newline at end of file + - disable_xformers_memory_efficient_attention + - enable_vae_tiling + - disable_vae_tiling diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index 99725498fae6..0dd996ef9efe 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -109,8 +109,41 @@ def __init__( self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) self.use_slicing = False + def enable_tiling(self, use_tiling: bool = True): + r""" + Enable tiled VAE decoding. + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.use_tiling = use_tiling + + def disable_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.enable_tiling(False) + + def enable_slicing(self): + r""" + Enable sliced VAE decoding. + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + @apply_forward_hook def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: + if self.use_tiling and (x.shape[-1] > 512 or x.shape[-2] > 512): + return self.tiled_encode(x, return_dict=return_dict) + h = self.encoder(x) moments = self.quant_conv(h) posterior = DiagonalGaussianDistribution(moments) @@ -121,6 +154,9 @@ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderK return AutoencoderKLOutput(latent_dist=posterior) def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + if self.use_tiling and (z.shape[-1] > 64 or z.shape[-2] > 64): + return self.tiled_decode(z, return_dict=return_dict) + z = self.post_quant_conv(z) dec = self.decoder(z) @@ -129,22 +165,6 @@ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decod return DecoderOutput(sample=dec) - def enable_slicing(self): - r""" - Enable sliced VAE decoding. - - When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several - steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - @apply_forward_hook def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: if self.use_slicing and z.shape[0] > 1: @@ -158,6 +178,101 @@ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decode return DecoderOutput(sample=decoded) + def blend_v(self, a, b, blend_width): + for y in range(blend_width): + b[:, :, y, :] = a[:, :, -blend_width + y, :] * (1 - y / blend_width) + b[:, :, y, :] * (y / blend_width) + return b + + def blend_h(self, a, b, blend_width): + for x in range(blend_width): + b[:, :, :, x] = a[:, :, :, -blend_width + x] * (1 - x / blend_width) + b[:, :, :, x] * (x / blend_width) + return b + + def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. + The end result of tiled encoding is different from non-tiled encoding due to each tile using a different + encoder. To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output. You may + still see tile-sized changes in the look of the output, but they should be much less noticeable. + Args: + x (`torch.FloatTensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple. + """ + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], 384): + row = [] + for j in range(0, x.shape[3], 384): + tile = x[:, :, i : i + 512, j : j + 512] + tile = self.encoder(tile) + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, 16) + if j > 0: + tile = self.blend_h(row[j - 1], tile, 16) + result_row.append(tile[:, :, :48, :48]) + result_rows.append(torch.cat(result_row, dim=3)) + + moments = torch.cat(result_rows, dim=2) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + r"""Decode a batch of images using a tiled decoder. + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several + steps. This is useful to keep memory use constant regardless of image size. + The end result of tiled decoding is different from non-tiled decoding due to each tile using a different + decoder. To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output. You may + still see tile-sized changes in the look of the output, but they should be much less noticeable. + Args: + z (`torch.FloatTensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[2], 48): + row = [] + for j in range(0, z.shape[3], 48): + tile = z[:, :, i : i + 64, j : j + 64] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, 128) + if j > 0: + tile = self.blend_h(row[j - 1], tile, 128) + result_row.append(tile[:, :, :384, :384]) + result_rows.append(torch.cat(result_row, dim=3)) + + dec = torch.cat(result_rows, dim=2) + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + def forward( self, sample: torch.FloatTensor, From 94781b6e7d59ab7022e6ac90c8c982bc8332b583 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 1 Mar 2023 17:13:23 +0100 Subject: [PATCH 10/19] make style --- src/diffusers/models/autoencoder_kl.py | 31 +++++++++++++------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index 0dd996ef9efe..ca2056b5ccba 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -111,9 +111,9 @@ def __init__( def enable_tiling(self, use_tiling: bool = True): r""" - Enable tiled VAE decoding. - When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in - several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow + the processing of larger images. """ self.use_tiling = use_tiling @@ -126,9 +126,8 @@ def disable_tiling(self): def enable_slicing(self): r""" - Enable sliced VAE decoding. - When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several - steps. This is useful to save some memory and allow larger batch sizes. + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.use_slicing = True @@ -190,12 +189,12 @@ def blend_h(self, a, b, blend_width): def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: r"""Encode a batch of images using a tiled encoder. - When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several - steps. This is useful to keep memory use constant regardless of image size. - The end result of tiled encoding is different from non-tiled encoding due to each tile using a different - encoder. To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output. You may - still see tile-sized changes in the look of the output, but they should be much less noticeable. Args: + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding due to each tile using a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + look of the output, but they should be much less noticeable. x (`torch.FloatTensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple. @@ -233,12 +232,12 @@ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> Autoen def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: r"""Decode a batch of images using a tiled decoder. - When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several - steps. This is useful to keep memory use constant regardless of image size. - The end result of tiled decoding is different from non-tiled decoding due to each tile using a different - decoder. To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output. You may - still see tile-sized changes in the look of the output, but they should be much less noticeable. Args: + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is + different from non-tiled decoding due to each tile using a different decoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + look of the output, but they should be much less noticeable. z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`DecoderOutput`] instead of a plain tuple. From 6adfedfabf704e58d4255dee7997c53d58d7e401 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 1 Mar 2023 17:15:49 +0100 Subject: [PATCH 11/19] Apply suggestions from code review --- docs/source/en/optimization/fp16.mdx | 2 -- tests/pipelines/stable_diffusion/test_stable_diffusion.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/source/en/optimization/fp16.mdx b/docs/source/en/optimization/fp16.mdx index d6aa0373de9c..444d4296c477 100644 --- a/docs/source/en/optimization/fp16.mdx +++ b/docs/source/en/optimization/fp16.mdx @@ -162,9 +162,7 @@ images = pipe([prompt], width=3840, height=2224).images The output image will have some tile-to-tile tone variation from the tiles having separate decoders, but you shouldn't see sharp seams between the tiles. The tiling is turned off for images that are 512x512 or smaller. -======= ->>>>>>> 664b4de9e22a825875b6bad45867f8f77cdf95d6:docs/source/en/optimization/fp16.mdx ## Offloading to CPU with accelerate for memory savings For additional memory savings, you can offload the weights to CPU and only load them to GPU when performing the forward pass. diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index e5e52a1d86c6..0a0622753707 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -422,7 +422,7 @@ def test_stable_diffusion_vae_slicing(self): def test_stable_diffusion_vae_tiling(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator unet = self.dummy_cond_unet - scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + scheduler = LMSDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4") vae = self.dummy_vae bert = self.dummy_text_encoder tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") From 8e7b8c218eadbb778c9a542aea43473b85f31714 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 1 Mar 2023 17:16:04 +0100 Subject: [PATCH 12/19] Apply suggestions from code review --- docs/source/en/optimization/fp16.mdx | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/source/en/optimization/fp16.mdx b/docs/source/en/optimization/fp16.mdx index 444d4296c477..d1fd6b33b087 100644 --- a/docs/source/en/optimization/fp16.mdx +++ b/docs/source/en/optimization/fp16.mdx @@ -133,7 +133,6 @@ images = pipe([prompt] * 32).images You may see a small performance boost in VAE decode on multi-image batches. There should be no performance impact on single-image batches. -<<<<<<< HEAD:docs/source/optimization/fp16.mdx ## Tiled VAE decode and encode for large images Tiled VAE processing makes it possible to work with large images on limited VRAM. For example, generating 4k images in 8GB of VRAM. Tiled VAE decoder splits the image into overlapping tiles, decodes the tiles, and blends the outputs to make the final image. From a03d41f344f69558c8bacb8f5ef31fa0220095ad Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 1 Mar 2023 17:16:28 +0100 Subject: [PATCH 13/19] Apply suggestions from code review --- docs/source/en/optimization/fp16.mdx | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/source/en/optimization/fp16.mdx b/docs/source/en/optimization/fp16.mdx index d1fd6b33b087..6ea09305f83a 100644 --- a/docs/source/en/optimization/fp16.mdx +++ b/docs/source/en/optimization/fp16.mdx @@ -147,7 +147,6 @@ from diffusers import StableDiffusionPipeline pipe = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", - revision="fp16", torch_dtype=torch.float16, ) pipe = pipe.to("cuda") From 68d9b29584a397ee72438790bf1352a87dccd777 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 1 Mar 2023 16:22:19 +0000 Subject: [PATCH 14/19] make style --- src/diffusers/models/autoencoder_kl.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index ca2056b5ccba..1b5ef1fe5488 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -191,12 +191,11 @@ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> Autoen r"""Encode a batch of images using a tiled encoder. Args: When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several - steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is: different from non-tiled encoding due to each tile using a different encoder. To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the look of the output, but they should be much less noticeable. - x (`torch.FloatTensor`): Input batch of images. - return_dict (`bool`, *optional*, defaults to `True`): + x (`torch.FloatTensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple. """ # Split the image into 512x512 tiles and encode them separately. @@ -234,12 +233,12 @@ def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[ r"""Decode a batch of images using a tiled decoder. Args: When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several - steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is: different from non-tiled decoding due to each tile using a different decoder. To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the look of the output, but they should be much less noticeable. - z (`torch.FloatTensor`): Input batch of latent vectors. - return_dict (`bool`, *optional*, defaults to `True`): + z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to + `True`): Whether or not to return a [`DecoderOutput`] instead of a plain tuple. """ # Split z into overlapping 64x64 tiles and decode them separately. From 99a37337670eb65c80f7ad2783191429558429c3 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 1 Mar 2023 16:53:55 +0000 Subject: [PATCH 15/19] improve naming --- docs/source/en/optimization/fp16.mdx | 9 +++--- src/diffusers/models/autoencoder_kl.py | 43 +++++++++++++++++--------- 2 files changed, 34 insertions(+), 18 deletions(-) diff --git a/docs/source/en/optimization/fp16.mdx b/docs/source/en/optimization/fp16.mdx index 6ea09305f83a..eef1dcec90f5 100644 --- a/docs/source/en/optimization/fp16.mdx +++ b/docs/source/en/optimization/fp16.mdx @@ -141,20 +141,21 @@ You want to couple this with [`~StableDiffusionPipeline.enable_attention_slicing To use tiled VAE processing, invoke [`~StableDiffusionPipeline.enable_vae_tiling`] in your pipeline before inference. For example: -```Python +```python import torch -from diffusers import StableDiffusionPipeline +from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler pipe = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, ) +pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) pipe = pipe.to("cuda") - prompt = "a beautiful landscape photograph" pipe.enable_vae_tiling() pipe.enable_xformers_memory_efficient_attention() -images = pipe([prompt], width=3840, height=2224).images + +image = pipe([prompt], width=3840, height=2224, num_inference_steps=20).images[0] ``` The output image will have some tile-to-tile tone variation from the tiles having separate decoders, but you shouldn't see sharp seams between the tiles. The tiling is turned off for images that are 512x512 or smaller. diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index 1b5ef1fe5488..79280be1aa40 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -107,7 +107,14 @@ def __init__( self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) + self.use_slicing = False + self.use_tiling = False + + # only relevant if tiling is enabled + self.tile_sample_min_size = self.config.sample_size + self.tile_latent_min_size = self.config.sample_size / (2 ** len(self.block_out_channels)) + self.tile_overlap_factor = 0.25 def enable_tiling(self, use_tiling: bool = True): r""" @@ -140,7 +147,7 @@ def disable_slicing(self): @apply_forward_hook def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: - if self.use_tiling and (x.shape[-1] > 512 or x.shape[-2] > 512): + if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): return self.tiled_encode(x, return_dict=return_dict) h = self.encoder(x) @@ -153,7 +160,7 @@ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderK return AutoencoderKLOutput(latent_dist=posterior) def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: - if self.use_tiling and (z.shape[-1] > 64 or z.shape[-2] > 64): + if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): return self.tiled_decode(z, return_dict=return_dict) z = self.post_quant_conv(z) @@ -198,12 +205,16 @@ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> Autoen x (`torch.FloatTensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple. """ + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_width = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_width + # Split the image into 512x512 tiles and encode them separately. rows = [] - for i in range(0, x.shape[2], 384): + for i in range(0, x.shape[2], overlap_size): row = [] - for j in range(0, x.shape[3], 384): - tile = x[:, :, i : i + 512, j : j + 512] + for j in range(0, x.shape[3], overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] tile = self.encoder(tile) tile = self.quant_conv(tile) row.append(tile) @@ -215,10 +226,10 @@ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> Autoen # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, 16) + tile = self.blend_v(rows[i - 1][j], tile, blend_width) if j > 0: - tile = self.blend_h(row[j - 1], tile, 16) - result_row.append(tile[:, :, :48, :48]) + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :row_limit, :row_limit]) result_rows.append(torch.cat(result_row, dim=3)) moments = torch.cat(result_rows, dim=2) @@ -241,13 +252,17 @@ def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[ `True`): Whether or not to return a [`DecoderOutput`] instead of a plain tuple. """ + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_width = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_width + # Split z into overlapping 64x64 tiles and decode them separately. # The tiles have an overlap to avoid seams between tiles. rows = [] - for i in range(0, z.shape[2], 48): + for i in range(0, z.shape[2], overlap_size): row = [] - for j in range(0, z.shape[3], 48): - tile = z[:, :, i : i + 64, j : j + 64] + for j in range(0, z.shape[3], overlap_size): + tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] tile = self.post_quant_conv(tile) decoded = self.decoder(tile) row.append(decoded) @@ -259,10 +274,10 @@ def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[ # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, 128) + tile = self.blend_v(rows[i - 1][j], tile, blend_width) if j > 0: - tile = self.blend_h(row[j - 1], tile, 128) - result_row.append(tile[:, :, :384, :384]) + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :row_limit, :row_limit]) result_rows.append(torch.cat(result_row, dim=3)) dec = torch.cat(result_rows, dim=2) From 9728662060509731fc244c063be05979b7d85830 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 1 Mar 2023 17:54:03 +0000 Subject: [PATCH 16/19] finish --- src/diffusers/models/autoencoder_kl.py | 13 +++++++++---- .../audio_diffusion/test_audio_diffusion.py | 1 + .../stable_diffusion/test_stable_diffusion.py | 19 ++++--------------- 3 files changed, 14 insertions(+), 19 deletions(-) diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index 79280be1aa40..ad90b6d91115 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -111,9 +111,14 @@ def __init__( self.use_slicing = False self.use_tiling = False - # only relevant if tiling is enabled + # only relevant if vae tiling is enabled self.tile_sample_min_size = self.config.sample_size - self.tile_latent_min_size = self.config.sample_size / (2 ** len(self.block_out_channels)) + sample_size = ( + self.config.sample_size[0] + if isinstance(self.config.sample_size, (list, tuple)) + else self.config.sample_size + ) + self.tile_latent_min_size = int(sample_size / (2 ** (len(self.block_out_channels) - 1))) self.tile_overlap_factor = 0.25 def enable_tiling(self, use_tiling: bool = True): @@ -206,7 +211,7 @@ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> Autoen Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple. """ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) - blend_width = int(self.tile_sample_min_size * self.tile_overlap_factor) + blend_width = int(self.tile_latent_min_size * self.tile_overlap_factor) row_limit = self.tile_latent_min_size - blend_width # Split the image into 512x512 tiles and encode them separately. @@ -253,7 +258,7 @@ def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[ Whether or not to return a [`DecoderOutput`] instead of a plain tuple. """ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) - blend_width = int(self.tile_latent_min_size * self.tile_overlap_factor) + blend_width = int(self.tile_sample_min_size * self.tile_overlap_factor) row_limit = self.tile_sample_min_size - blend_width # Split z into overlapping 64x64 tiles and decode them separately. diff --git a/tests/pipelines/audio_diffusion/test_audio_diffusion.py b/tests/pipelines/audio_diffusion/test_audio_diffusion.py index 770780285d3e..ba389d9c936d 100644 --- a/tests/pipelines/audio_diffusion/test_audio_diffusion.py +++ b/tests/pipelines/audio_diffusion/test_audio_diffusion.py @@ -96,6 +96,7 @@ def dummy_vqvae_and_unet(self): ) return vqvae, unet + @slow def test_audio_diffusion(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator mel = Mel() diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 0a0622753707..6ac43cbeb0be 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -421,22 +421,11 @@ def test_stable_diffusion_vae_slicing(self): def test_stable_diffusion_vae_tiling(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator - unet = self.dummy_cond_unet - scheduler = LMSDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4") - vae = self.dummy_vae - bert = self.dummy_text_encoder - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + components = self.get_dummy_components() # make sure here that pndm scheduler skips prk - sd_pipe = StableDiffusionPipeline( - unet=unet, - scheduler=scheduler, - vae=vae, - text_encoder=bert, - tokenizer=tokenizer, - safety_checker=None, - feature_extractor=self.dummy_extractor, - ) + components["safety_checker"] = None + sd_pipe = StableDiffusionPipeline(**components) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) @@ -451,7 +440,7 @@ def test_stable_diffusion_vae_tiling(self): generator = torch.Generator(device=device).manual_seed(0) output_2 = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") - assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 1e-4 + assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 5e-1 def test_stable_diffusion_negative_prompt(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator From 4eb4981caa82b57c4660f18de2168ce9c37f372f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 2 Mar 2023 17:25:49 +0100 Subject: [PATCH 17/19] apply suggestions --- src/diffusers/models/autoencoder_kl.py | 28 +++++++++++++------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index ad90b6d91115..9cb0a4b2432b 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -189,14 +189,14 @@ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decode return DecoderOutput(sample=decoded) - def blend_v(self, a, b, blend_width): - for y in range(blend_width): - b[:, :, y, :] = a[:, :, -blend_width + y, :] * (1 - y / blend_width) + b[:, :, y, :] * (y / blend_width) + def blend_v(self, a, b, blend_extent): + for y in range(blend_extent): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) return b - def blend_h(self, a, b, blend_width): - for x in range(blend_width): - b[:, :, :, x] = a[:, :, :, -blend_width + x] * (1 - x / blend_width) + b[:, :, :, x] * (x / blend_width) + def blend_h(self, a, b, blend_extent): + for x in range(blend_extent): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) return b def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: @@ -211,8 +211,8 @@ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> Autoen Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple. """ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) - blend_width = int(self.tile_latent_min_size * self.tile_overlap_factor) - row_limit = self.tile_latent_min_size - blend_width + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent # Split the image into 512x512 tiles and encode them separately. rows = [] @@ -231,9 +231,9 @@ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> Autoen # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_width) + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_width) + tile = self.blend_h(row[j - 1], tile, blend_extent) result_row.append(tile[:, :, :row_limit, :row_limit]) result_rows.append(torch.cat(result_row, dim=3)) @@ -258,8 +258,8 @@ def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[ Whether or not to return a [`DecoderOutput`] instead of a plain tuple. """ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) - blend_width = int(self.tile_sample_min_size * self.tile_overlap_factor) - row_limit = self.tile_sample_min_size - blend_width + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent # Split z into overlapping 64x64 tiles and decode them separately. # The tiles have an overlap to avoid seams between tiles. @@ -279,9 +279,9 @@ def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[ # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_width) + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_width) + tile = self.blend_h(row[j - 1], tile, blend_extent) result_row.append(tile[:, :, :row_limit, :row_limit]) result_rows.append(torch.cat(result_row, dim=3)) From 907788a34b677862d50fbb175d7167d294314b2c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 2 Mar 2023 17:26:23 +0100 Subject: [PATCH 18/19] Apply suggestions from code review Co-authored-by: Pedro Cuenca --- src/diffusers/models/autoencoder_kl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index 9cb0a4b2432b..f28544aa8df3 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -203,7 +203,7 @@ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> Autoen r"""Encode a batch of images using a tiled encoder. Args: When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several - steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is: + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is different from non-tiled encoding due to each tile using a different encoder. To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the look of the output, but they should be much less noticeable. From 48793b7b29a75ab6c3234adeccee505ad86bb02e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 2 Mar 2023 17:26:39 +0100 Subject: [PATCH 19/19] up --- src/diffusers/models/autoencoder_kl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index f28544aa8df3..9cb0a4b2432b 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -203,7 +203,7 @@ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> Autoen r"""Encode a batch of images using a tiled encoder. Args: When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several - steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is: different from non-tiled encoding due to each tile using a different encoder. To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the look of the output, but they should be much less noticeable.