-
Notifications
You must be signed in to change notification settings - Fork 603
run sdpa with dtensor #180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
This PR gets rid of the manual adjustment of num of heads in attention layers, by using dtensor outputs of `wq`, `wk`, `wv`, so that the SDPA is aware of the distributedness. [ghstack-poisoned]
| "attention.wq": col_parallel_strategy(), | ||
| "attention.wk": col_parallel_strategy(), | ||
| "attention.wv": col_parallel_strategy(), | ||
| "attention.wq": col_parallel_strategy(use_local_output=False), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤔 I thought we need to replicate the freq_cis but here it seems we don't need to?
|
just curious, is this gonna land soon or does it have some risk or unfinished business? also looks like this could use a rebase. i got a little confused applying it on my branch bc some of the sharding config seems changed (attention.wo and attention_norm) |
It hasn't been landed because there is a very strange bug (#267) associated with (but seemingly not caused by) multiplication using DTensor. It would be triggered in the rotary embedding computation if this PR is landed. I will work on the bug soon since it will also benefit PP (iiuc). @wconstab |
oh, is this related to dispatching for complex numbers by any chance? |
@wconstab Possibly, we don't know. The |
9d45a6c to
e773b75
Compare
fe1f241 to
a28e74e
Compare
Stack from ghstack (oldest at bottom):
This PR gets rid of the manual adjustment of num of heads in attention layers, by using dtensor outputs of
wq,wk,wv, so that the SDPA is aware of the distributedness.