@@ -296,10 +296,10 @@ class LoRACrossAttnProcessor(nn.Module):
296296 def __init__ (self , hidden_size , cross_attention_dim = None , rank = 4 ):
297297 super ().__init__ ()
298298
299- self .to_q_lora = LoRALinearLayer (hidden_size , hidden_size )
300- self .to_k_lora = LoRALinearLayer (cross_attention_dim or hidden_size , hidden_size )
301- self .to_v_lora = LoRALinearLayer (cross_attention_dim or hidden_size , hidden_size )
302- self .to_out_lora = LoRALinearLayer (hidden_size , hidden_size )
299+ self .to_q_lora = LoRALinearLayer (hidden_size , hidden_size , rank )
300+ self .to_k_lora = LoRALinearLayer (cross_attention_dim or hidden_size , hidden_size , rank )
301+ self .to_v_lora = LoRALinearLayer (cross_attention_dim or hidden_size , hidden_size , rank )
302+ self .to_out_lora = LoRALinearLayer (hidden_size , hidden_size , rank )
303303
304304 def __call__ (
305305 self , attn : CrossAttention , hidden_states , encoder_hidden_states = None , attention_mask = None , scale = 1.0
@@ -408,10 +408,10 @@ class LoRAXFormersCrossAttnProcessor(nn.Module):
408408 def __init__ (self , hidden_size , cross_attention_dim , rank = 4 ):
409409 super ().__init__ ()
410410
411- self .to_q_lora = LoRALinearLayer (hidden_size , hidden_size )
412- self .to_k_lora = LoRALinearLayer (cross_attention_dim or hidden_size , hidden_size )
413- self .to_v_lora = LoRALinearLayer (cross_attention_dim or hidden_size , hidden_size )
414- self .to_out_lora = LoRALinearLayer (hidden_size , hidden_size )
411+ self .to_q_lora = LoRALinearLayer (hidden_size , hidden_size , rank )
412+ self .to_k_lora = LoRALinearLayer (cross_attention_dim or hidden_size , hidden_size , rank )
413+ self .to_v_lora = LoRALinearLayer (cross_attention_dim or hidden_size , hidden_size , rank )
414+ self .to_out_lora = LoRALinearLayer (hidden_size , hidden_size , rank )
415415
416416 def __call__ (
417417 self , attn : CrossAttention , hidden_states , encoder_hidden_states = None , attention_mask = None , scale = 1.0
0 commit comments