@@ -570,15 +570,15 @@ def __call__(
570570 if attn .group_norm is not None :
571571 hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
572572
573- query = attn .to_q (hidden_states , lora_scale = scale )
573+ query = attn .to_q (hidden_states , scale = scale )
574574
575575 if encoder_hidden_states is None :
576576 encoder_hidden_states = hidden_states
577577 elif attn .norm_cross :
578578 encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
579579
580- key = attn .to_k (encoder_hidden_states , lora_scale = scale )
581- value = attn .to_v (encoder_hidden_states , lora_scale = scale )
580+ key = attn .to_k (encoder_hidden_states , scale = scale )
581+ value = attn .to_v (encoder_hidden_states , scale = scale )
582582
583583 query = attn .head_to_batch_dim (query )
584584 key = attn .head_to_batch_dim (key )
@@ -589,7 +589,7 @@ def __call__(
589589 hidden_states = attn .batch_to_head_dim (hidden_states )
590590
591591 # linear proj
592- hidden_states = attn .to_out [0 ](hidden_states , lora_scale = scale )
592+ hidden_states = attn .to_out [0 ](hidden_states , scale = scale )
593593 # dropout
594594 hidden_states = attn .to_out [1 ](hidden_states )
595595
@@ -722,17 +722,17 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
722722
723723 hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
724724
725- query = attn .to_q (hidden_states , lora_scale = scale )
725+ query = attn .to_q (hidden_states , scale = scale )
726726 query = attn .head_to_batch_dim (query )
727727
728- encoder_hidden_states_key_proj = attn .add_k_proj (encoder_hidden_states , lora_scale = scale )
729- encoder_hidden_states_value_proj = attn .add_v_proj (encoder_hidden_states , lora_scale = scale )
728+ encoder_hidden_states_key_proj = attn .add_k_proj (encoder_hidden_states , scale = scale )
729+ encoder_hidden_states_value_proj = attn .add_v_proj (encoder_hidden_states , scale = scale )
730730 encoder_hidden_states_key_proj = attn .head_to_batch_dim (encoder_hidden_states_key_proj )
731731 encoder_hidden_states_value_proj = attn .head_to_batch_dim (encoder_hidden_states_value_proj )
732732
733733 if not attn .only_cross_attention :
734- key = attn .to_k (hidden_states , lora_scale = scale )
735- value = attn .to_v (hidden_states , lora_scale = scale )
734+ key = attn .to_k (hidden_states , scale = scale )
735+ value = attn .to_v (hidden_states , scale = scale )
736736 key = attn .head_to_batch_dim (key )
737737 value = attn .head_to_batch_dim (value )
738738 key = torch .cat ([encoder_hidden_states_key_proj , key ], dim = 1 )
@@ -746,7 +746,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
746746 hidden_states = attn .batch_to_head_dim (hidden_states )
747747
748748 # linear proj
749- hidden_states = attn .to_out [0 ](hidden_states , lora_scale = scale )
749+ hidden_states = attn .to_out [0 ](hidden_states , scale = scale )
750750 # dropout
751751 hidden_states = attn .to_out [1 ](hidden_states )
752752
@@ -782,7 +782,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
782782
783783 hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
784784
785- query = attn .to_q (hidden_states , lora_scale = scale )
785+ query = attn .to_q (hidden_states , scale = scale )
786786 query = attn .head_to_batch_dim (query , out_dim = 4 )
787787
788788 encoder_hidden_states_key_proj = attn .add_k_proj (encoder_hidden_states )
@@ -791,8 +791,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
791791 encoder_hidden_states_value_proj = attn .head_to_batch_dim (encoder_hidden_states_value_proj , out_dim = 4 )
792792
793793 if not attn .only_cross_attention :
794- key = attn .to_k (hidden_states , lora_scale = scale )
795- value = attn .to_v (hidden_states , lora_scale = scale )
794+ key = attn .to_k (hidden_states , scale = scale )
795+ value = attn .to_v (hidden_states , scale = scale )
796796 key = attn .head_to_batch_dim (key , out_dim = 4 )
797797 value = attn .head_to_batch_dim (value , out_dim = 4 )
798798 key = torch .cat ([encoder_hidden_states_key_proj , key ], dim = 2 )
@@ -809,7 +809,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
809809 hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , residual .shape [1 ])
810810
811811 # linear proj
812- hidden_states = attn .to_out [0 ](hidden_states , lora_scale = scale )
812+ hidden_states = attn .to_out [0 ](hidden_states , scale = scale )
813813 # dropout
814814 hidden_states = attn .to_out [1 ](hidden_states )
815815
@@ -937,15 +937,15 @@ def __call__(
937937 if attn .group_norm is not None :
938938 hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
939939
940- query = attn .to_q (hidden_states , lora_scale = scale )
940+ query = attn .to_q (hidden_states , scale = scale )
941941
942942 if encoder_hidden_states is None :
943943 encoder_hidden_states = hidden_states
944944 elif attn .norm_cross :
945945 encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
946946
947- key = attn .to_k (encoder_hidden_states , lora_scale = scale )
948- value = attn .to_v (encoder_hidden_states , lora_scale = scale )
947+ key = attn .to_k (encoder_hidden_states , scale = scale )
948+ value = attn .to_v (encoder_hidden_states , scale = scale )
949949
950950 query = attn .head_to_batch_dim (query ).contiguous ()
951951 key = attn .head_to_batch_dim (key ).contiguous ()
@@ -958,7 +958,7 @@ def __call__(
958958 hidden_states = attn .batch_to_head_dim (hidden_states )
959959
960960 # linear proj
961- hidden_states = attn .to_out [0 ](hidden_states , lora_scale = scale )
961+ hidden_states = attn .to_out [0 ](hidden_states , scale = scale )
962962 # dropout
963963 hidden_states = attn .to_out [1 ](hidden_states )
964964
@@ -1015,15 +1015,15 @@ def __call__(
10151015 if attn .group_norm is not None :
10161016 hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
10171017
1018- query = attn .to_q (hidden_states , lora_scale = scale )
1018+ query = attn .to_q (hidden_states , scale = scale )
10191019
10201020 if encoder_hidden_states is None :
10211021 encoder_hidden_states = hidden_states
10221022 elif attn .norm_cross :
10231023 encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
10241024
1025- key = attn .to_k (encoder_hidden_states , lora_scale = scale )
1026- value = attn .to_v (encoder_hidden_states , lora_scale = scale )
1025+ key = attn .to_k (encoder_hidden_states , scale = scale )
1026+ value = attn .to_v (encoder_hidden_states , scale = scale )
10271027
10281028 inner_dim = key .shape [- 1 ]
10291029 head_dim = inner_dim // attn .heads
@@ -1043,7 +1043,7 @@ def __call__(
10431043 hidden_states = hidden_states .to (query .dtype )
10441044
10451045 # linear proj
1046- hidden_states = attn .to_out [0 ](hidden_states , lora_scale = scale )
1046+ hidden_states = attn .to_out [0 ](hidden_states , scale = scale )
10471047 # dropout
10481048 hidden_states = attn .to_out [1 ](hidden_states )
10491049
0 commit comments