From 16367e10df296c4e6d2a201cc228c5e22dbcba68 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Fri, 16 Sep 2022 14:44:50 -0300 Subject: [PATCH 1/5] enable shrkinking of sd to run on 2gb GPUs --- .../pipeline_stable_diffusion.py | 45 ++++++++++++------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index ce17f2e0ee41..8a6efd0ef997 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -1,5 +1,6 @@ import inspect import warnings +import gc from typing import List, Optional, Union import torch @@ -89,7 +90,17 @@ def disable_attention_slicing(self): """ # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) - + + def enable_minimal_memory_usage(self): + """Moves only unet to fp16 and to CUDA, while keepping lighter models on CPUs + """ + self.unet.to(torch.float16).to(torch.device("cuda")) + self.enable_attention_slicing(1) + + torch.cuda.empty_cache() + gc.collect() + + @torch.no_grad() def __call__( self, @@ -136,7 +147,7 @@ def __call__( tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. @@ -157,9 +168,9 @@ def __call__( ) # Set device as before (to be removed in 0.3.0) - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - self.to(device) + # if device is None: + # device = "cuda" if torch.cuda.is_available() else "cpu" + # self.to(device) if isinstance(prompt, str): batch_size = 1 @@ -179,7 +190,7 @@ def __call__( truncation=True, return_tensors="pt", ) - text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + text_embeddings = self.text_encoder(text_input.input_ids.to(self.text_encoder.device))[0].to(self.unet.device) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -191,7 +202,7 @@ def __call__( uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.text_encoder.device))[0].to(self.unet.device) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch @@ -224,7 +235,7 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) - # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas + # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas if isinstance(self.scheduler, LMSDiscreteScheduler): latents = latents * self.scheduler.sigmas[0] @@ -246,7 +257,11 @@ def __call__( latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + noise_pred = self.unet( + latent_model_input.to(self.unet.device), + t.to(self.unet.device), + encoder_hidden_states=text_embeddings + ).sample # perform guidance if do_classifier_free_guidance: @@ -255,19 +270,19 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, i.to(self.unet.device), latents.to(self.unet.device), **extra_step_kwargs).prev_sample else: - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t.to(self.unet.device), latents.to(self.unet.device), **extra_step_kwargs).prev_sample # scale and decode the image latents with vae latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents.to(self.vae.device)).sample image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() + image = image.to(self.safety_checker.dtype).cpu().permute(0, 2, 3, 1).numpy() # run safety checker - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.safety_checker.dtype).to(self.device) image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) if output_type == "pil": @@ -276,4 +291,4 @@ def __call__( if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) \ No newline at end of file From 94561c4d9c1c7429d27b8682a4411c3058ce9aff Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Fri, 16 Sep 2022 15:02:29 -0300 Subject: [PATCH 2/5] add test to ensure reduced GPU memory usage --- .../pipeline_stable_diffusion.py | 4 +-- tests/test_pipelines.py | 25 +++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 8a6efd0ef997..32412faf3a00 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -282,7 +282,7 @@ def __call__( image = image.to(self.safety_checker.dtype).cpu().permute(0, 2, 3, 1).numpy() # run safety checker - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.safety_checker.dtype).to(self.device) + safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.safety_checker.device).to(self.safety_checker.dtype) image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) if output_type == "pil": @@ -291,4 +291,4 @@ def __call__( if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) \ No newline at end of file + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 102a55a93e4b..832e36fadf58 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -1168,6 +1168,31 @@ def test_stable_diffusion_memory_chunking(self): assert mem_bytes > 3.75 * 10**9 assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3 + @slow + @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") + def test_stable_diffusion_further_memory_chunking(self): + torch.cuda.reset_peak_memory_stats() + model_id = "CompVis/stable-diffusion-v1-4" + pipe = StableDiffusionPipeline.from_pretrained( + model_id, use_auth_token=True, revision="main" + ) + pipe.set_progress_bar_config(disable=None) + pipe.enable_minimal_memory_usage() + + prompt = "a photograph of an astronaut riding a horse" + + # make attention efficient + pipe.enable_attention_slicing() + with torch.autocast(torch_device): + output_chunked = pipe( + [prompt], guidance_scale=7.5, num_inference_steps=10, output_type="numpy" + ) + + mem_bytes = torch.cuda.max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + # make sure that less than 2.3 GB is allocated + assert mem_bytes < 2.3 * 10**9 + @slow @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") def test_stable_diffusion_text2img_pipeline(self): From c8b4581369c7d3e53c87e36dfe74e90383a18469 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Fri, 16 Sep 2022 15:26:04 -0300 Subject: [PATCH 3/5] format code using black --- .../pipeline_stable_diffusion.py | 32 +++++++++++-------- tests/test_pipelines.py | 8 ++--- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 32412faf3a00..ea3c6d4dad1b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -90,17 +90,15 @@ def disable_attention_slicing(self): """ # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) - + def enable_minimal_memory_usage(self): - """Moves only unet to fp16 and to CUDA, while keepping lighter models on CPUs - """ + """Moves only unet to fp16 and to CUDA, while keepping lighter models on CPUs""" self.unet.to(torch.float16).to(torch.device("cuda")) self.enable_attention_slicing(1) - + torch.cuda.empty_cache() gc.collect() - - + @torch.no_grad() def __call__( self, @@ -202,7 +200,9 @@ def __call__( uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.text_encoder.device))[0].to(self.unet.device) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.text_encoder.device))[0].to( + self.unet.device + ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch @@ -258,9 +258,7 @@ def __call__( # predict the noise residual noise_pred = self.unet( - latent_model_input.to(self.unet.device), - t.to(self.unet.device), - encoder_hidden_states=text_embeddings + latent_model_input.to(self.unet.device), t.to(self.unet.device), encoder_hidden_states=text_embeddings ).sample # perform guidance @@ -270,9 +268,13 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step(noise_pred, i.to(self.unet.device), latents.to(self.unet.device), **extra_step_kwargs).prev_sample + latents = self.scheduler.step( + noise_pred, i.to(self.unet.device), latents.to(self.unet.device), **extra_step_kwargs + ).prev_sample else: - latents = self.scheduler.step(noise_pred, t.to(self.unet.device), latents.to(self.unet.device), **extra_step_kwargs).prev_sample + latents = self.scheduler.step( + noise_pred, t.to(self.unet.device), latents.to(self.unet.device), **extra_step_kwargs + ).prev_sample # scale and decode the image latents with vae latents = 1 / 0.18215 * latents @@ -282,7 +284,11 @@ def __call__( image = image.to(self.safety_checker.dtype).cpu().permute(0, 2, 3, 1).numpy() # run safety checker - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.safety_checker.device).to(self.safety_checker.dtype) + safety_cheker_input = ( + self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt") + .to(self.safety_checker.device) + .to(self.safety_checker.dtype) + ) image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) if output_type == "pil": diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 832e36fadf58..d33c3972086a 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -1173,9 +1173,7 @@ def test_stable_diffusion_memory_chunking(self): def test_stable_diffusion_further_memory_chunking(self): torch.cuda.reset_peak_memory_stats() model_id = "CompVis/stable-diffusion-v1-4" - pipe = StableDiffusionPipeline.from_pretrained( - model_id, use_auth_token=True, revision="main" - ) + pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True, revision="main") pipe.set_progress_bar_config(disable=None) pipe.enable_minimal_memory_usage() @@ -1184,9 +1182,7 @@ def test_stable_diffusion_further_memory_chunking(self): # make attention efficient pipe.enable_attention_slicing() with torch.autocast(torch_device): - output_chunked = pipe( - [prompt], guidance_scale=7.5, num_inference_steps=10, output_type="numpy" - ) + output_chunked = pipe([prompt], guidance_scale=7.5, num_inference_steps=10, output_type="numpy") mem_bytes = torch.cuda.max_memory_allocated() torch.cuda.reset_peak_memory_stats() From 0a11fe4bfcb3c2c05651792ef307056b87ca4e0f Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Fri, 16 Sep 2022 15:32:45 -0300 Subject: [PATCH 4/5] fix imports and remove unused variables --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 4 ++-- tests/test_pipelines.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index ea3c6d4dad1b..d9f5ce2c7fef 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -1,6 +1,6 @@ +import gc import inspect import warnings -import gc from typing import List, Optional, Union import torch @@ -159,7 +159,7 @@ def __call__( """ if "torch_device" in kwargs: - device = kwargs.pop("torch_device") + # device = kwargs.pop("torch_device") warnings.warn( "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." " Consider using `pipe.to(torch_device)` instead." diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index d33c3972086a..2ad871705854 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -1182,7 +1182,7 @@ def test_stable_diffusion_further_memory_chunking(self): # make attention efficient pipe.enable_attention_slicing() with torch.autocast(torch_device): - output_chunked = pipe([prompt], guidance_scale=7.5, num_inference_steps=10, output_type="numpy") + _ = pipe([prompt], guidance_scale=7.5, num_inference_steps=10, output_type="numpy") mem_bytes = torch.cuda.max_memory_allocated() torch.cuda.reset_peak_memory_stats() From 6a42078f5bef27e548751b522ebfb67737b55d79 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Fri, 16 Sep 2022 16:16:30 -0300 Subject: [PATCH 5/5] fix tensor devices for cases when safety checker is mocked --- .../stable_diffusion/pipeline_stable_diffusion.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index d9f5ce2c7fef..8e61876139c0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -269,7 +269,7 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 if isinstance(self.scheduler, LMSDiscreteScheduler): latents = self.scheduler.step( - noise_pred, i.to(self.unet.device), latents.to(self.unet.device), **extra_step_kwargs + noise_pred, i, latents.to(self.unet.device), **extra_step_kwargs ).prev_sample else: latents = self.scheduler.step( @@ -281,13 +281,13 @@ def __call__( image = self.vae.decode(latents.to(self.vae.device)).sample image = (image / 2 + 0.5).clamp(0, 1) - image = image.to(self.safety_checker.dtype).cpu().permute(0, 2, 3, 1).numpy() + image = image.to(self.vae.device).to(self.vae.device).cpu().permute(0, 2, 3, 1).numpy() # run safety checker safety_cheker_input = ( self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt") - .to(self.safety_checker.device) - .to(self.safety_checker.dtype) + .to(self.vae.device) + .to(self.vae.dtype) ) image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)