@@ -1187,6 +1187,7 @@ def load_lora_weights(
11871187 network_alphas = network_alphas ,
11881188 unet = self .unet ,
11891189 low_cpu_mem_usage = low_cpu_mem_usage ,
1190+ adapter_name = adapter_name ,
11901191 _pipeline = self ,
11911192 )
11921193 self .load_lora_into_text_encoder (
@@ -1497,7 +1498,7 @@ def _maybe_map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="
14971498
14981499 @classmethod
14991500 def load_lora_into_unet (
1500- cls , state_dict , network_alphas , unet , low_cpu_mem_usage = None , _pipeline = None , adapter_name = None
1501+ cls , state_dict , network_alphas , unet , low_cpu_mem_usage = None , adapter_name = None , _pipeline = None
15011502 ):
15021503 """
15031504 This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -2987,7 +2988,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
29872988 """This class overrides `LoraLoaderMixin` with LoRA loading/saving code that's specific to SDXL"""
29882989
29892990 # Overrride to properly handle the loading and unloading of the additional text encoder.
2990- def load_lora_weights (self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], ** kwargs ):
2991+ def load_lora_weights (
2992+ self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], adapter_name = None , ** kwargs
2993+ ):
29912994 """
29922995 Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
29932996 `self.text_encoder`.
@@ -3005,6 +3008,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
30053008 Parameters:
30063009 pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
30073010 See [`~loaders.LoraLoaderMixin.lora_state_dict`].
3011+ adapter_name (`str`, *optional*):
3012+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3013+ `default_{i}` where i is the total number of adapters being loaded.
30083014 kwargs (`dict`, *optional*):
30093015 See [`~loaders.LoraLoaderMixin.lora_state_dict`].
30103016 """
@@ -3031,6 +3037,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
30313037 text_encoder = self .text_encoder ,
30323038 prefix = "text_encoder" ,
30333039 lora_scale = self .lora_scale ,
3040+ adapter_name = adapter_name ,
30343041 _pipeline = self ,
30353042 )
30363043
@@ -3042,6 +3049,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
30423049 text_encoder = self .text_encoder_2 ,
30433050 prefix = "text_encoder_2" ,
30443051 lora_scale = self .lora_scale ,
3052+ adapter_name = adapter_name ,
30453053 _pipeline = self ,
30463054 )
30473055
0 commit comments