3434from .utils import (
3535 DIFFUSERS_CACHE ,
3636 HF_HUB_OFFLINE ,
37- TEXT_ENCODER_TARGET_MODULES ,
37+ TEXT_ENCODER_ATTN_MODULE ,
3838 _get_model_file ,
3939 deprecate ,
4040 is_safetensors_available ,
@@ -955,6 +955,19 @@ def text_encoder_lora_attn_procs(self):
955955 return self ._text_encoder_lora_attn_procs
956956 return
957957
958+ def _remove_text_encoder_monkey_patch (self ):
959+ # Loop over the CLIPAttention module of text_encoder
960+ for name , attn_module in self .text_encoder .named_modules ():
961+ if name .endswith (TEXT_ENCODER_ATTN_MODULE ):
962+ # Loop over the LoRA layers
963+ for _ , text_encoder_attr in self ._lora_attn_processor_attr_to_text_encoder_attr .items ():
964+ # Retrieve the q/k/v/out projection of CLIPAttention
965+ module = attn_module .get_submodule (text_encoder_attr )
966+ if hasattr (module , "old_forward" ):
967+ # restore original `forward` to remove monkey-patch
968+ module .forward = module .old_forward
969+ delattr (module , "old_forward" )
970+
958971 def _modify_text_encoder (self , attn_processors : Dict [str , LoRAAttnProcessor ]):
959972 r"""
960973 Monkey-patches the forward passes of attention modules of the text encoder.
@@ -963,37 +976,41 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
963976 attn_processors: Dict[str, `LoRAAttnProcessor`]:
964977 A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`].
965978 """
966- # Loop over the original attention modules.
967- for name , _ in self .text_encoder .named_modules ():
968- if any (x in name for x in TEXT_ENCODER_TARGET_MODULES ):
969- # Retrieve the module and its corresponding LoRA processor.
970- module = self .text_encoder .get_submodule (name )
971- # Construct a new function that performs the LoRA merging. We will monkey patch
972- # this forward pass.
973- attn_processor_name = "." .join (name .split ("." )[:- 1 ])
974- lora_layer = getattr (attn_processors [attn_processor_name ], self ._get_lora_layer_attribute (name ))
975- old_forward = module .forward
976-
977- # create a new scope that locks in the old_forward, lora_layer value for each new_forward function
978- # for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
979- def make_new_forward (old_forward , lora_layer ):
980- def new_forward (x ):
981- return old_forward (x ) + lora_layer (x )
982-
983- return new_forward
984-
985- # Monkey-patch.
986- module .forward = make_new_forward (old_forward , lora_layer )
987-
988- def _get_lora_layer_attribute (self , name : str ) -> str :
989- if "q_proj" in name :
990- return "to_q_lora"
991- elif "v_proj" in name :
992- return "to_v_lora"
993- elif "k_proj" in name :
994- return "to_k_lora"
995- else :
996- return "to_out_lora"
979+
980+ # First, remove any monkey-patch that might have been applied before
981+ self ._remove_text_encoder_monkey_patch ()
982+
983+ # Loop over the CLIPAttention module of text_encoder
984+ for name , attn_module in self .text_encoder .named_modules ():
985+ if name .endswith (TEXT_ENCODER_ATTN_MODULE ):
986+ # Loop over the LoRA layers
987+ for attn_proc_attr , text_encoder_attr in self ._lora_attn_processor_attr_to_text_encoder_attr .items ():
988+ # Retrieve the q/k/v/out projection of CLIPAttention and its corresponding LoRA layer.
989+ module = attn_module .get_submodule (text_encoder_attr )
990+ lora_layer = attn_processors [name ].get_submodule (attn_proc_attr )
991+
992+ # save old_forward to module that can be used to remove monkey-patch
993+ old_forward = module .old_forward = module .forward
994+
995+ # create a new scope that locks in the old_forward, lora_layer value for each new_forward function
996+ # for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
997+ def make_new_forward (old_forward , lora_layer ):
998+ def new_forward (x ):
999+ return old_forward (x ) + lora_layer (x )
1000+
1001+ return new_forward
1002+
1003+ # Monkey-patch.
1004+ module .forward = make_new_forward (old_forward , lora_layer )
1005+
1006+ @property
1007+ def _lora_attn_processor_attr_to_text_encoder_attr (self ):
1008+ return {
1009+ "to_q_lora" : "q_proj" ,
1010+ "to_k_lora" : "k_proj" ,
1011+ "to_v_lora" : "v_proj" ,
1012+ "to_out_lora" : "out_proj" ,
1013+ }
9971014
9981015 def _load_text_encoder_attn_procs (
9991016 self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], ** kwargs
0 commit comments