-
Notifications
You must be signed in to change notification settings - Fork 31.2k
Remove unnecessary slicing in sdpa_attention_forward #41900
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove unnecessary slicing in sdpa_attention_forward #41900
Conversation
The slicing in sdpa_attention_forward was there only because some masks were not constructed correctly (I was told). When the dimension is dynamic, the slice op also prevents torch.export from correctly reasoning about its size. Signed-off-by: Justin Chu <[email protected]>
|
@Cyrilvallez Looks like this change passes the CI. |
|
@vasqu @Cyrilvallez any thoughts? Thanks. This is an important fix we hope to include in the 5.0 release |
|
Responded in #41559 (comment) But I'm pro this, we might wanna check some important models with slow run. Let's wait for Cyril for a final decision |
|
Sorry for the delay, I was off as @vasqu mentioned! Still very relevant, would be very happy to finally remove this (and in other attn functions as well, such as the eager ones but I can take care of it myself later no worries) cc @ydshieh, could you run a more extensive CI run on this PR and tell us whether you see any new failures, especially on older models? I don't have much time to do it manually myself as I need to catch up on all reviews 🤓 Just a bit scared that the fast tests may not be enough on this one! |
|
For sure, thank you for the ping. I will report back today or tomorrow. |
|
run-slow: bert, gpt2, t5, modernbert, vit, clip, detr, table_transformer, got_ocr2, whisper, wav2vec2, qwen2_audio, speech_t5, csm, llama, gemma3, qwen2, mistral3, qwen2_5_vl, llava, smolvlm, internvl, gemma3n, gpt_oss, qwen2_5_omni |
|
This comment contains models: ["models/bert", "models/clip", "models/csm", "models/detr", "models/gemma3", "models/gemma3n", "models/got_ocr2", "models/gpt2", "models/gpt_oss", "models/internvl", "models/llama", "models/llava", "models/mistral3", "models/modernbert", "models/qwen2", "models/qwen2_5_omni", "models/qwen2_5_vl", "models/qwen2_audio", "models/smolvlm", "models/t5", "models/table_transformer", "models/vit", "models/wav2vec2", "models/whisper"] |
CI Results✅ No failing test specific to this PR 🎉 ! |
|
[Update] Looks good!
|
|
Alright, amazing @ydshieh, thanks! Merging then! Thanks again @justinchuby for pushing on something we wanted since a long time! |
The slicing in sdpa_attention_forward was there only because some masks were not constructed correctly (I was told). When the key size is dynamic, the slice op also prevents torch.export from correctly reasoning about its size.
cc @vasqu