Skip to content

Commit d15e157

Browse files
committed
finally fix glm4v accordingly
1 parent 9873b2f commit d15e157

File tree

2 files changed

+22
-55
lines changed

2 files changed

+22
-55
lines changed

src/transformers/models/glm4v/modeling_glm4v.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -296,23 +296,31 @@ def forward(
296296
cu_seqlens: torch.Tensor,
297297
rotary_pos_emb: Optional[torch.Tensor] = None,
298298
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
299-
**kwargs: Unpack[FlashAttentionKwargs],
299+
attention_mask: Optional[torch.Tensor] = None,
300+
**kwargs,
300301
) -> torch.Tensor:
301302
seq_length = hidden_states.shape[0]
302303
query_states, key_states, value_states = (
303304
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
304305
)
305-
306-
cos, sin = position_embeddings
306+
if position_embeddings is None:
307+
logger.warning_once(
308+
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
309+
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
310+
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
311+
"removed and `position_embeddings` will be mandatory."
312+
)
313+
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
314+
cos = emb.cos()
315+
sin = emb.sin()
316+
else:
317+
cos, sin = position_embeddings
307318
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
308319

309320
query_states = query_states.transpose(0, 1).unsqueeze(0)
310321
key_states = key_states.transpose(0, 1).unsqueeze(0)
311322
value_states = value_states.transpose(0, 1).unsqueeze(0)
312-
313-
attention_mask = torch.zeros([1, 1, seq_length, seq_length], device=query_states.device, dtype=torch.bool)
314-
for i in range(1, len(cu_seqlens)):
315-
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
323+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
316324

317325
attention_interface: Callable = eager_attention_forward
318326
if self.config._attn_implementation != "eager":
@@ -323,13 +331,17 @@ def forward(
323331
query_states,
324332
key_states,
325333
value_states,
326-
attention_mask,
334+
attention_mask=attention_mask,
327335
dropout=0.0 if not self.training else self.attention_dropout,
328336
scaling=self.scaling,
329-
is_causal=self.is_causal,
337+
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
338+
cu_seq_lens_k=cu_seqlens,
339+
max_length_q=max_seqlen,
340+
max_length_k=max_seqlen,
341+
is_causal=False,
330342
**kwargs,
331343
)
332-
attn_output = attn_output.squeeze(0)
344+
333345
attn_output = attn_output.reshape(seq_length, -1).contiguous()
334346
attn_output = self.proj(attn_output)
335347
return attn_output

src/transformers/models/glm4v/modular_glm4v.py

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -512,51 +512,6 @@ def __init__(self, config: Glm4vVisionConfig) -> None:
512512
self.attention_dropout = config.attention_dropout
513513
self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias)
514514
self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
515-
self.is_causal = False
516-
517-
def forward(
518-
self,
519-
hidden_states: torch.Tensor,
520-
cu_seqlens: torch.Tensor,
521-
rotary_pos_emb: Optional[torch.Tensor] = None,
522-
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
523-
**kwargs: Unpack[FlashAttentionKwargs],
524-
) -> torch.Tensor:
525-
seq_length = hidden_states.shape[0]
526-
query_states, key_states, value_states = (
527-
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
528-
)
529-
530-
cos, sin = position_embeddings
531-
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
532-
533-
query_states = query_states.transpose(0, 1).unsqueeze(0)
534-
key_states = key_states.transpose(0, 1).unsqueeze(0)
535-
value_states = value_states.transpose(0, 1).unsqueeze(0)
536-
537-
attention_mask = torch.zeros([1, 1, seq_length, seq_length], device=query_states.device, dtype=torch.bool)
538-
for i in range(1, len(cu_seqlens)):
539-
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
540-
541-
attention_interface: Callable = eager_attention_forward
542-
if self.config._attn_implementation != "eager":
543-
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
544-
545-
attn_output, _ = attention_interface(
546-
self,
547-
query_states,
548-
key_states,
549-
value_states,
550-
attention_mask,
551-
dropout=0.0 if not self.training else self.attention_dropout,
552-
scaling=self.scaling,
553-
is_causal=self.is_causal,
554-
**kwargs,
555-
)
556-
attn_output = attn_output.squeeze(0)
557-
attn_output = attn_output.reshape(seq_length, -1).contiguous()
558-
attn_output = self.proj(attn_output)
559-
return attn_output
560515

561516

562517
class Glm4vVisionBlock(Qwen2_5_VLVisionBlock):

0 commit comments

Comments
 (0)