-
Notifications
You must be signed in to change notification settings - Fork 601
Change MFU calculation #274
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Uses Flash's FLOP counter of 7*: https://github.com/Dao-AILab/flash-attention/blob/23e8fa5a263d1c7122bc46a86ef32030ee7130f9/benchmarks/benchmark_flash_attention.py#L27 Excludes vocab embedding from FLOPS. Did not test.
|
Hi @ad8e! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
|
Probably I won't end up signing the CLA, for laziness reasons; my PyTorch CLA unfortunately doesn't apply here. Someone else can pick up these changes and submit them. |
Thank you @ad8e for helping improve torchtitan! May I ask why we should exclude vocab embedding from FLOPS computation? It seems to me that the embedding layer is involved in both the forward and backward computations. Thanks! |
|
Forward of embedding: it acts as a lookup table and so the flops are 0 (or = hidden dimension, if you want to count the memory bandwidth).
EDIT: Actually, perhaps the vocab backward can also use the lookup table method and skip the matmul. So then the vocab layer wouldn't need any matmul, in either forward or backward. I don't know the internal implementation of the embedding though. In that case, my original PR has the correct formula. |
|
https://discuss.pytorch.org/t/how-does-backward-work-for-embeddingbag/103342 Embedding backward uses the lookup table, as expected. So the vocab layer should be omitted entirely from FLOPS. |
That makes sense! I've sent a PR #280 to address this. For the flash attention part, it's quite tricky (e.g. we shouldn't include the extra matmul recomputation in the backward pass into MFU computation; do we really want to consider the sparsity introduced in causal attention as the hardware treats sparsity differently; etc.), so I'm keeping the factor of 12. |
|
tianyu's PR is better and has a CLA attached, so closing this in favor of the referenced PR. |
Uses Flash's FLOP counter of 7*: https://github.com/Dao-AILab/flash-attention/blob/23e8fa5a263d1c7122bc46a86ef32030ee7130f9/benchmarks/benchmark_flash_attention.py#L27
Excludes vocab embedding from FLOPS.
Did not test.