Skip to content

Commit 2e9af1e

Browse files
authored
[Core] better support offloading when side loading is enabled. (huggingface#4855)
* better support offloading when side loading is enabled. * load_textual_inversion * better messaging for textual inversion. * fixes * address PR feedback. * sdxl support. * improve messaging * recursive removal when cpu sequential offloading is enabled. * add: lora tests * recruse. * add: offload tests for textual inversion.
1 parent 32cde23 commit 2e9af1e

File tree

6 files changed

+173
-0
lines changed

6 files changed

+173
-0
lines changed

loaders.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545

4646
if is_accelerate_available():
4747
from accelerate import init_empty_weights
48+
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
4849
from accelerate.utils import set_module_tensor_to_device
4950

5051
logger = logging.get_logger(__name__)
@@ -768,6 +769,21 @@ def load_textual_inversion(
768769
f" `{self.load_textual_inversion.__name__}`"
769770
)
770771

772+
# Remove any existing hooks.
773+
is_model_cpu_offload = False
774+
is_sequential_cpu_offload = False
775+
recursive = False
776+
for _, component in self.components.items():
777+
if isinstance(component, nn.Module):
778+
if hasattr(component, "_hf_hook"):
779+
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
780+
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
781+
logger.info(
782+
"Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
783+
)
784+
recursive = is_sequential_cpu_offload
785+
remove_hook_from_module(component, recurse=recursive)
786+
771787
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
772788
force_download = kwargs.pop("force_download", False)
773789
resume_download = kwargs.pop("resume_download", False)
@@ -921,6 +937,12 @@ def load_textual_inversion(
921937
for token_id, embedding in token_ids_and_embeddings:
922938
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
923939

940+
# offload back
941+
if is_model_cpu_offload:
942+
self.enable_model_cpu_offload()
943+
elif is_sequential_cpu_offload:
944+
self.enable_sequential_cpu_offload()
945+
924946

925947
class LoraLoaderMixin:
926948
r"""
@@ -952,6 +974,21 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
952974
kwargs (`dict`, *optional*):
953975
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
954976
"""
977+
# Remove any existing hooks.
978+
is_model_cpu_offload = False
979+
is_sequential_cpu_offload = False
980+
recurive = False
981+
for _, component in self.components.items():
982+
if isinstance(component, nn.Module):
983+
if hasattr(component, "_hf_hook"):
984+
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
985+
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
986+
logger.info(
987+
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
988+
)
989+
recurive = is_sequential_cpu_offload
990+
remove_hook_from_module(component, recurse=recurive)
991+
955992
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
956993
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
957994
self.load_lora_into_text_encoder(
@@ -961,6 +998,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
961998
lora_scale=self.lora_scale,
962999
)
9631000

1001+
# Offload back.
1002+
if is_model_cpu_offload:
1003+
self.enable_model_cpu_offload()
1004+
elif is_sequential_cpu_offload:
1005+
self.enable_sequential_cpu_offload()
1006+
9641007
@classmethod
9651008
def lora_state_dict(
9661009
cls,

pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1549,6 +1549,26 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
15491549
# We could have accessed the unet config from `lora_state_dict()` too. We pass
15501550
# it here explicitly to be able to tell that it's coming from an SDXL
15511551
# pipeline.
1552+
1553+
# Remove any existing hooks.
1554+
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
1555+
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
1556+
else:
1557+
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
1558+
1559+
is_model_cpu_offload = False
1560+
is_sequential_cpu_offload = False
1561+
recursive = False
1562+
for _, component in self.components.items():
1563+
if isinstance(component, torch.nn.Module):
1564+
if hasattr(component, "_hf_hook"):
1565+
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
1566+
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
1567+
logger.info(
1568+
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
1569+
)
1570+
recursive = is_sequential_cpu_offload
1571+
remove_hook_from_module(component, recurse=recursive)
15521572
state_dict, network_alphas = self.lora_state_dict(
15531573
pretrained_model_name_or_path_or_dict,
15541574
unet_config=self.unet.config,
@@ -1576,6 +1596,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
15761596
lora_scale=self.lora_scale,
15771597
)
15781598

1599+
# Offload back.
1600+
if is_model_cpu_offload:
1601+
self.enable_model_cpu_offload()
1602+
elif is_sequential_cpu_offload:
1603+
self.enable_sequential_cpu_offload()
1604+
15791605
@classmethod
15801606
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
15811607
def save_lora_weights(

pipelines/controlnet/pipeline_controlnet_sd_xl.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,6 +1212,26 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
12121212
# We could have accessed the unet config from `lora_state_dict()` too. We pass
12131213
# it here explicitly to be able to tell that it's coming from an SDXL
12141214
# pipeline.
1215+
1216+
# Remove any existing hooks.
1217+
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
1218+
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
1219+
else:
1220+
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
1221+
1222+
is_model_cpu_offload = False
1223+
is_sequential_cpu_offload = False
1224+
recursive = False
1225+
for _, component in self.components.items():
1226+
if isinstance(component, torch.nn.Module):
1227+
if hasattr(component, "_hf_hook"):
1228+
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
1229+
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
1230+
logger.info(
1231+
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
1232+
)
1233+
recursive = is_sequential_cpu_offload
1234+
remove_hook_from_module(component, recurse=recursive)
12151235
state_dict, network_alphas = self.lora_state_dict(
12161236
pretrained_model_name_or_path_or_dict,
12171237
unet_config=self.unet.config,
@@ -1239,6 +1259,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
12391259
lora_scale=self.lora_scale,
12401260
)
12411261

1262+
# Offload back.
1263+
if is_model_cpu_offload:
1264+
self.enable_model_cpu_offload()
1265+
elif is_sequential_cpu_offload:
1266+
self.enable_sequential_cpu_offload()
1267+
12421268
@classmethod
12431269
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
12441270
def save_lora_weights(

pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,26 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
916916
# We could have accessed the unet config from `lora_state_dict()` too. We pass
917917
# it here explicitly to be able to tell that it's coming from an SDXL
918918
# pipeline.
919+
920+
# Remove any existing hooks.
921+
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
922+
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
923+
else:
924+
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
925+
926+
is_model_cpu_offload = False
927+
is_sequential_cpu_offload = False
928+
recursive = False
929+
for _, component in self.components.items():
930+
if isinstance(component, torch.nn.Module):
931+
if hasattr(component, "_hf_hook"):
932+
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
933+
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
934+
logger.info(
935+
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
936+
)
937+
recursive = is_sequential_cpu_offload
938+
remove_hook_from_module(component, recurse=recursive)
919939
state_dict, network_alphas = self.lora_state_dict(
920940
pretrained_model_name_or_path_or_dict,
921941
unet_config=self.unet.config,
@@ -943,6 +963,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
943963
lora_scale=self.lora_scale,
944964
)
945965

966+
# Offload back.
967+
if is_model_cpu_offload:
968+
self.enable_model_cpu_offload()
969+
elif is_sequential_cpu_offload:
970+
self.enable_sequential_cpu_offload()
971+
946972
@classmethod
947973
def save_lora_weights(
948974
self,

pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,26 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
10701070
# We could have accessed the unet config from `lora_state_dict()` too. We pass
10711071
# it here explicitly to be able to tell that it's coming from an SDXL
10721072
# pipeline.
1073+
1074+
# Remove any existing hooks.
1075+
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
1076+
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
1077+
else:
1078+
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
1079+
1080+
is_model_cpu_offload = False
1081+
is_sequential_cpu_offload = False
1082+
recursive = False
1083+
for _, component in self.components.items():
1084+
if isinstance(component, torch.nn.Module):
1085+
if hasattr(component, "_hf_hook"):
1086+
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
1087+
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
1088+
logger.info(
1089+
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
1090+
)
1091+
recursive = is_sequential_cpu_offload
1092+
remove_hook_from_module(component, recurse=recursive)
10731093
state_dict, network_alphas = self.lora_state_dict(
10741094
pretrained_model_name_or_path_or_dict,
10751095
unet_config=self.unet.config,
@@ -1097,6 +1117,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
10971117
lora_scale=self.lora_scale,
10981118
)
10991119

1120+
# Offload back.
1121+
if is_model_cpu_offload:
1122+
self.enable_model_cpu_offload()
1123+
elif is_sequential_cpu_offload:
1124+
self.enable_sequential_cpu_offload()
1125+
11001126
@classmethod
11011127
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
11021128
def save_lora_weights(

pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,6 +1384,26 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
13841384
# We could have accessed the unet config from `lora_state_dict()` too. We pass
13851385
# it here explicitly to be able to tell that it's coming from an SDXL
13861386
# pipeline.
1387+
1388+
# Remove any existing hooks.
1389+
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
1390+
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
1391+
else:
1392+
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
1393+
1394+
is_model_cpu_offload = False
1395+
is_sequential_cpu_offload = False
1396+
recursive = False
1397+
for _, component in self.components.items():
1398+
if isinstance(component, torch.nn.Module):
1399+
if hasattr(component, "_hf_hook"):
1400+
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
1401+
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
1402+
logger.info(
1403+
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
1404+
)
1405+
recursive = is_sequential_cpu_offload
1406+
remove_hook_from_module(component, recurse=recursive)
13871407
state_dict, network_alphas = self.lora_state_dict(
13881408
pretrained_model_name_or_path_or_dict,
13891409
unet_config=self.unet.config,
@@ -1411,6 +1431,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
14111431
lora_scale=self.lora_scale,
14121432
)
14131433

1434+
# Offload back.
1435+
if is_model_cpu_offload:
1436+
self.enable_model_cpu_offload()
1437+
elif is_sequential_cpu_offload:
1438+
self.enable_sequential_cpu_offload()
1439+
14141440
@classmethod
14151441
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
14161442
def save_lora_weights(

0 commit comments

Comments
 (0)