3232 LoRAAttnAddedKVProcessor ,
3333 LoRAAttnProcessor ,
3434 LoRAAttnProcessor2_0 ,
35- LoRALinearLayer ,
3635 LoRAXFormersAttnProcessor ,
3736 SlicedAttnAddedKVProcessor ,
3837 XFormersAttnProcessor ,
3938)
39+ from .models .lora import Conv2dWithLoRA , LinearWithLoRA , LoRAConv2dLayer , LoRALinearLayer
4040from .utils import (
4141 DIFFUSERS_CACHE ,
4242 HF_HUB_OFFLINE ,
@@ -464,6 +464,36 @@ def save_function(weights, filename):
464464 save_function (state_dict , os .path .join (save_directory , weight_name ))
465465 logger .info (f"Model weights saved in { os .path .join (save_directory , weight_name )} " )
466466
467+ def _load_lora_aux (self , state_dict , network_alpha = None ):
468+ lora_grouped_dict = defaultdict (dict )
469+ for key , value in state_dict .items ():
470+ attn_processor_key , sub_key = "." .join (key .split ("." )[:- 3 ]), "." .join (key .split ("." )[- 3 :])
471+ lora_grouped_dict [attn_processor_key ][sub_key ] = value
472+
473+ for key , value_dict in lora_grouped_dict .items ():
474+ rank = value_dict ["lora.down.weight" ].shape [0 ]
475+ hidden_size = value_dict ["lora.up.weight" ].shape [0 ]
476+ target_modules = [module for name , module in self .named_modules () if name == key ]
477+ if len (target_modules ) == 0 :
478+ logger .warning (f"Could not find module { key } in the model. Skipping." )
479+ continue
480+
481+ target_module = target_modules [0 ]
482+ value_dict = {k .replace ("lora." , "" ): v for k , v in value_dict .items ()}
483+
484+ lora = None
485+ if isinstance (target_module , Conv2dWithLoRA ):
486+ lora = LoRAConv2dLayer (hidden_size , hidden_size , rank , network_alpha )
487+ elif isinstance (target_module , LinearWithLoRA ):
488+ lora = LoRALinearLayer (target_module .in_features , target_module .out_features , rank , network_alpha )
489+ else :
490+ raise ValueError (f"Module { key } is not a Conv2dWithLoRA or LinearWithLoRA module." )
491+ lora .load_state_dict (value_dict )
492+ lora .to (device = self .device , dtype = self .dtype )
493+
494+ # install lora
495+ target_module .lora_layer = lora
496+
467497
468498class TextualInversionLoaderMixin :
469499 r"""
@@ -825,10 +855,18 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
825855 kwargs:
826856 See [`~loaders.LoraLoaderMixin.lora_state_dict`].
827857 """
828- state_dict , network_alpha = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
829- self .load_lora_into_unet (state_dict , network_alpha = network_alpha , unet = self .unet )
858+ state_dict , network_alpha , (unet_state_dict_aux , te_state_dict_aux ) = self .lora_state_dict (
859+ pretrained_model_name_or_path_or_dict , ** kwargs
860+ )
861+ self .load_lora_into_unet (
862+ state_dict , network_alpha = network_alpha , unet = self .unet , state_dict_aux = unet_state_dict_aux
863+ )
830864 self .load_lora_into_text_encoder (
831- state_dict , network_alpha = network_alpha , text_encoder = self .text_encoder , lora_scale = self .lora_scale
865+ state_dict ,
866+ network_alpha = network_alpha ,
867+ text_encoder = self .text_encoder ,
868+ lora_scale = self .lora_scale ,
869+ state_dict_aux = te_state_dict_aux ,
832870 )
833871
834872 @classmethod
@@ -962,13 +1000,14 @@ def lora_state_dict(
9621000
9631001 # Convert kohya-ss Style LoRA attn procs to diffusers attn procs
9641002 network_alpha = None
1003+ auxilary_states = ({}, {})
9651004 if all ((k .startswith ("lora_te_" ) or k .startswith ("lora_unet_" )) for k in state_dict .keys ()):
966- state_dict , network_alpha = cls ._convert_kohya_lora_to_diffusers (state_dict )
1005+ state_dict , network_alpha , auxilary_states = cls ._convert_kohya_lora_to_diffusers (state_dict )
9671006
968- return state_dict , network_alpha
1007+ return state_dict , network_alpha , auxilary_states
9691008
9701009 @classmethod
971- def load_lora_into_unet (cls , state_dict , network_alpha , unet ):
1010+ def load_lora_into_unet (cls , state_dict , network_alpha , unet , aux_state_dict = None ):
9721011 """
9731012 This will load the LoRA layers specified in `state_dict` into `unet`
9741013
@@ -981,6 +1020,8 @@ def load_lora_into_unet(cls, state_dict, network_alpha, unet):
9811020 See `LoRALinearLayer` for more details.
9821021 unet (`UNet2DConditionModel`):
9831022 The UNet model to load the LoRA layers into.
1023+ aux_state_dict (`dict`, *optional*):
1024+ A dictionary containing the auxilary state (additional lora state) dict for the unet.
9841025 """
9851026
9861027 # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
@@ -1005,8 +1046,11 @@ def load_lora_into_unet(cls, state_dict, network_alpha, unet):
10051046 warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
10061047 warnings .warn (warn_message )
10071048
1049+ if aux_state_dict :
1050+ unet ._load_lora_aux (aux_state_dict , network_alpha = network_alpha )
1051+
10081052 @classmethod
1009- def load_lora_into_text_encoder (cls , state_dict , network_alpha , text_encoder , lora_scale = 1.0 ):
1053+ def load_lora_into_text_encoder (cls , state_dict , network_alpha , text_encoder , lora_scale = 1.0 , state_dict_aux = None ):
10101054 """
10111055 This will load the LoRA layers specified in `state_dict` into `text_encoder`
10121056
@@ -1021,6 +1065,8 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo
10211065 lora_scale (`float`):
10221066 How much to scale the output of the lora linear layer before it is added with the output of the regular
10231067 lora layer.
1068+ state_dict_aux (`dict`, *optional*):
1069+ A dictionary containing the auxilary state dict (additional lora state) for the text encoder.
10241070 """
10251071
10261072 # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
@@ -1078,6 +1124,8 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo
10781124 ].shape [1 ]
10791125
10801126 cls ._modify_text_encoder (text_encoder , lora_scale , network_alpha , rank = rank )
1127+ if state_dict_aux :
1128+ cls ._load_lora_aux_for_text_encoder (text_encoder , state_dict_aux , network_alpha = network_alpha )
10811129
10821130 # set correct dtype & device
10831131 text_encoder_lora_state_dict = {
@@ -1109,6 +1157,37 @@ def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
11091157 attn_module .v_proj = attn_module .v_proj .regular_linear_layer
11101158 attn_module .out_proj = attn_module .out_proj .regular_linear_layer
11111159
1160+ @classmethod
1161+ def _load_lora_aux_for_text_encoder (cls , text_encoder , state_dict , network_alpha = None ):
1162+ lora_grouped_dict = defaultdict (dict )
1163+ for key , value in state_dict .items ():
1164+ attn_processor_key , sub_key = "." .join (key .split ("." )[:- 3 ]), "." .join (key .split ("." )[- 3 :])
1165+ lora_grouped_dict [attn_processor_key ][sub_key ] = value
1166+
1167+ for key , value_dict in lora_grouped_dict .items ():
1168+ rank = value_dict ["lora.down.weight" ].shape [0 ]
1169+ target_modules = [module for name , module in text_encoder .named_modules () if name == key ]
1170+ if len (target_modules ) == 0 :
1171+ logger .warning (f"Could not find module { key } in the model. Skipping." )
1172+ continue
1173+
1174+ target_module = target_modules [0 ]
1175+ value_dict = {k .replace ("lora." , "" ): v for k , v in value_dict .items ()}
1176+ lora_layer = LoRALinearLayer (target_module .in_features , target_module .out_features , rank , network_alpha )
1177+ lora_layer .load_state_dict (value_dict )
1178+ lora_layer .to (device = text_encoder .device , dtype = text_encoder .dtype )
1179+
1180+ old_forward = target_module .forward
1181+
1182+ def make_new_forward (old_forward , lora_layer ):
1183+ def new_forward (x ):
1184+ return old_forward (x ) + lora_layer (x )
1185+
1186+ return new_forward
1187+
1188+ # Monkey-patch.
1189+ target_module .forward = make_new_forward (old_forward , lora_layer )
1190+
11121191 @classmethod
11131192 def _modify_text_encoder (cls , text_encoder , lora_scale = 1 , network_alpha = None , rank = 4 , dtype = None ):
11141193 r"""
@@ -1225,6 +1304,8 @@ def save_function(weights, filename):
12251304 def _convert_kohya_lora_to_diffusers (cls , state_dict ):
12261305 unet_state_dict = {}
12271306 te_state_dict = {}
1307+ unet_state_dict_aux = {}
1308+ te_state_dict_aux = {}
12281309 network_alpha = None
12291310
12301311 for key , value in state_dict .items ():
@@ -1249,12 +1330,20 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
12491330 diffusers_name = diffusers_name .replace ("to.k.lora" , "to_k_lora" )
12501331 diffusers_name = diffusers_name .replace ("to.v.lora" , "to_v_lora" )
12511332 diffusers_name = diffusers_name .replace ("to.out.0.lora" , "to_out_lora" )
1333+ diffusers_name = diffusers_name .replace ("proj.in" , "proj_in" )
1334+ diffusers_name = diffusers_name .replace ("proj.out" , "proj_out" )
12521335 if "transformer_blocks" in diffusers_name :
12531336 if "attn1" in diffusers_name or "attn2" in diffusers_name :
12541337 diffusers_name = diffusers_name .replace ("attn1" , "attn1.processor" )
12551338 diffusers_name = diffusers_name .replace ("attn2" , "attn2.processor" )
12561339 unet_state_dict [diffusers_name ] = value
12571340 unet_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict [lora_name_up ]
1341+ elif "ff" in diffusers_name :
1342+ unet_state_dict_aux [diffusers_name ] = value
1343+ unet_state_dict_aux [diffusers_name .replace (".down." , ".up." )] = state_dict [lora_name_up ]
1344+ elif any (key in diffusers_name for key in ("proj_in" , "proj_out" )):
1345+ unet_state_dict_aux [diffusers_name ] = value
1346+ unet_state_dict_aux [diffusers_name .replace (".down." , ".up." )] = state_dict [lora_name_up ]
12581347 elif lora_name .startswith ("lora_te_" ):
12591348 diffusers_name = key .replace ("lora_te_" , "" ).replace ("_" , "." )
12601349 diffusers_name = diffusers_name .replace ("text.model" , "text_model" )
@@ -1266,11 +1355,14 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
12661355 if "self_attn" in diffusers_name :
12671356 te_state_dict [diffusers_name ] = value
12681357 te_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict [lora_name_up ]
1358+ elif "mlp" in diffusers_name :
1359+ te_state_dict_aux [diffusers_name ] = value
1360+ te_state_dict_aux [diffusers_name .replace (".down." , ".up." )] = state_dict [lora_name_up ]
12691361
12701362 unet_state_dict = {f"{ UNET_NAME } .{ module_name } " : params for module_name , params in unet_state_dict .items ()}
12711363 te_state_dict = {f"{ TEXT_ENCODER_NAME } .{ module_name } " : params for module_name , params in te_state_dict .items ()}
12721364 new_state_dict = {** unet_state_dict , ** te_state_dict }
1273- return new_state_dict , network_alpha
1365+ return new_state_dict , network_alpha , ( unet_state_dict_aux , te_state_dict_aux )
12741366
12751367 def unload_lora_weights (self ):
12761368 """
0 commit comments