From 2104e78dd02c0985bf51634578cc485f40ab09bc Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 17 Nov 2022 16:01:41 +0000 Subject: [PATCH 1/3] inpaint fix k lms --- .../pipeline_stable_diffusion_inpaint.py | 7 ++-- .../test_stable_diffusion_inpaint.py | 40 +++++++++++++++++++ 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index fea2b3e5a8ee..992b4ca272c2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -169,7 +169,7 @@ def disable_attention_slicing(self): self.enable_attention_slicing(None) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload - def enable_sequential_cpu_offload(self): + def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a @@ -180,7 +180,7 @@ def enable_sequential_cpu_offload(self): else: raise ImportError("Please install accelerate via `pip install accelerate`") - device = torch.device("cuda") + device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: if cpu_offloaded_model is not None: @@ -586,9 +586,8 @@ def __call__( latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents # concat latents, mask, masked_image_latents in the channel dimension - latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index ce231a1a460b..c6a976f4f2f0 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -22,6 +22,7 @@ from diffusers import ( AutoencoderKL, + LMSDiscreteScheduler, PNDMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel, @@ -421,6 +422,45 @@ def test_stable_diffusion_inpaint_pipeline_pndm(self): assert image.shape == (512, 512, 3) assert np.abs(expected_image - image).max() < 1e-2 + def test_stable_diffusion_inpaint_pipeline_k_lms(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" + ) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint" + "/yellow_cat_sitting_on_a_park_bench_k_lms.npy" + ) + + model_id = "runwayml/stable-diffusion-inpainting" + pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None) + pipe.to(torch_device) + + # switch to LMS + pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) + + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + prompt=prompt, + image=init_image, + mask_image=mask_image, + generator=generator, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (512, 512, 3) + assert np.abs(expected_image - image).max() < 1e-2 + @unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU") def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): torch.cuda.empty_cache() From 1fff480838c52a78b35f13dafeca277aafc0b5b6 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 17 Nov 2022 16:03:00 +0000 Subject: [PATCH 2/3] onnox as well --- .../stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py index 6228824b3d24..9fbdc8f6d461 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py @@ -408,8 +408,8 @@ def __call__( # expand the latents if we are doing classifier free guidance latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents # concat latents, mask, masked_image_latnets in the channel dimension - latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1) latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) + latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1) latent_model_input = latent_model_input.cpu().numpy() # predict the noise residual From 611fc7fc8a47ed75600cf37660a0c869cd5e10be Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 17 Nov 2022 16:03:37 +0000 Subject: [PATCH 3/3] up --- .../pipelines/alt_diffusion/pipeline_alt_diffusion.py | 4 ++-- .../pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py | 2 +- .../pipelines/stable_diffusion/pipeline_cycle_diffusion.py | 2 +- .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 2 +- .../pipeline_stable_diffusion_inpaint_legacy.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 01b2051db410..afb2c5288640 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -179,7 +179,7 @@ def disable_attention_slicing(self): # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) - def enable_sequential_cpu_offload(self): + def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a @@ -190,7 +190,7 @@ def enable_sequential_cpu_offload(self): else: raise ImportError("Please install accelerate via `pip install accelerate`") - device = torch.device("cuda") + device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: if cpu_offloaded_model is not None: diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 294a43e86e61..4f5c1c354626 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -189,7 +189,7 @@ def enable_sequential_cpu_offload(self): else: raise ImportError("Please install accelerate via `pip install accelerate`") - device = torch.device("cuda") + device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: if cpu_offloaded_model is not None: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index b5f409929234..f3fc3f8eb514 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -220,7 +220,7 @@ def enable_sequential_cpu_offload(self): else: raise ImportError("Please install accelerate via `pip install accelerate`") - device = torch.device("cuda") + device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: if cpu_offloaded_model is not None: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 4bfbc5fbcb6e..d9ec997b91f3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -187,7 +187,7 @@ def enable_sequential_cpu_offload(self): else: raise ImportError("Please install accelerate via `pip install accelerate`") - device = torch.device("cuda") + device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: if cpu_offloaded_model is not None: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index 5c2a3e952375..2cc7c5c167e4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -200,7 +200,7 @@ def enable_sequential_cpu_offload(self): else: raise ImportError("Please install accelerate via `pip install accelerate`") - device = torch.device("cuda") + device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: if cpu_offloaded_model is not None: