Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gc
import inspect
import warnings
from typing import List, Optional, Union
Expand Down Expand Up @@ -90,6 +91,14 @@ def disable_attention_slicing(self):
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)

def enable_minimal_memory_usage(self):
"""Moves only unet to fp16 and to CUDA, while keepping lighter models on CPUs"""
self.unet.to(torch.float16).to(torch.device("cuda"))
self.enable_attention_slicing(1)

torch.cuda.empty_cache()
gc.collect()

@torch.no_grad()
def __call__(
self,
Expand Down Expand Up @@ -136,7 +145,7 @@ def __call__(
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
Expand All @@ -150,16 +159,16 @@ def __call__(
"""

if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
# device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)

# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)
# if device is None:
# device = "cuda" if torch.cuda.is_available() else "cpu"
# self.to(device)

if isinstance(prompt, str):
batch_size = 1
Expand All @@ -179,7 +188,7 @@ def __call__(
truncation=True,
return_tensors="pt",
)
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
text_embeddings = self.text_encoder(text_input.input_ids.to(self.text_encoder.device))[0].to(self.unet.device)

# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
Expand All @@ -191,7 +200,9 @@ def __call__(
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.text_encoder.device))[0].to(
self.unet.device
)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
Expand Down Expand Up @@ -224,7 +235,7 @@ def __call__(

self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)

# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = latents * self.scheduler.sigmas[0]

Expand All @@ -246,7 +257,9 @@ def __call__(
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
noise_pred = self.unet(
latent_model_input.to(self.unet.device), t.to(self.unet.device), encoder_hidden_states=text_embeddings
).sample

# perform guidance
if do_classifier_free_guidance:
Expand All @@ -255,19 +268,27 @@ def __call__(

# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(
noise_pred, i, latents.to(self.unet.device), **extra_step_kwargs
).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(
noise_pred, t.to(self.unet.device), latents.to(self.unet.device), **extra_step_kwargs
).prev_sample

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = self.vae.decode(latents.to(self.vae.device)).sample

image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
image = image.to(self.vae.device).to(self.vae.device).cpu().permute(0, 2, 3, 1).numpy()

# run safety checker
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
safety_cheker_input = (
self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt")
.to(self.vae.device)
.to(self.vae.dtype)
)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)

if output_type == "pil":
Expand Down
21 changes: 21 additions & 0 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,27 @@ def test_stable_diffusion_memory_chunking(self):
assert mem_bytes > 3.75 * 10**9
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3

@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_further_memory_chunking(self):
torch.cuda.reset_peak_memory_stats()
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True, revision="main")
pipe.set_progress_bar_config(disable=None)
pipe.enable_minimal_memory_usage()

prompt = "a photograph of an astronaut riding a horse"

# make attention efficient
pipe.enable_attention_slicing()
with torch.autocast(torch_device):
_ = pipe([prompt], guidance_scale=7.5, num_inference_steps=10, output_type="numpy")

mem_bytes = torch.cuda.max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
# make sure that less than 2.3 GB is allocated
assert mem_bytes < 2.3 * 10**9

@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_text2img_pipeline(self):
Expand Down