Skip to content

Commit f62b324

Browse files
authored
Pass LoRA rank to LoRALinearLayer (huggingface#2191)
1 parent 8da25e2 commit f62b324

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

models/cross_attention.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -296,10 +296,10 @@ class LoRACrossAttnProcessor(nn.Module):
296296
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
297297
super().__init__()
298298

299-
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
300-
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
301-
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
302-
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
299+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
300+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
301+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
302+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
303303

304304
def __call__(
305305
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
@@ -408,10 +408,10 @@ class LoRAXFormersCrossAttnProcessor(nn.Module):
408408
def __init__(self, hidden_size, cross_attention_dim, rank=4):
409409
super().__init__()
410410

411-
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
412-
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
413-
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
414-
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
411+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
412+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
413+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
414+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
415415

416416
def __call__(
417417
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0

0 commit comments

Comments
 (0)