Skip to content

Commit 965dfe1

Browse files
committed
move to cpu_offload along with minor internal changes to make it work
1 parent 4b4c69a commit 965dfe1

File tree

4 files changed

+9
-12
lines changed

4 files changed

+9
-12
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
# 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py
8080
_deps = [
8181
"Pillow<10.0", # keep the PIL.Image.Resampling deprecation away
82-
"accelerate>=0.11.0",
82+
"accelerate>=0.14.0",
8383
"black==22.8",
8484
"datasets",
8585
"filelock",

src/diffusers/pipeline_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,8 @@ def device(self) -> torch.device:
206206
for name in module_names.keys():
207207
module = getattr(self, name)
208208
if isinstance(module, torch.nn.Module):
209+
if module.device == torch.device("meta"):
210+
return torch.device("cpu")
209211
return module.device
210212
return torch.device("cpu")
211213

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -121,20 +121,15 @@ def disable_attention_slicing(self):
121121

122122
def cuda_with_minimal_gpu_usage(self):
123123
if is_accelerate_available():
124-
from accelerate.hooks import attach_execution_device_hook
124+
from accelerate import cpu_offload
125125
else:
126126
raise ImportError("Please install accelerate via `pip install accelerate`")
127127

128128
device = torch.device("cuda")
129-
130-
self.unet.half().to(device)
131-
attach_execution_device_hook(self.unet, device)
132-
self.unet.forward = torch.autocast("cuda")(self.unet.forward)
133129
self.enable_attention_slicing(1)
134130

135-
for cpu_offloaded_model in [self.text_encoder, self.vae, self.safety_checker]:
136-
cpu_offloaded_model.to(torch.float32)
137-
attach_execution_device_hook(cpu_offloaded_model, "cpu")
131+
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
132+
cpu_offload(cpu_offloaded_model, device)
138133

139134
@torch.no_grad()
140135
def __call__(
@@ -310,7 +305,7 @@ def __call__(
310305
self.device
311306
)
312307
else:
313-
latents = torch.randn(latents_shape, generator=generator, device=self.unet.device, dtype=latents_dtype)
308+
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
314309
else:
315310
if latents.shape != latents_shape:
316311
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")

tests/test_pipelines.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2293,5 +2293,5 @@ def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self):
22932293
_ = pipeline(prompt)
22942294

22952295
mem_bytes = torch.cuda.max_memory_allocated()
2296-
# make sure that less than 2.2 GB is allocated
2297-
assert mem_bytes < 2.2 * 10**9
2296+
# make sure that less than 0.8 GB is allocated
2297+
assert mem_bytes < 0.8 * 10**9

0 commit comments

Comments
 (0)