@@ -80,6 +80,18 @@ def initialize_dummy_state_dict(state_dict):
8080POSSIBLE_ATTENTION_KWARGS_NAMES = ["cross_attention_kwargs" , "joint_attention_kwargs" , "attention_kwargs" ]
8181
8282
83+ def determine_attention_kwargs_name (pipeline_class ):
84+ call_signature_keys = inspect .signature (pipeline_class .__call__ ).parameters .keys ()
85+
86+ # TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release
87+ for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES :
88+ if possible_attention_kwargs in call_signature_keys :
89+ attention_kwargs_name = possible_attention_kwargs
90+ break
91+ assert attention_kwargs_name is not None
92+ return attention_kwargs_name
93+
94+
8395@require_peft_backend
8496class PeftLoraLoaderMixinTests :
8597 pipeline_class = None
@@ -442,14 +454,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
442454 Tests a simple inference with lora attached on the text encoder + scale argument
443455 and makes sure it works as expected
444456 """
445- call_signature_keys = inspect .signature (self .pipeline_class .__call__ ).parameters .keys ()
446-
447- # TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release
448- for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES :
449- if possible_attention_kwargs in call_signature_keys :
450- attention_kwargs_name = possible_attention_kwargs
451- break
452- assert attention_kwargs_name is not None
457+ attention_kwargs_name = determine_attention_kwargs_name (self .pipeline_class )
453458
454459 for scheduler_cls in self .scheduler_classes :
455460 components , text_lora_config , _ = self .get_dummy_components (scheduler_cls )
@@ -740,12 +745,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
740745 Tests a simple inference with lora attached on the text encoder + Unet + scale argument
741746 and makes sure it works as expected
742747 """
743- call_signature_keys = inspect .signature (self .pipeline_class .__call__ ).parameters .keys ()
744- for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES :
745- if possible_attention_kwargs in call_signature_keys :
746- attention_kwargs_name = possible_attention_kwargs
747- break
748- assert attention_kwargs_name is not None
748+ attention_kwargs_name = determine_attention_kwargs_name (self .pipeline_class )
749749
750750 for scheduler_cls in self .scheduler_classes :
751751 components , text_lora_config , denoiser_lora_config = self .get_dummy_components (scheduler_cls )
@@ -878,9 +878,11 @@ def test_simple_inference_with_text_denoiser_lora_unfused(
878878 pipe , denoiser = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config )
879879
880880 pipe .fuse_lora (components = self .pipeline_class ._lora_loadable_modules )
881+ self .assertTrue (pipe .num_fused_loras == 1 , f"{ pipe .num_fused_loras = } , { pipe .fused_loras = } " )
881882 output_fused_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
882883
883884 pipe .unfuse_lora (components = self .pipeline_class ._lora_loadable_modules )
885+ self .assertTrue (pipe .num_fused_loras == 0 , f"{ pipe .num_fused_loras = } , { pipe .fused_loras = } " )
884886 output_unfused_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
885887
886888 # unloading should remove the LoRA layers
@@ -1608,26 +1610,21 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
16081610 self .assertTrue (
16091611 check_if_lora_correctly_set (pipe .text_encoder ), "Lora not correctly set in text encoder"
16101612 )
1613+ pipe .text_encoder .add_adapter (text_lora_config , "adapter-2" )
16111614
16121615 denoiser = pipe .transformer if self .unet_kwargs is None else pipe .unet
16131616 denoiser .add_adapter (denoiser_lora_config , "adapter-1" )
1614-
1615- # Attach a second adapter
1616- if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
1617- pipe .text_encoder .add_adapter (text_lora_config , "adapter-2" )
1618-
1619- denoiser .add_adapter (denoiser_lora_config , "adapter-2" )
1620-
16211617 self .assertTrue (check_if_lora_correctly_set (denoiser ), "Lora not correctly set in denoiser." )
1618+ denoiser .add_adapter (denoiser_lora_config , "adapter-2" )
16221619
16231620 if self .has_two_text_encoders or self .has_three_text_encoders :
16241621 lora_loadable_components = self .pipeline_class ._lora_loadable_modules
16251622 if "text_encoder_2" in lora_loadable_components :
16261623 pipe .text_encoder_2 .add_adapter (text_lora_config , "adapter-1" )
1627- pipe .text_encoder_2 .add_adapter (text_lora_config , "adapter-2" )
16281624 self .assertTrue (
16291625 check_if_lora_correctly_set (pipe .text_encoder_2 ), "Lora not correctly set in text encoder 2"
16301626 )
1627+ pipe .text_encoder_2 .add_adapter (text_lora_config , "adapter-2" )
16311628
16321629 # set them to multi-adapter inference mode
16331630 pipe .set_adapters (["adapter-1" , "adapter-2" ])
@@ -1637,6 +1634,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
16371634 outputs_lora_1 = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
16381635
16391636 pipe .fuse_lora (components = self .pipeline_class ._lora_loadable_modules , adapter_names = ["adapter-1" ])
1637+ self .assertTrue (pipe .num_fused_loras == 1 , f"{ pipe .num_fused_loras = } , { pipe .fused_loras = } " )
16401638
16411639 # Fusing should still keep the LoRA layers so outpout should remain the same
16421640 outputs_lora_1_fused = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
@@ -1647,16 +1645,87 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
16471645 )
16481646
16491647 pipe .unfuse_lora (components = self .pipeline_class ._lora_loadable_modules )
1648+ self .assertTrue (pipe .num_fused_loras == 0 , f"{ pipe .num_fused_loras = } , { pipe .fused_loras = } " )
1649+
1650+ if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
1651+ self .assertTrue (check_if_lora_correctly_set (pipe .text_encoder ), "Unfuse should still keep LoRA layers" )
1652+
1653+ self .assertTrue (check_if_lora_correctly_set (denoiser ), "Unfuse should still keep LoRA layers" )
1654+
1655+ if self .has_two_text_encoders or self .has_three_text_encoders :
1656+ if "text_encoder_2" in self .pipeline_class ._lora_loadable_modules :
1657+ self .assertTrue (
1658+ check_if_lora_correctly_set (pipe .text_encoder_2 ), "Unfuse should still keep LoRA layers"
1659+ )
1660+
16501661 pipe .fuse_lora (
16511662 components = self .pipeline_class ._lora_loadable_modules , adapter_names = ["adapter-2" , "adapter-1" ]
16521663 )
1664+ self .assertTrue (pipe .num_fused_loras == 2 , f"{ pipe .num_fused_loras = } , { pipe .fused_loras = } " )
16531665
16541666 # Fusing should still keep the LoRA layers
16551667 output_all_lora_fused = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
16561668 self .assertTrue (
16571669 np .allclose (output_all_lora_fused , outputs_all_lora , atol = expected_atol , rtol = expected_rtol ),
16581670 "Fused lora should not change the output" ,
16591671 )
1672+ pipe .unfuse_lora (components = self .pipeline_class ._lora_loadable_modules )
1673+ self .assertTrue (pipe .num_fused_loras == 0 , f"{ pipe .num_fused_loras = } , { pipe .fused_loras = } " )
1674+
1675+ def test_lora_scale_kwargs_match_fusion (self , expected_atol : float = 1e-3 , expected_rtol : float = 1e-3 ):
1676+ attention_kwargs_name = determine_attention_kwargs_name (self .pipeline_class )
1677+
1678+ for lora_scale in [1.0 , 0.8 ]:
1679+ for scheduler_cls in self .scheduler_classes :
1680+ components , text_lora_config , denoiser_lora_config = self .get_dummy_components (scheduler_cls )
1681+ pipe = self .pipeline_class (** components )
1682+ pipe = pipe .to (torch_device )
1683+ pipe .set_progress_bar_config (disable = None )
1684+ _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
1685+
1686+ output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
1687+ self .assertTrue (output_no_lora .shape == self .output_shape )
1688+
1689+ if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
1690+ pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
1691+ self .assertTrue (
1692+ check_if_lora_correctly_set (pipe .text_encoder ), "Lora not correctly set in text encoder"
1693+ )
1694+
1695+ denoiser = pipe .transformer if self .unet_kwargs is None else pipe .unet
1696+ denoiser .add_adapter (denoiser_lora_config , "adapter-1" )
1697+ self .assertTrue (check_if_lora_correctly_set (denoiser ), "Lora not correctly set in denoiser." )
1698+
1699+ if self .has_two_text_encoders or self .has_three_text_encoders :
1700+ lora_loadable_components = self .pipeline_class ._lora_loadable_modules
1701+ if "text_encoder_2" in lora_loadable_components :
1702+ pipe .text_encoder_2 .add_adapter (text_lora_config , "adapter-1" )
1703+ self .assertTrue (
1704+ check_if_lora_correctly_set (pipe .text_encoder_2 ),
1705+ "Lora not correctly set in text encoder 2" ,
1706+ )
1707+
1708+ pipe .set_adapters (["adapter-1" ])
1709+ attention_kwargs = {attention_kwargs_name : {"scale" : lora_scale }}
1710+ outputs_lora_1 = pipe (** inputs , generator = torch .manual_seed (0 ), ** attention_kwargs )[0 ]
1711+
1712+ pipe .fuse_lora (
1713+ components = self .pipeline_class ._lora_loadable_modules ,
1714+ adapter_names = ["adapter-1" ],
1715+ lora_scale = lora_scale ,
1716+ )
1717+ self .assertTrue (pipe .num_fused_loras == 1 , f"{ pipe .num_fused_loras = } , { pipe .fused_loras = } " )
1718+
1719+ outputs_lora_1_fused = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
1720+
1721+ self .assertTrue (
1722+ np .allclose (outputs_lora_1 , outputs_lora_1_fused , atol = expected_atol , rtol = expected_rtol ),
1723+ "Fused lora should not change the output" ,
1724+ )
1725+ self .assertFalse (
1726+ np .allclose (output_no_lora , outputs_lora_1 , atol = expected_atol , rtol = expected_rtol ),
1727+ "LoRA should change the output" ,
1728+ )
16601729
16611730 @require_peft_version_greater (peft_version = "0.9.0" )
16621731 def test_simple_inference_with_dora (self ):
@@ -1838,12 +1907,7 @@ def test_logs_info_when_no_lora_keys_found(self):
18381907
18391908 def test_set_adapters_match_attention_kwargs (self ):
18401909 """Test to check if outputs after `set_adapters()` and attention kwargs match."""
1841- call_signature_keys = inspect .signature (self .pipeline_class .__call__ ).parameters .keys ()
1842- for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES :
1843- if possible_attention_kwargs in call_signature_keys :
1844- attention_kwargs_name = possible_attention_kwargs
1845- break
1846- assert attention_kwargs_name is not None
1910+ attention_kwargs_name = determine_attention_kwargs_name (self .pipeline_class )
18471911
18481912 for scheduler_cls in self .scheduler_classes :
18491913 components , text_lora_config , denoiser_lora_config = self .get_dummy_components (scheduler_cls )
0 commit comments