Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

import PIL
from diffusers.utils import is_accelerate_available
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer

from ...configuration_utils import FrozenDict
Expand Down Expand Up @@ -151,6 +152,23 @@ def disable_attention_slicing(self):
# set slice_size = `None` to disable `set_attention_slice`
self.enable_attention_slicing(None)

def enable_sequential_cpu_offload(self):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")

device = torch.device("cuda")

for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)

def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

import PIL
from diffusers.utils import is_accelerate_available
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer

from ...configuration_utils import FrozenDict
Expand Down Expand Up @@ -151,6 +152,23 @@ def disable_attention_slicing(self):
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)

def enable_sequential_cpu_offload(self):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")

device = torch.device("cuda")

for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)

def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
Expand Down Expand Up @@ -361,11 +379,14 @@ def __call__(

# prepare mask and masked_image
mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
mask = mask.to(device=self.device, dtype=text_embeddings.dtype)
masked_image = masked_image.to(device=self.device, dtype=text_embeddings.dtype)

# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8))
mask = mask.to(device=self.device, dtype=text_embeddings.dtype)

masked_image = masked_image.to(device=self.device, dtype=text_embeddings.dtype)

# encode the mask image into latents space so we can concatenate it to the latents
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
Expand All @@ -380,6 +401,9 @@ def __call__(
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
)

# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=self.device, dtype=text_embeddings.dtype)

num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]

Expand Down
2 changes: 1 addition & 1 deletion tests/pipelines/stable_diffusion/test_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ def test_stable_diffusion_low_cpu_mem_usage(self):
assert 2 * low_cpu_mem_usage_time < normal_load_time

@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self):
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()

Expand Down
45 changes: 45 additions & 0 deletions tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,3 +599,48 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No
)
assert test_callback_fn.has_been_called
assert number_of_steps == 38

def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()

init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/sketch-mountains-input.jpg"
)
expected_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/fantasy_landscape_k_lms.png"
)
init_image = init_image.resize((768, 512))
expected_image = np.array(expected_image, dtype=np.float32) / 255.0

model_id = "CompVis/stable-diffusion-v1-4"
lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id,
scheduler=lms,
safety_checker=None,
device_map="auto",
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing(1)
pipe.enable_sequential_cpu_offload()

prompt = "A fantasy landscape, trending on artstation"

generator = torch.Generator(device=torch_device).manual_seed(0)
_ = pipe(
prompt=prompt,
init_image=init_image,
strength=0.75,
guidance_scale=7.5,
generator=generator,
output_type="np",
num_inference_steps=5,
)

mem_bytes = torch.cuda.max_memory_allocated()
# make sure that less than 1.5 GB is allocated
assert mem_bytes < 1.5 * 10**9
Original file line number Diff line number Diff line change
Expand Up @@ -378,4 +378,49 @@ def test_stable_diffusion_inpaint_pipeline_pndm(self):
image = output.images[0]

assert image.shape == (512, 512, 3)
assert np.abs(expected_image - image).max() < 1e-3
assert np.abs(expected_image - image).max() < 1e-2

@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()

init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo.png"
)
mask_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
)
expected_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/yellow_cat_sitting_on_a_park_bench_pndm.png"
)
expected_image = np.array(expected_image, dtype=np.float32) / 255.0

model_id = "runwayml/stable-diffusion-inpainting"
pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler")
pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id, safety_checker=None, scheduler=pndm, device_map="auto"
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing(1)
pipe.enable_sequential_cpu_offload()

prompt = "Face of a yellow cat, high resolution, sitting on a park bench"

generator = torch.Generator(device=torch_device).manual_seed(0)
_ = pipe(
prompt=prompt,
image=init_image,
mask_image=mask_image,
generator=generator,
num_inference_steps=5,
output_type="np",
)

mem_bytes = torch.cuda.max_memory_allocated()
# make sure that less than 1.5 GB is allocated
assert mem_bytes < 1.5 * 10**9