@@ -679,6 +679,55 @@ def _unfuse_lora_apply(self, module):
679679 if hasattr (module , "_unfuse_lora" ):
680680 module ._unfuse_lora ()
681681
682+ def set_adapters (
683+ self ,
684+ adapter_names : Union [List [str ], str ],
685+ weights : List [float ] = None ,
686+ ):
687+ """
688+ Sets the adapter layers for the unet.
689+
690+ Args:
691+ adapter_names (`List[str]` or `str`):
692+ The names of the adapters to use.
693+ weights (`List[float]`, *optional*):
694+ The weights to use for the unet. If `None`, the weights are set to `1.0` for all the adapters.
695+ """
696+ if not self .use_peft_backend :
697+ raise ValueError ("PEFT backend is required for this method." )
698+
699+ def process_weights (adapter_names , weights ):
700+ if weights is None :
701+ weights = [1.0 ] * len (adapter_names )
702+ elif isinstance (weights , float ):
703+ weights = [weights ]
704+
705+ if len (adapter_names ) != len (weights ):
706+ raise ValueError (
707+ f"Length of adapter names { len (adapter_names )} is not equal to the length of the weights { len (weights )} "
708+ )
709+ return weights
710+
711+ adapter_names = [adapter_names ] if isinstance (adapter_names , str ) else adapter_names
712+ weights = process_weights (adapter_names , weights )
713+ set_weights_and_activate_adapters (self , adapter_names , weights )
714+
715+ def disable_lora (self ):
716+ """
717+ Disables the LoRA layers for the unet.
718+ """
719+ if not self .use_peft_backend :
720+ raise ValueError ("PEFT backend is required for this method." )
721+ set_adapter_layers (self , enabled = False )
722+
723+ def enable_lora (self ):
724+ """
725+ Enables the LoRA layers for the unet.
726+ """
727+ if not self .use_peft_backend :
728+ raise ValueError ("PEFT backend is required for this method." )
729+ set_adapter_layers (self , enabled = True )
730+
682731
683732def load_textual_inversion_state_dicts (pretrained_model_name_or_paths , ** kwargs ):
684733 cache_dir = kwargs .pop ("cache_dir" , DIFFUSERS_CACHE )
@@ -1448,7 +1497,7 @@ def _maybe_map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="
14481497
14491498 @classmethod
14501499 def load_lora_into_unet (
1451- cls , state_dict , network_alphas , unet , low_cpu_mem_usage = None , _pipeline = None , adapter_name = "default"
1500+ cls , state_dict , network_alphas , unet , low_cpu_mem_usage = None , _pipeline = None , adapter_name = None
14521501 ):
14531502 """
14541503 This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -1468,7 +1517,8 @@ def load_lora_into_unet(
14681517 Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
14691518 argument to `True` will raise an error.
14701519 adapter_name (`str`, *optional*):
1471- The name of the adapter to load the weights into. By default we use `"default"`
1520+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1521+ `default_{i}` where i is the total number of adapters being loaded.
14721522 """
14731523 low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
14741524 # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
@@ -1500,38 +1550,19 @@ def load_lora_into_unet(
15001550
15011551 state_dict = convert_unet_state_dict_to_peft (state_dict )
15021552
1503- target_modules = []
1504- ranks = []
1553+ rank = {}
15051554 for key in state_dict .keys ():
1506- # filter out the name
1507- filtered_name = "." .join (key .split ("." )[:- 2 ])
1508- target_modules .append (filtered_name )
15091555 if "lora_B" in key :
1510- rank = state_dict [key ].shape [1 ]
1511- ranks .append (rank )
1556+ rank [key ] = state_dict [key ].shape [1 ]
15121557
1513- current_rank = ranks [0 ]
1514- if not all (rank == current_rank for rank in ranks ):
1515- raise ValueError ("Multi-rank not supported yet" )
1558+ lora_config_kwargs = get_peft_kwargs (rank , network_alphas , state_dict )
1559+ lora_config = LoraConfig (** lora_config_kwargs )
15161560
1517- if network_alphas is not None :
1518- alphas = set (network_alphas .values ())
1519- if len (alphas ) == 1 :
1520- alpha = alphas .pop ()
1521- # TODO: support multi-alpha
1522- else :
1523- raise ValueError ("Multi-alpha not supported yet" )
1524- else :
1525- alpha = current_rank
1526-
1527- lora_config = LoraConfig (
1528- r = current_rank ,
1529- lora_alpha = alpha ,
1530- target_modules = target_modules ,
1531- )
1561+ # adapter_name
1562+ if adapter_name is None :
1563+ adapter_name = get_adapter_name (unet )
15321564
15331565 inject_adapter_in_model (lora_config , unet , adapter_name = adapter_name )
1534-
15351566 incompatible_keys = set_peft_model_state_dict (unet , state_dict , adapter_name )
15361567
15371568 if incompatible_keys is not None :
@@ -1655,12 +1686,14 @@ def load_lora_into_text_encoder(
16551686 if adapter_name is None :
16561687 adapter_name = get_adapter_name (text_encoder )
16571688
1689+
16581690 # inject LoRA layers and load the state dict
16591691 text_encoder .load_adapter (
16601692 adapter_name = adapter_name ,
16611693 adapter_state_dict = text_encoder_lora_state_dict ,
16621694 peft_config = lora_config ,
16631695 )
1696+
16641697 # scale LoRA layers with `lora_scale`
16651698 scale_lora_layers (text_encoder , weight = lora_scale )
16661699
@@ -2258,7 +2291,7 @@ def unfuse_text_encoder_lora(text_encoder):
22582291
22592292 self .num_fused_loras -= 1
22602293
2261- def set_adapter_for_text_encoder (
2294+ def set_adapters_for_text_encoder (
22622295 self ,
22632296 adapter_names : Union [List [str ], str ],
22642297 text_encoder : Optional [PreTrainedModel ] = None ,
@@ -2336,60 +2369,44 @@ def enable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] =
23362369 def set_adapters (
23372370 self ,
23382371 adapter_names : Union [List [str ], str ],
2339- weights : List [float ] = None ,
2372+ unet_weights : List [float ] = None ,
2373+ te_weights : List [float ] = None ,
2374+ te2_weights : List [float ] = None ,
23402375 ):
2341- """
2342- Sets the adapter layers for the unet.
2343-
2344- Args:
2345- adapter_names (`List[str]` or `str`):
2346- The names of the adapters to use.
2347- weights (`List[float]`, *optional*):
2348- The weights to use for the unet. If `None`, the weights are set to `1.0` for all the adapters.
2349- """
2350- if not self .use_peft_backend :
2351- raise ValueError ("PEFT backend is required for this method." )
2352-
2353- def process_weights (adapter_names , weights ):
2354- if weights is None :
2355- weights = [1.0 ] * len (adapter_names )
2356- elif isinstance (weights , float ):
2357- weights = [weights ]
2358-
2359- if len (adapter_names ) != len (weights ):
2360- raise ValueError (
2361- f"Length of adapter names { len (adapter_names )} is not equal to the length of the weights { len (weights )} "
2362- )
2363- return weights
2364-
2365- adapter_names = [adapter_names ] if isinstance (adapter_names , str ) else adapter_names
2366- weights = process_weights (adapter_names , weights )
2376+ # Handle the UNET
2377+ self .unet .set_adapters (adapter_names , unet_weights )
23672378
2368- for key , value in self .components .items ():
2369- if isinstance (value , nn .Module ):
2370- set_weights_and_activate_adapters (value , adapter_names , weights )
2379+ # Handle the Text Encoder
2380+ if hasattr (self , "text_encoder" ):
2381+ self .set_adapters_for_text_encoder (adapter_names , self .text_encoder , te_weights )
2382+ if hasattr (self , "text_encoder_2" ):
2383+ self .set_adapters_for_text_encoder (adapter_names , self .text_encoder_2 , te2_weights )
23712384
23722385 def disable_lora (self ):
2373- """
2374- Disables the LoRA layers for the unet.
2375- """
23762386 if not self .use_peft_backend :
23772387 raise ValueError ("PEFT backend is required for this method." )
23782388
2379- for key , value in self .components .items ():
2380- if isinstance (value , nn .Module ):
2381- set_adapter_layers (value , enabled = False )
2389+ # Disable unet adapters
2390+ self .unet .disable_lora ()
2391+
2392+ # Disable text encoder adapters
2393+ if hasattr (self , "text_encoder" ):
2394+ self .disable_lora_for_text_encoder (self .text_encoder )
2395+ if hasattr (self , "text_encoder_2" ):
2396+ self .disable_lora_for_text_encoder (self .text_encoder_2 )
23822397
23832398 def enable_lora (self ):
2384- """
2385- Enables the LoRA layers for the unet.
2386- """
23872399 if not self .use_peft_backend :
23882400 raise ValueError ("PEFT backend is required for this method." )
23892401
2390- for key , value in self .components .items ():
2391- if isinstance (value , nn .Module ):
2392- set_adapter_layers (value , enabled = True )
2402+ # Enable unet adapters
2403+ self .unet .enable_lora ()
2404+
2405+ # Enable text encoder adapters
2406+ if hasattr (self , "text_encoder" ):
2407+ self .enable_lora_for_text_encoder (self .text_encoder )
2408+ if hasattr (self , "text_encoder_2" ):
2409+ self .enable_lora_for_text_encoder (self .text_encoder_2 )
23932410
23942411
23952412class FromSingleFileMixin :
0 commit comments