From 1fb21d318640aba50786554a55fd12ffab3a8a23 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Mon, 31 Oct 2022 14:36:10 -0300 Subject: [PATCH 1/7] add enable sequential cpu offloading to other stable diffusion pipelines --- .../pipeline_stable_diffusion_img2img.py | 18 ++++++++ .../pipeline_stable_diffusion_inpaint.py | 18 ++++++++ .../stable_diffusion/test_stable_diffusion.py | 2 +- .../test_stable_diffusion_img2img.py | 44 +++++++++++++++++++ .../test_stable_diffusion_inpaint.py | 44 +++++++++++++++++++ 5 files changed, 125 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 00c364f8e5e3..cd55051f5b04 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -5,6 +5,7 @@ import torch import PIL +from diffusers.utils import is_accelerate_available from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict @@ -151,6 +152,23 @@ def disable_attention_slicing(self): # set slice_size = `None` to disable `set_attention_slice` self.enable_attention_slicing(None) + def enable_sequential_cpu_offload(self): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device("cuda") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + @torch.no_grad() def __call__( self, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 39c8d16823af..9618593db59b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -5,6 +5,7 @@ import torch import PIL +from diffusers.utils import is_accelerate_available from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict @@ -150,6 +151,23 @@ def disable_attention_slicing(self): # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) + def enable_sequential_cpu_offload(self): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device("cuda") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + @torch.no_grad() def __call__( self, diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 4e627846f2ec..f6f632adf33d 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -840,7 +840,7 @@ def test_stable_diffusion_accelerate_auto_device(self): assert 2 * meta_device_load_time < normal_load_time @unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU") - def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self): + def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py index 78d001e3c758..e5f33cace72d 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -603,3 +603,47 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No ) assert test_callback_fn.has_been_called assert number_of_steps == 38 + + def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/sketch-mountains-input.jpg" + ) + expected_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/fantasy_landscape_k_lms.png" + ) + init_image = init_image.resize((768, 512)) + expected_image = np.array(expected_image, dtype=np.float32) / 255.0 + + model_id = "CompVis/stable-diffusion-v1-4" + lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler") + pipe = StableDiffusionImg2ImgPipeline.from_pretrained( + model_id, + scheduler=lms, + safety_checker=None, + device_map="auto", + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing(1) + pipe.enable_sequential_cpu_offload() + + prompt = "A fantasy landscape, trending on artstation" + + generator = torch.Generator(device=torch_device).manual_seed(0) + _ = pipe( + prompt=prompt, + init_image=init_image, + strength=0.75, + guidance_scale=7.5, + generator=generator, + output_type="np", + ) + + mem_bytes = torch.cuda.max_memory_allocated() + # make sure that less than 1.5 GB is allocated + assert mem_bytes < 1.5 * 10**9 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index 0a373ada68bc..2366d481e8fe 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -389,3 +389,47 @@ def test_stable_diffusion_inpaint_pipeline_pndm(self): assert image.shape == (512, 512, 3) assert np.abs(expected_image - image).max() < 1e-2 + + @unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU") + def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" + ) + expected_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/yellow_cat_sitting_on_a_park_bench_pndm.png" + ) + expected_image = np.array(expected_image, dtype=np.float32) / 255.0 + + model_id = "runwayml/stable-diffusion-inpainting" + pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler") + pipe = StableDiffusionInpaintPipeline.from_pretrained( + model_id, safety_checker=None, scheduler=pndm, device_map="auto" + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing(1) + pipe.enable_sequential_cpu_offload() + + prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + + generator = torch.Generator(device=torch_device).manual_seed(0) + _ = pipe( + prompt=prompt, + image=init_image, + mask_image=mask_image, + generator=generator, + output_type="np", + ) + + mem_bytes = torch.cuda.max_memory_allocated() + # make sure that less than 1.5 GB is allocated + assert mem_bytes < 1.5 * 10**9 From 028f824cac4d8dab796840cb104639b70e77c391 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Mon, 31 Oct 2022 17:02:09 -0300 Subject: [PATCH 2/7] trigger ci From e3e6f50a4b47328d7b23686198b2c592d35130e9 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Wed, 2 Nov 2022 13:18:17 -0300 Subject: [PATCH 3/7] fix styling --- .../pipelines/stable_diffusion/test_stable_diffusion_inpaint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index 6efb1adc7523..740603209fbf 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -387,7 +387,6 @@ def test_stable_diffusion_inpaint_pipeline_pndm(self): assert image.shape == (512, 512, 3) assert np.abs(expected_image - image).max() < 1e-2 - @unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU") def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): torch.cuda.empty_cache() From d6fcb917d17f40a72f0832dd96305fe2d78c3bae Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Wed, 2 Nov 2022 13:27:20 -0300 Subject: [PATCH 4/7] interpolate before converting to device to avoid breking when cpu_offload is enabled with fp16 Co-authored-by: Pedro Gengo --- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 773ef7d95edf..c5d4b9edeb61 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -379,11 +379,15 @@ def __call__( # prepare mask and masked_image mask, masked_image = prepare_mask_and_masked_image(image, mask_image) - mask = mask.to(device=self.device, dtype=text_embeddings.dtype) - masked_image = masked_image.to(device=self.device, dtype=text_embeddings.dtype) # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8)) + mask = mask.to(device=self.device, dtype=text_embeddings.dtype) + + masked_image = masked_image.to(device=self.device, dtype=text_embeddings.dtype) + # encode the mask image into latents space so we can concatenate it to the latents masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) From 36603350ed460f358e208742cf08e5590e3b6f23 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Wed, 2 Nov 2022 13:31:48 -0300 Subject: [PATCH 5/7] style again I need to stop forgething this thing --- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index c5d4b9edeb61..fa461645ca38 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -388,7 +388,6 @@ def __call__( masked_image = masked_image.to(device=self.device, dtype=text_embeddings.dtype) - # encode the mask image into latents space so we can concatenate it to the latents masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) masked_image_latents = 0.18215 * masked_image_latents From a6a4069af59c159b41e5d2736b3a47cbd88c080b Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Wed, 2 Nov 2022 16:35:46 -0300 Subject: [PATCH 6/7] fix inpainting bug that could cause device misalignment Co-authored-by: Pedro Gengo --- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index fa461645ca38..b8e7f949e7ec 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -401,6 +401,9 @@ def __call__( torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents ) + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=self.device, dtype=text_embeddings.dtype) + num_channels_mask = mask.shape[1] num_channels_masked_image = masked_image_latents.shape[1] From 0eb3d258dca11ec14316fff6e9abfb0d97ae418c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 4 Nov 2022 18:16:57 +0100 Subject: [PATCH 7/7] Apply suggestions from code review --- .../pipelines/stable_diffusion/test_stable_diffusion_img2img.py | 1 + .../pipelines/stable_diffusion/test_stable_diffusion_inpaint.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py index ea6a348f7f5e..1926c3a7a6aa 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -642,6 +642,7 @@ def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): guidance_scale=7.5, generator=generator, output_type="np", + num_inference_steps=5, ) mem_bytes = torch.cuda.max_memory_allocated() diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index b4441e570abb..e8dcb43163da 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -417,6 +417,7 @@ def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): image=init_image, mask_image=mask_image, generator=generator, + num_inference_steps=5, output_type="np", )