diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index ce17f2e0ee41..8e61876139c0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -1,3 +1,4 @@ +import gc import inspect import warnings from typing import List, Optional, Union @@ -90,6 +91,14 @@ def disable_attention_slicing(self): # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) + def enable_minimal_memory_usage(self): + """Moves only unet to fp16 and to CUDA, while keepping lighter models on CPUs""" + self.unet.to(torch.float16).to(torch.device("cuda")) + self.enable_attention_slicing(1) + + torch.cuda.empty_cache() + gc.collect() + @torch.no_grad() def __call__( self, @@ -136,7 +145,7 @@ def __call__( tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. @@ -150,16 +159,16 @@ def __call__( """ if "torch_device" in kwargs: - device = kwargs.pop("torch_device") + # device = kwargs.pop("torch_device") warnings.warn( "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." " Consider using `pipe.to(torch_device)` instead." ) # Set device as before (to be removed in 0.3.0) - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - self.to(device) + # if device is None: + # device = "cuda" if torch.cuda.is_available() else "cpu" + # self.to(device) if isinstance(prompt, str): batch_size = 1 @@ -179,7 +188,7 @@ def __call__( truncation=True, return_tensors="pt", ) - text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + text_embeddings = self.text_encoder(text_input.input_ids.to(self.text_encoder.device))[0].to(self.unet.device) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -191,7 +200,9 @@ def __call__( uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.text_encoder.device))[0].to( + self.unet.device + ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch @@ -224,7 +235,7 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) - # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas + # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas if isinstance(self.scheduler, LMSDiscreteScheduler): latents = latents * self.scheduler.sigmas[0] @@ -246,7 +257,9 @@ def __call__( latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + noise_pred = self.unet( + latent_model_input.to(self.unet.device), t.to(self.unet.device), encoder_hidden_states=text_embeddings + ).sample # perform guidance if do_classifier_free_guidance: @@ -255,19 +268,27 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step( + noise_pred, i, latents.to(self.unet.device), **extra_step_kwargs + ).prev_sample else: - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step( + noise_pred, t.to(self.unet.device), latents.to(self.unet.device), **extra_step_kwargs + ).prev_sample # scale and decode the image latents with vae latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents.to(self.vae.device)).sample image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() + image = image.to(self.vae.device).to(self.vae.device).cpu().permute(0, 2, 3, 1).numpy() # run safety checker - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + safety_cheker_input = ( + self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt") + .to(self.vae.device) + .to(self.vae.dtype) + ) image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) if output_type == "pil": diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 102a55a93e4b..2ad871705854 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -1168,6 +1168,27 @@ def test_stable_diffusion_memory_chunking(self): assert mem_bytes > 3.75 * 10**9 assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3 + @slow + @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") + def test_stable_diffusion_further_memory_chunking(self): + torch.cuda.reset_peak_memory_stats() + model_id = "CompVis/stable-diffusion-v1-4" + pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True, revision="main") + pipe.set_progress_bar_config(disable=None) + pipe.enable_minimal_memory_usage() + + prompt = "a photograph of an astronaut riding a horse" + + # make attention efficient + pipe.enable_attention_slicing() + with torch.autocast(torch_device): + _ = pipe([prompt], guidance_scale=7.5, num_inference_steps=10, output_type="numpy") + + mem_bytes = torch.cuda.max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + # make sure that less than 2.3 GB is allocated + assert mem_bytes < 2.3 * 10**9 + @slow @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") def test_stable_diffusion_text2img_pipeline(self):