@@ -1205,6 +1205,7 @@ def load_lora_weights(
12051205 network_alphas = network_alphas ,
12061206 unet = self .unet ,
12071207 low_cpu_mem_usage = low_cpu_mem_usage ,
1208+ adapter_name = adapter_name ,
12081209 _pipeline = self ,
12091210 )
12101211 self .load_lora_into_text_encoder (
@@ -1515,7 +1516,7 @@ def _maybe_map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="
15151516
15161517 @classmethod
15171518 def load_lora_into_unet (
1518- cls , state_dict , network_alphas , unet , low_cpu_mem_usage = None , _pipeline = None , adapter_name = None
1519+ cls , state_dict , network_alphas , unet , low_cpu_mem_usage = None , adapter_name = None , _pipeline = None
15191520 ):
15201521 """
15211522 This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -3005,7 +3006,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
30053006 """This class overrides `LoraLoaderMixin` with LoRA loading/saving code that's specific to SDXL"""
30063007
30073008 # Overrride to properly handle the loading and unloading of the additional text encoder.
3008- def load_lora_weights (self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], ** kwargs ):
3009+ def load_lora_weights (
3010+ self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], adapter_name = None , ** kwargs
3011+ ):
30093012 """
30103013 Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
30113014 `self.text_encoder`.
@@ -3023,6 +3026,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
30233026 Parameters:
30243027 pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
30253028 See [`~loaders.LoraLoaderMixin.lora_state_dict`].
3029+ adapter_name (`str`, *optional*):
3030+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3031+ `default_{i}` where i is the total number of adapters being loaded.
30263032 kwargs (`dict`, *optional*):
30273033 See [`~loaders.LoraLoaderMixin.lora_state_dict`].
30283034 """
@@ -3040,7 +3046,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
30403046 if not is_correct_format :
30413047 raise ValueError ("Invalid LoRA checkpoint." )
30423048
3043- self .load_lora_into_unet (state_dict , network_alphas = network_alphas , unet = self .unet , _pipeline = self )
3049+ self .load_lora_into_unet (
3050+ state_dict , network_alphas = network_alphas , unet = self .unet , adapter_name = adapter_name , _pipeline = self
3051+ )
30443052 text_encoder_state_dict = {k : v for k , v in state_dict .items () if "text_encoder." in k }
30453053 if len (text_encoder_state_dict ) > 0 :
30463054 self .load_lora_into_text_encoder (
@@ -3049,6 +3057,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
30493057 text_encoder = self .text_encoder ,
30503058 prefix = "text_encoder" ,
30513059 lora_scale = self .lora_scale ,
3060+ adapter_name = adapter_name ,
30523061 _pipeline = self ,
30533062 )
30543063
@@ -3060,6 +3069,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
30603069 text_encoder = self .text_encoder_2 ,
30613070 prefix = "text_encoder_2" ,
30623071 lora_scale = self .lora_scale ,
3072+ adapter_name = adapter_name ,
30633073 _pipeline = self ,
30643074 )
30653075
0 commit comments