diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index db36dfc30332..e7b7aadba3c7 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -63,9 +63,6 @@ def sdpa_attention_forward( else: sdpa_kwargs = {"enable_gqa": True} - if attention_mask is not None and attention_mask.ndim == 4: - attention_mask = attention_mask[:, :, :, : key.shape[-2]] - # Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)