Skip to content

Commit f40ef03

Browse files
authored
Remove unnecessary slicing in sdpa_attention_forward (#41900)
Remove redundant slicing in sdpa_attention_forward The slicing in sdpa_attention_forward was there only because some masks were not constructed correctly (I was told). When the dimension is dynamic, the slice op also prevents torch.export from correctly reasoning about its size. Signed-off-by: Justin Chu <[email protected]>
1 parent 5150dac commit f40ef03

File tree

1 file changed

+0
-3
lines changed

1 file changed

+0
-3
lines changed

src/transformers/integrations/sdpa_attention.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,6 @@ def sdpa_attention_forward(
6363
else:
6464
sdpa_kwargs = {"enable_gqa": True}
6565

66-
if attention_mask is not None and attention_mask.ndim == 4:
67-
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
68-
6966
# Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented
7067
is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
7168

0 commit comments

Comments
 (0)