Skip to content

Commit 1172c96

Browse files
piEspositopedrogengopatrickvonplaten
authored
add enable sequential cpu offloading to other stable diffusion pipelines (#1085)
* add enable sequential cpu offloading to other stable diffusion pipelines * trigger ci * fix styling * interpolate before converting to device to avoid breking when cpu_offload is enabled with fp16 Co-authored-by: Pedro Gengo <[email protected]> * style again I need to stop forgething this thing * fix inpainting bug that could cause device misalignment Co-authored-by: Pedro Gengo <[email protected]> * Apply suggestions from code review Co-authored-by: Pedro Gengo <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 2fcae69 commit 1172c96

File tree

5 files changed

+136
-4
lines changed

5 files changed

+136
-4
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66

77
import PIL
8+
from diffusers.utils import is_accelerate_available
89
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
910

1011
from ...configuration_utils import FrozenDict
@@ -151,6 +152,23 @@ def disable_attention_slicing(self):
151152
# set slice_size = `None` to disable `set_attention_slice`
152153
self.enable_attention_slicing(None)
153154

155+
def enable_sequential_cpu_offload(self):
156+
r"""
157+
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
158+
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
159+
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
160+
"""
161+
if is_accelerate_available():
162+
from accelerate import cpu_offload
163+
else:
164+
raise ImportError("Please install accelerate via `pip install accelerate`")
165+
166+
device = torch.device("cuda")
167+
168+
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
169+
if cpu_offloaded_model is not None:
170+
cpu_offload(cpu_offloaded_model, device)
171+
154172
def enable_xformers_memory_efficient_attention(self):
155173
r"""
156174
Enable memory efficient attention as implemented in xformers.

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66

77
import PIL
8+
from diffusers.utils import is_accelerate_available
89
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
910

1011
from ...configuration_utils import FrozenDict
@@ -151,6 +152,23 @@ def disable_attention_slicing(self):
151152
# set slice_size = `None` to disable `attention slicing`
152153
self.enable_attention_slicing(None)
153154

155+
def enable_sequential_cpu_offload(self):
156+
r"""
157+
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
158+
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
159+
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
160+
"""
161+
if is_accelerate_available():
162+
from accelerate import cpu_offload
163+
else:
164+
raise ImportError("Please install accelerate via `pip install accelerate`")
165+
166+
device = torch.device("cuda")
167+
168+
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
169+
if cpu_offloaded_model is not None:
170+
cpu_offload(cpu_offloaded_model, device)
171+
154172
def enable_xformers_memory_efficient_attention(self):
155173
r"""
156174
Enable memory efficient attention as implemented in xformers.
@@ -361,11 +379,14 @@ def __call__(
361379

362380
# prepare mask and masked_image
363381
mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
364-
mask = mask.to(device=self.device, dtype=text_embeddings.dtype)
365-
masked_image = masked_image.to(device=self.device, dtype=text_embeddings.dtype)
366382

367383
# resize the mask to latents shape as we concatenate the mask to the latents
384+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
385+
# and half precision
368386
mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8))
387+
mask = mask.to(device=self.device, dtype=text_embeddings.dtype)
388+
389+
masked_image = masked_image.to(device=self.device, dtype=text_embeddings.dtype)
369390

370391
# encode the mask image into latents space so we can concatenate it to the latents
371392
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
@@ -380,6 +401,9 @@ def __call__(
380401
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
381402
)
382403

404+
# aligning device to prevent device errors when concating it with the latent model input
405+
masked_image_latents = masked_image_latents.to(device=self.device, dtype=text_embeddings.dtype)
406+
383407
num_channels_mask = mask.shape[1]
384408
num_channels_masked_image = masked_image_latents.shape[1]
385409

tests/pipelines/stable_diffusion/test_stable_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,7 @@ def test_stable_diffusion_low_cpu_mem_usage(self):
840840
assert 2 * low_cpu_mem_usage_time < normal_load_time
841841

842842
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
843-
def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self):
843+
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
844844
torch.cuda.empty_cache()
845845
torch.cuda.reset_max_memory_allocated()
846846

tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,3 +599,48 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No
599599
)
600600
assert test_callback_fn.has_been_called
601601
assert number_of_steps == 38
602+
603+
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
604+
torch.cuda.empty_cache()
605+
torch.cuda.reset_max_memory_allocated()
606+
607+
init_image = load_image(
608+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
609+
"/img2img/sketch-mountains-input.jpg"
610+
)
611+
expected_image = load_image(
612+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
613+
"/img2img/fantasy_landscape_k_lms.png"
614+
)
615+
init_image = init_image.resize((768, 512))
616+
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
617+
618+
model_id = "CompVis/stable-diffusion-v1-4"
619+
lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
620+
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
621+
model_id,
622+
scheduler=lms,
623+
safety_checker=None,
624+
device_map="auto",
625+
)
626+
pipe.to(torch_device)
627+
pipe.set_progress_bar_config(disable=None)
628+
pipe.enable_attention_slicing(1)
629+
pipe.enable_sequential_cpu_offload()
630+
631+
prompt = "A fantasy landscape, trending on artstation"
632+
633+
generator = torch.Generator(device=torch_device).manual_seed(0)
634+
_ = pipe(
635+
prompt=prompt,
636+
init_image=init_image,
637+
strength=0.75,
638+
guidance_scale=7.5,
639+
generator=generator,
640+
output_type="np",
641+
num_inference_steps=5,
642+
)
643+
644+
mem_bytes = torch.cuda.max_memory_allocated()
645+
# make sure that less than 1.5 GB is allocated
646+
assert mem_bytes < 1.5 * 10**9

tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,4 +378,49 @@ def test_stable_diffusion_inpaint_pipeline_pndm(self):
378378
image = output.images[0]
379379

380380
assert image.shape == (512, 512, 3)
381-
assert np.abs(expected_image - image).max() < 1e-3
381+
assert np.abs(expected_image - image).max() < 1e-2
382+
383+
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
384+
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
385+
torch.cuda.empty_cache()
386+
torch.cuda.reset_max_memory_allocated()
387+
388+
init_image = load_image(
389+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
390+
"/in_paint/overture-creations-5sI6fQgYIuo.png"
391+
)
392+
mask_image = load_image(
393+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
394+
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
395+
)
396+
expected_image = load_image(
397+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
398+
"/in_paint/yellow_cat_sitting_on_a_park_bench_pndm.png"
399+
)
400+
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
401+
402+
model_id = "runwayml/stable-diffusion-inpainting"
403+
pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler")
404+
pipe = StableDiffusionInpaintPipeline.from_pretrained(
405+
model_id, safety_checker=None, scheduler=pndm, device_map="auto"
406+
)
407+
pipe.to(torch_device)
408+
pipe.set_progress_bar_config(disable=None)
409+
pipe.enable_attention_slicing(1)
410+
pipe.enable_sequential_cpu_offload()
411+
412+
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
413+
414+
generator = torch.Generator(device=torch_device).manual_seed(0)
415+
_ = pipe(
416+
prompt=prompt,
417+
image=init_image,
418+
mask_image=mask_image,
419+
generator=generator,
420+
num_inference_steps=5,
421+
output_type="np",
422+
)
423+
424+
mem_bytes = torch.cuda.max_memory_allocated()
425+
# make sure that less than 1.5 GB is allocated
426+
assert mem_bytes < 1.5 * 10**9

0 commit comments

Comments
 (0)