Skip to content

Conversation

@ad8e
Copy link

@ad8e ad8e commented Apr 26, 2024

@facebook-github-bot
Copy link

Hi @ad8e!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@ad8e
Copy link
Author

ad8e commented Apr 26, 2024

Probably I won't end up signing the CLA, for laziness reasons; my PyTorch CLA unfortunately doesn't apply here. Someone else can pick up these changes and submit them.

@wanchaol wanchaol requested a review from tianyu-l April 26, 2024 00:11
@wanchaol
Copy link
Collaborator

Probably I won't end up signing the CLA, for laziness reasons; my PyTorch CLA unfortunately doesn't apply here. Someone else can pick up these changes and submit them.

@ad8e Thanks for the PR! sounds good, @tianyu-l could you take a look on this and help submit changes?

@tianyu-l
Copy link
Contributor

Excludes vocab embedding from FLOPS.

Thank you @ad8e for helping improve torchtitan!

May I ask why we should exclude vocab embedding from FLOPS computation? It seems to me that the embedding layer is involved in both the forward and backward computations.

Thanks!

@ad8e
Copy link
Author

ad8e commented Apr 26, 2024

Backward of embedding: you're right, the vocab layer must be counted. FLOPS = 2x vocab embedding params. It must calculate the gradient for the vocab weights, but not the gradient for the inputs.

Forward of embedding: it acts as a lookup table and so the flops are 0 (or = hidden dimension, if you want to count the memory bandwidth).

So I would instead suggest 6 * num_params - 4 * v * d + 7 * l * h * q * t as my corrected formula.

EDIT: Actually, perhaps the vocab backward can also use the lookup table method and skip the matmul. So then the vocab layer wouldn't need any matmul, in either forward or backward. I don't know the internal implementation of the embedding though. In that case, my original PR has the correct formula.

@ad8e
Copy link
Author

ad8e commented Apr 26, 2024

https://discuss.pytorch.org/t/how-does-backward-work-for-embeddingbag/103342

Embedding backward uses the lookup table, as expected. So the vocab layer should be omitted entirely from FLOPS.

@tianyu-l
Copy link
Contributor

https://discuss.pytorch.org/t/how-does-backward-work-for-embeddingbag/103342

Embedding backward uses the lookup table, as expected. So the vocab layer should be omitted entirely from FLOPS.

That makes sense! I've sent a PR #280 to address this. For the flash attention part, it's quite tricky (e.g. we shouldn't include the extra matmul recomputation in the backward pass into MFU computation; do we really want to consider the sparsity introduced in causal attention as the hardware treats sparsity differently; etc.), so I'm keeping the factor of 12.

@ad8e
Copy link
Author

ad8e commented Apr 27, 2024

tianyu's PR is better and has a CLA attached, so closing this in favor of the referenced PR.

@ad8e ad8e closed this Apr 27, 2024
@ad8e ad8e deleted the patch-1 branch April 27, 2024 01:04
@awgu awgu mentioned this pull request Jul 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants