Skip to content

Commit 24895a1

Browse files
authored
Fix cpu offloading (#1177)
* Fix cpu offloading * get offloaded devices locally for SD pipelines
1 parent 598ff76 commit 24895a1

File tree

7 files changed

+107
-61
lines changed

7 files changed

+107
-61
lines changed

src/diffusers/pipeline_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,6 @@ def device(self) -> torch.device:
230230
for name in module_names.keys():
231231
module = getattr(self, name)
232232
if isinstance(module, torch.nn.Module):
233-
if module.device == torch.device("meta"):
234-
return torch.device("cpu")
235233
return module.device
236234
return torch.device("cpu")
237235

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,24 @@ def enable_sequential_cpu_offload(self):
195195
if cpu_offloaded_model is not None:
196196
cpu_offload(cpu_offloaded_model, device)
197197

198+
@property
199+
def _execution_device(self):
200+
r"""
201+
Returns the device on which the pipeline's models will be executed. After calling
202+
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
203+
hooks.
204+
"""
205+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
206+
return self.device
207+
for module in self.unet.modules():
208+
if (
209+
hasattr(module, "_hf_hook")
210+
and hasattr(module._hf_hook, "execution_device")
211+
and module._hf_hook.execution_device is not None
212+
):
213+
return torch.device(module._hf_hook.execution_device)
214+
return self.device
215+
198216
@torch.no_grad()
199217
def __call__(
200218
self,
@@ -286,6 +304,8 @@ def __call__(
286304
f" {type(callback_steps)}."
287305
)
288306

307+
device = self._execution_device
308+
289309
# get prompt text embeddings
290310
text_inputs = self.tokenizer(
291311
prompt,
@@ -302,7 +322,7 @@ def __call__(
302322
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
303323
)
304324
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
305-
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
325+
text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
306326

307327
# duplicate text embeddings for each generation per prompt, using mps friendly method
308328
bs_embed, seq_len, _ = text_embeddings.shape
@@ -342,7 +362,7 @@ def __call__(
342362
truncation=True,
343363
return_tensors="pt",
344364
)
345-
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
365+
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
346366

347367
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
348368
seq_len = uncond_embeddings.shape[1]
@@ -362,20 +382,18 @@ def __call__(
362382
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
363383
latents_dtype = text_embeddings.dtype
364384
if latents is None:
365-
if self.device.type == "mps":
385+
if device.type == "mps":
366386
# randn does not work reproducibly on mps
367-
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
368-
self.device
369-
)
387+
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(device)
370388
else:
371-
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
389+
latents = torch.randn(latents_shape, generator=generator, device=device, dtype=latents_dtype)
372390
else:
373391
if latents.shape != latents_shape:
374392
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
375-
latents = latents.to(self.device)
393+
latents = latents.to(device)
376394

377395
# set timesteps and move to the correct device
378-
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
396+
self.scheduler.set_timesteps(num_inference_steps, device=device)
379397
timesteps_tensor = self.scheduler.timesteps
380398

381399
# scale the initial noise by the standard deviation required by the scheduler
@@ -424,9 +442,7 @@ def __call__(
424442
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
425443

426444
if self.safety_checker is not None:
427-
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
428-
self.device
429-
)
445+
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
430446
image, has_nsfw_concept = self.safety_checker(
431447
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
432448
)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,25 @@ def enable_sequential_cpu_offload(self):
183183
if cpu_offloaded_model is not None:
184184
cpu_offload(cpu_offloaded_model, device)
185185

186+
@property
187+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
188+
def _execution_device(self):
189+
r"""
190+
Returns the device on which the pipeline's models will be executed. After calling
191+
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
192+
hooks.
193+
"""
194+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
195+
return self.device
196+
for module in self.unet.modules():
197+
if (
198+
hasattr(module, "_hf_hook")
199+
and hasattr(module._hf_hook, "execution_device")
200+
and module._hf_hook.execution_device is not None
201+
):
202+
return torch.device(module._hf_hook.execution_device)
203+
return self.device
204+
186205
def enable_xformers_memory_efficient_attention(self):
187206
r"""
188207
Enable memory efficient attention as implemented in xformers.
@@ -292,6 +311,8 @@ def __call__(
292311
f" {type(callback_steps)}."
293312
)
294313

314+
device = self._execution_device
315+
295316
# set timesteps
296317
self.scheduler.set_timesteps(num_inference_steps)
297318

@@ -314,7 +335,7 @@ def __call__(
314335
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
315336
)
316337
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
317-
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
338+
text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
318339

319340
# duplicate text embeddings for each generation per prompt
320341
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
@@ -348,7 +369,7 @@ def __call__(
348369
truncation=True,
349370
return_tensors="pt",
350371
)
351-
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
372+
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
352373

353374
# duplicate unconditional embeddings for each generation per prompt
354375
seq_len = uncond_embeddings.shape[1]
@@ -362,7 +383,7 @@ def __call__(
362383

363384
# encode the init image into latents and scale the latents
364385
latents_dtype = text_embeddings.dtype
365-
init_image = init_image.to(device=self.device, dtype=latents_dtype)
386+
init_image = init_image.to(device=device, dtype=latents_dtype)
366387
init_latent_dist = self.vae.encode(init_image).latent_dist
367388
init_latents = init_latent_dist.sample(generator=generator)
368389
init_latents = 0.18215 * init_latents
@@ -393,10 +414,10 @@ def __call__(
393414
init_timestep = min(init_timestep, num_inference_steps)
394415

395416
timesteps = self.scheduler.timesteps[-init_timestep]
396-
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
417+
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=device)
397418

398419
# add noise to latents using the timesteps
399-
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
420+
noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=latents_dtype)
400421
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
401422

402423
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -419,7 +440,7 @@ def __call__(
419440

420441
# Some schedulers like PNDM have timesteps as arrays
421442
# It's more optimized to move all timesteps to correct device beforehand
422-
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
443+
timesteps = self.scheduler.timesteps[t_start:].to(device)
423444

424445
for i, t in enumerate(self.progress_bar(timesteps)):
425446
# expand the latents if we are doing classifier free guidance
@@ -448,9 +469,7 @@ def __call__(
448469
image = image.cpu().permute(0, 2, 3, 1).numpy()
449470

450471
if self.safety_checker is not None:
451-
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
452-
self.device
453-
)
472+
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
454473
image, has_nsfw_concept = self.safety_checker(
455474
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
456475
)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,25 @@ def enable_sequential_cpu_offload(self):
183183
if cpu_offloaded_model is not None:
184184
cpu_offload(cpu_offloaded_model, device)
185185

186+
@property
187+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
188+
def _execution_device(self):
189+
r"""
190+
Returns the device on which the pipeline's models will be executed. After calling
191+
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
192+
hooks.
193+
"""
194+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
195+
return self.device
196+
for module in self.unet.modules():
197+
if (
198+
hasattr(module, "_hf_hook")
199+
and hasattr(module._hf_hook, "execution_device")
200+
and module._hf_hook.execution_device is not None
201+
):
202+
return torch.device(module._hf_hook.execution_device)
203+
return self.device
204+
186205
def enable_xformers_memory_efficient_attention(self):
187206
r"""
188207
Enable memory efficient attention as implemented in xformers.
@@ -303,6 +322,8 @@ def __call__(
303322
f" {type(callback_steps)}."
304323
)
305324

325+
device = self._execution_device
326+
306327
# get prompt text embeddings
307328
text_inputs = self.tokenizer(
308329
prompt,
@@ -319,7 +340,7 @@ def __call__(
319340
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
320341
)
321342
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
322-
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
343+
text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
323344

324345
# duplicate text embeddings for each generation per prompt, using mps friendly method
325346
bs_embed, seq_len, _ = text_embeddings.shape
@@ -359,7 +380,7 @@ def __call__(
359380
truncation=True,
360381
return_tensors="pt",
361382
)
362-
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
383+
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
363384

364385
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
365386
seq_len = uncond_embeddings.shape[1]
@@ -379,17 +400,15 @@ def __call__(
379400
latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
380401
latents_dtype = text_embeddings.dtype
381402
if latents is None:
382-
if self.device.type == "mps":
403+
if device.type == "mps":
383404
# randn does not exist on mps
384-
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
385-
self.device
386-
)
405+
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(device)
387406
else:
388-
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
407+
latents = torch.randn(latents_shape, generator=generator, device=device, dtype=latents_dtype)
389408
else:
390409
if latents.shape != latents_shape:
391410
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
392-
latents = latents.to(self.device)
411+
latents = latents.to(device)
393412

394413
# prepare mask and masked_image
395414
mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
@@ -398,9 +417,9 @@ def __call__(
398417
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
399418
# and half precision
400419
mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8))
401-
mask = mask.to(device=self.device, dtype=text_embeddings.dtype)
420+
mask = mask.to(device=device, dtype=text_embeddings.dtype)
402421

403-
masked_image = masked_image.to(device=self.device, dtype=text_embeddings.dtype)
422+
masked_image = masked_image.to(device=device, dtype=text_embeddings.dtype)
404423

405424
# encode the mask image into latents space so we can concatenate it to the latents
406425
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
@@ -416,7 +435,7 @@ def __call__(
416435
)
417436

418437
# aligning device to prevent device errors when concating it with the latent model input
419-
masked_image_latents = masked_image_latents.to(device=self.device, dtype=text_embeddings.dtype)
438+
masked_image_latents = masked_image_latents.to(device=device, dtype=text_embeddings.dtype)
420439

421440
num_channels_mask = mask.shape[1]
422441
num_channels_masked_image = masked_image_latents.shape[1]
@@ -431,7 +450,7 @@ def __call__(
431450
)
432451

433452
# set timesteps and move to the correct device
434-
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
453+
self.scheduler.set_timesteps(num_inference_steps, device=device)
435454
timesteps_tensor = self.scheduler.timesteps
436455

437456
# scale the initial noise by the standard deviation required by the scheduler
@@ -484,9 +503,7 @@ def __call__(
484503
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
485504

486505
if self.safety_checker is not None:
487-
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
488-
self.device
489-
)
506+
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
490507
image, has_nsfw_concept = self.safety_checker(
491508
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
492509
)

tests/pipelines/stable_diffusion/test_stable_diffusion.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -839,20 +839,22 @@ def test_stable_diffusion_low_cpu_mem_usage(self):
839839

840840
assert 2 * low_cpu_mem_usage_time < normal_load_time
841841

842-
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
843842
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
844843
torch.cuda.empty_cache()
845844
torch.cuda.reset_max_memory_allocated()
845+
torch.cuda.reset_peak_memory_stats()
846846

847847
pipeline_id = "CompVis/stable-diffusion-v1-4"
848848
prompt = "Andromeda galaxy in a bottle"
849849

850850
pipeline = StableDiffusionPipeline.from_pretrained(pipeline_id, revision="fp16", torch_dtype=torch.float16)
851+
pipeline = pipeline.to(torch_device)
851852
pipeline.enable_attention_slicing(1)
852853
pipeline.enable_sequential_cpu_offload()
853854

854-
_ = pipeline(prompt, num_inference_steps=5)
855+
generator = torch.Generator(device=torch_device).manual_seed(0)
856+
_ = pipeline(prompt, generator=generator, num_inference_steps=5)
855857

856858
mem_bytes = torch.cuda.max_memory_allocated()
857-
# make sure that less than 1.5 GB is allocated
858-
assert mem_bytes < 1.5 * 10**9
859+
# make sure that less than 2.8 GB is allocated
860+
assert mem_bytes < 2.8 * 10**9

tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -603,25 +603,18 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No
603603
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
604604
torch.cuda.empty_cache()
605605
torch.cuda.reset_max_memory_allocated()
606+
torch.cuda.reset_peak_memory_stats()
606607

607608
init_image = load_image(
608609
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
609610
"/img2img/sketch-mountains-input.jpg"
610611
)
611-
expected_image = load_image(
612-
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
613-
"/img2img/fantasy_landscape_k_lms.png"
614-
)
615612
init_image = init_image.resize((768, 512))
616-
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
617613

618614
model_id = "CompVis/stable-diffusion-v1-4"
619615
lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
620616
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
621-
model_id,
622-
scheduler=lms,
623-
safety_checker=None,
624-
device_map="auto",
617+
model_id, scheduler=lms, safety_checker=None, device_map="auto", revision="fp16", torch_dtype=torch.float16
625618
)
626619
pipe.to(torch_device)
627620
pipe.set_progress_bar_config(disable=None)
@@ -642,5 +635,5 @@ def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
642635
)
643636

644637
mem_bytes = torch.cuda.max_memory_allocated()
645-
# make sure that less than 1.5 GB is allocated
646-
assert mem_bytes < 1.5 * 10**9
638+
# make sure that less than 2.2 GB is allocated
639+
assert mem_bytes < 2.2 * 10**9

0 commit comments

Comments
 (0)