Skip to content

Commit 94abbc0

Browse files
committed
fix adapter_name issues
1 parent 86c7d69 commit 94abbc0

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

src/diffusers/loaders.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1187,6 +1187,7 @@ def load_lora_weights(
11871187
network_alphas=network_alphas,
11881188
unet=self.unet,
11891189
low_cpu_mem_usage=low_cpu_mem_usage,
1190+
adapter_name=adapter_name,
11901191
_pipeline=self,
11911192
)
11921193
self.load_lora_into_text_encoder(
@@ -1497,7 +1498,7 @@ def _maybe_map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="
14971498

14981499
@classmethod
14991500
def load_lora_into_unet(
1500-
cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, _pipeline=None, adapter_name=None
1501+
cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
15011502
):
15021503
"""
15031504
This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -2987,7 +2988,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
29872988
"""This class overrides `LoraLoaderMixin` with LoRA loading/saving code that's specific to SDXL"""
29882989

29892990
# Overrride to properly handle the loading and unloading of the additional text encoder.
2990-
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
2991+
def load_lora_weights(
2992+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
2993+
):
29912994
"""
29922995
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
29932996
`self.text_encoder`.
@@ -3005,6 +3008,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
30053008
Parameters:
30063009
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
30073010
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
3011+
adapter_name (`str`, *optional*):
3012+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3013+
`default_{i}` where i is the total number of adapters being loaded.
30083014
kwargs (`dict`, *optional*):
30093015
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
30103016
"""
@@ -3031,6 +3037,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
30313037
text_encoder=self.text_encoder,
30323038
prefix="text_encoder",
30333039
lora_scale=self.lora_scale,
3040+
adapter_name=adapter_name,
30343041
_pipeline=self,
30353042
)
30363043

@@ -3042,6 +3049,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
30423049
text_encoder=self.text_encoder_2,
30433050
prefix="text_encoder_2",
30443051
lora_scale=self.lora_scale,
3052+
adapter_name=adapter_name,
30453053
_pipeline=self,
30463054
)
30473055

0 commit comments

Comments
 (0)