Skip to content

Commit 4191dde

Browse files
Revert revert and install accelerate main (#4963)
* Revert "Temp Revert "[Core] better support offloading when side loading is enabled… (#4927)" This reverts commit 2ab1704. * tests: install accelerate from main
1 parent 2ab1704 commit 4191dde

File tree

11 files changed

+281
-2
lines changed

11 files changed

+281
-2
lines changed

.github/workflows/pr_tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ jobs:
6767
run: |
6868
apt-get update && apt-get install libsndfile1-dev libgl1 -y
6969
python -m pip install -e .[quality,test]
70+
python -m pip install git+https://github.com/huggingface/accelerate.git
7071
7172
- name: Environment
7273
run: |

.github/workflows/push_tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ jobs:
6363
run: |
6464
apt-get update && apt-get install libsndfile1-dev libgl1 -y
6565
python -m pip install -e .[quality,test]
66+
python -m pip install git+https://github.com/huggingface/accelerate.git
6667
6768
- name: Environment
6869
run: |

.github/workflows/push_tests_mps.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ jobs:
4040
${CONDA_RUN} python -m pip install --upgrade pip
4141
${CONDA_RUN} python -m pip install -e .[quality,test]
4242
${CONDA_RUN} python -m pip install torch torchvision torchaudio
43-
${CONDA_RUN} python -m pip install accelerate --upgrade
43+
${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate.git
4444
${CONDA_RUN} python -m pip install transformers --upgrade
4545
4646
- name: Environment

src/diffusers/loaders.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545

4646
if is_accelerate_available():
4747
from accelerate import init_empty_weights
48+
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
4849
from accelerate.utils import set_module_tensor_to_device
4950

5051
logger = logging.get_logger(__name__)
@@ -778,6 +779,21 @@ def load_textual_inversion(
778779
f" `{self.load_textual_inversion.__name__}`"
779780
)
780781

782+
# Remove any existing hooks.
783+
is_model_cpu_offload = False
784+
is_sequential_cpu_offload = False
785+
recursive = False
786+
for _, component in self.components.items():
787+
if isinstance(component, nn.Module):
788+
if hasattr(component, "_hf_hook"):
789+
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
790+
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
791+
logger.info(
792+
"Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
793+
)
794+
recursive = is_sequential_cpu_offload
795+
remove_hook_from_module(component, recurse=recursive)
796+
781797
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
782798
force_download = kwargs.pop("force_download", False)
783799
resume_download = kwargs.pop("resume_download", False)
@@ -931,6 +947,12 @@ def load_textual_inversion(
931947
for token_id, embedding in token_ids_and_embeddings:
932948
text_encoder.get_input_embeddings().weight.data[token_id] = embedding
933949

950+
# offload back
951+
if is_model_cpu_offload:
952+
self.enable_model_cpu_offload()
953+
elif is_sequential_cpu_offload:
954+
self.enable_sequential_cpu_offload()
955+
934956

935957
class LoraLoaderMixin:
936958
r"""
@@ -962,6 +984,21 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
962984
kwargs (`dict`, *optional*):
963985
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
964986
"""
987+
# Remove any existing hooks.
988+
is_model_cpu_offload = False
989+
is_sequential_cpu_offload = False
990+
recurive = False
991+
for _, component in self.components.items():
992+
if isinstance(component, nn.Module):
993+
if hasattr(component, "_hf_hook"):
994+
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
995+
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
996+
logger.info(
997+
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
998+
)
999+
recurive = is_sequential_cpu_offload
1000+
remove_hook_from_module(component, recurse=recurive)
1001+
9651002
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
9661003
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
9671004
self.load_lora_into_text_encoder(
@@ -971,6 +1008,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
9711008
lora_scale=self.lora_scale,
9721009
)
9731010

1011+
# Offload back.
1012+
if is_model_cpu_offload:
1013+
self.enable_model_cpu_offload()
1014+
elif is_sequential_cpu_offload:
1015+
self.enable_sequential_cpu_offload()
1016+
9741017
@classmethod
9751018
def lora_state_dict(
9761019
cls,

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1549,6 +1549,26 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
15491549
# We could have accessed the unet config from `lora_state_dict()` too. We pass
15501550
# it here explicitly to be able to tell that it's coming from an SDXL
15511551
# pipeline.
1552+
1553+
# Remove any existing hooks.
1554+
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
1555+
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
1556+
else:
1557+
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
1558+
1559+
is_model_cpu_offload = False
1560+
is_sequential_cpu_offload = False
1561+
recursive = False
1562+
for _, component in self.components.items():
1563+
if isinstance(component, torch.nn.Module):
1564+
if hasattr(component, "_hf_hook"):
1565+
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
1566+
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
1567+
logger.info(
1568+
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
1569+
)
1570+
recursive = is_sequential_cpu_offload
1571+
remove_hook_from_module(component, recurse=recursive)
15521572
state_dict, network_alphas = self.lora_state_dict(
15531573
pretrained_model_name_or_path_or_dict,
15541574
unet_config=self.unet.config,
@@ -1576,6 +1596,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
15761596
lora_scale=self.lora_scale,
15771597
)
15781598

1599+
# Offload back.
1600+
if is_model_cpu_offload:
1601+
self.enable_model_cpu_offload()
1602+
elif is_sequential_cpu_offload:
1603+
self.enable_sequential_cpu_offload()
1604+
15791605
@classmethod
15801606
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
15811607
def save_lora_weights(

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,6 +1216,26 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
12161216
# We could have accessed the unet config from `lora_state_dict()` too. We pass
12171217
# it here explicitly to be able to tell that it's coming from an SDXL
12181218
# pipeline.
1219+
1220+
# Remove any existing hooks.
1221+
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
1222+
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
1223+
else:
1224+
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
1225+
1226+
is_model_cpu_offload = False
1227+
is_sequential_cpu_offload = False
1228+
recursive = False
1229+
for _, component in self.components.items():
1230+
if isinstance(component, torch.nn.Module):
1231+
if hasattr(component, "_hf_hook"):
1232+
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
1233+
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
1234+
logger.info(
1235+
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
1236+
)
1237+
recursive = is_sequential_cpu_offload
1238+
remove_hook_from_module(component, recurse=recursive)
12191239
state_dict, network_alphas = self.lora_state_dict(
12201240
pretrained_model_name_or_path_or_dict,
12211241
unet_config=self.unet.config,
@@ -1243,6 +1263,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
12431263
lora_scale=self.lora_scale,
12441264
)
12451265

1266+
# Offload back.
1267+
if is_model_cpu_offload:
1268+
self.enable_model_cpu_offload()
1269+
elif is_sequential_cpu_offload:
1270+
self.enable_sequential_cpu_offload()
1271+
12461272
@classmethod
12471273
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
12481274
def save_lora_weights(

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -922,6 +922,26 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
922922
# We could have accessed the unet config from `lora_state_dict()` too. We pass
923923
# it here explicitly to be able to tell that it's coming from an SDXL
924924
# pipeline.
925+
926+
# Remove any existing hooks.
927+
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
928+
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
929+
else:
930+
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
931+
932+
is_model_cpu_offload = False
933+
is_sequential_cpu_offload = False
934+
recursive = False
935+
for _, component in self.components.items():
936+
if isinstance(component, torch.nn.Module):
937+
if hasattr(component, "_hf_hook"):
938+
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
939+
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
940+
logger.info(
941+
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
942+
)
943+
recursive = is_sequential_cpu_offload
944+
remove_hook_from_module(component, recurse=recursive)
925945
state_dict, network_alphas = self.lora_state_dict(
926946
pretrained_model_name_or_path_or_dict,
927947
unet_config=self.unet.config,
@@ -949,6 +969,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
949969
lora_scale=self.lora_scale,
950970
)
951971

972+
# Offload back.
973+
if is_model_cpu_offload:
974+
self.enable_model_cpu_offload()
975+
elif is_sequential_cpu_offload:
976+
self.enable_sequential_cpu_offload()
977+
952978
@classmethod
953979
def save_lora_weights(
954980
self,

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,6 +1072,26 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
10721072
# We could have accessed the unet config from `lora_state_dict()` too. We pass
10731073
# it here explicitly to be able to tell that it's coming from an SDXL
10741074
# pipeline.
1075+
1076+
# Remove any existing hooks.
1077+
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
1078+
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
1079+
else:
1080+
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
1081+
1082+
is_model_cpu_offload = False
1083+
is_sequential_cpu_offload = False
1084+
recursive = False
1085+
for _, component in self.components.items():
1086+
if isinstance(component, torch.nn.Module):
1087+
if hasattr(component, "_hf_hook"):
1088+
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
1089+
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
1090+
logger.info(
1091+
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
1092+
)
1093+
recursive = is_sequential_cpu_offload
1094+
remove_hook_from_module(component, recurse=recursive)
10751095
state_dict, network_alphas = self.lora_state_dict(
10761096
pretrained_model_name_or_path_or_dict,
10771097
unet_config=self.unet.config,
@@ -1099,6 +1119,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
10991119
lora_scale=self.lora_scale,
11001120
)
11011121

1122+
# Offload back.
1123+
if is_model_cpu_offload:
1124+
self.enable_model_cpu_offload()
1125+
elif is_sequential_cpu_offload:
1126+
self.enable_sequential_cpu_offload()
1127+
11021128
@classmethod
11031129
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
11041130
def save_lora_weights(

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,6 +1392,26 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
13921392
# We could have accessed the unet config from `lora_state_dict()` too. We pass
13931393
# it here explicitly to be able to tell that it's coming from an SDXL
13941394
# pipeline.
1395+
1396+
# Remove any existing hooks.
1397+
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
1398+
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
1399+
else:
1400+
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
1401+
1402+
is_model_cpu_offload = False
1403+
is_sequential_cpu_offload = False
1404+
recursive = False
1405+
for _, component in self.components.items():
1406+
if isinstance(component, torch.nn.Module):
1407+
if hasattr(component, "_hf_hook"):
1408+
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
1409+
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
1410+
logger.info(
1411+
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
1412+
)
1413+
recursive = is_sequential_cpu_offload
1414+
remove_hook_from_module(component, recurse=recursive)
13951415
state_dict, network_alphas = self.lora_state_dict(
13961416
pretrained_model_name_or_path_or_dict,
13971417
unet_config=self.unet.config,
@@ -1419,6 +1439,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
14191439
lora_scale=self.lora_scale,
14201440
)
14211441

1442+
# Offload back.
1443+
if is_model_cpu_offload:
1444+
self.enable_model_cpu_offload()
1445+
elif is_sequential_cpu_offload:
1446+
self.enable_sequential_cpu_offload()
1447+
14221448
@classmethod
14231449
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
14241450
def save_lora_weights(

tests/models/test_lora_layers.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1081,6 +1081,42 @@ def test_a1111(self):
10811081

10821082
self.assertTrue(np.allclose(images, expected, atol=1e-3))
10831083

1084+
def test_a1111_with_model_cpu_offload(self):
1085+
generator = torch.Generator().manual_seed(0)
1086+
1087+
pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None)
1088+
pipe.enable_model_cpu_offload()
1089+
lora_model_id = "hf-internal-testing/civitai-light-shadow-lora"
1090+
lora_filename = "light_and_shadow.safetensors"
1091+
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
1092+
1093+
images = pipe(
1094+
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
1095+
).images
1096+
1097+
images = images[0, -3:, -3:, -1].flatten()
1098+
expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292])
1099+
1100+
self.assertTrue(np.allclose(images, expected, atol=1e-3))
1101+
1102+
def test_a1111_with_sequential_cpu_offload(self):
1103+
generator = torch.Generator().manual_seed(0)
1104+
1105+
pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None)
1106+
pipe.enable_sequential_cpu_offload()
1107+
lora_model_id = "hf-internal-testing/civitai-light-shadow-lora"
1108+
lora_filename = "light_and_shadow.safetensors"
1109+
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
1110+
1111+
images = pipe(
1112+
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
1113+
).images
1114+
1115+
images = images[0, -3:, -3:, -1].flatten()
1116+
expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292])
1117+
1118+
self.assertTrue(np.allclose(images, expected, atol=1e-3))
1119+
10841120
def test_kohya_sd_v15_with_higher_dimensions(self):
10851121
generator = torch.Generator().manual_seed(0)
10861122

@@ -1257,10 +1293,10 @@ def test_sdxl_1_0_lora(self):
12571293
generator = torch.Generator().manual_seed(0)
12581294

12591295
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
1296+
pipe.enable_model_cpu_offload()
12601297
lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
12611298
lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
12621299
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
1263-
pipe.enable_model_cpu_offload()
12641300

12651301
images = pipe(
12661302
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
@@ -1411,3 +1447,21 @@ def test_sdxl_1_0_fuse_unfuse_all(self):
14111447
assert state_dicts_almost_equal(text_encoder_1_sd, pipe.text_encoder.state_dict())
14121448
assert state_dicts_almost_equal(text_encoder_2_sd, pipe.text_encoder_2.state_dict())
14131449
assert state_dicts_almost_equal(unet_sd, pipe.unet.state_dict())
1450+
1451+
def test_sdxl_1_0_lora_with_sequential_cpu_offloading(self):
1452+
generator = torch.Generator().manual_seed(0)
1453+
1454+
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
1455+
pipe.enable_sequential_cpu_offload()
1456+
lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
1457+
lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
1458+
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
1459+
1460+
images = pipe(
1461+
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
1462+
).images
1463+
1464+
images = images[0, -3:, -3:, -1].flatten()
1465+
expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535])
1466+
1467+
self.assertTrue(np.allclose(images, expected, atol=1e-3))

0 commit comments

Comments
 (0)