Skip to content

Conversation

@wconstab
Copy link
Contributor

@wconstab wconstab commented May 10, 2024

Stack from ghstack (oldest at bottom):

Unchanged: we precompute freqs_cis for max_seqlen, >> seqlen for a given
batch.

Changed: instead of slicing self.freqs_cis down to seqlen at top level
transformer based on the input token shape, we slice it down to seqlen
inside a transformer layer after we have re-expanded to the full seqlen
in cases where TP has sharded across seqlen.

In the PP case, stage 1's input may be seqlen/TP instead of seqlen, but
we do not generally know this. That makes it hard for stage1 to slice
freqs_cis correctly. It's easy to do the slicing deeper inside, since
at that point we do know the full seqlen unambiguously.

Note: the full self.freqs_cis is stored in memory either way, and the
thing passed into every layer is just a view. This change should not be
material for memory usage or otherwise.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request May 10, 2024
Unchanged: we precompute freqs_cis for max_seqlen, >> seqlen for a given
batch.

Changed: instead of slicing self.freqs_cis down to seqlen at top level
transformer based on the input token shape, we slice it down to seqlen
inside a transformer layer after we have re-expanded to the full seqlen
in cases where TP has sharded across seqlen.

In the PP case, stage 1's input may be seqlen/TP instead of seqlen, but
we do not generally know this.  That makes it hard for stage1 to slice
freqs_cis correctly.  It's easy to do the slicing deeper inside, since
at that point we do know the full seqlen unambiguously.

Note: the full self.freqs_cis is stored in memory either way, and the
thing passed into every layer is just a view. This change should not be
material for memory usage or otherwise.

ghstack-source-id: 20ef05e
Pull Request resolved: #321
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 10, 2024
Copy link
Contributor

@lessw2020 lessw2020 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense - lgtm!

Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@wconstab wconstab merged commit 231ebc1 into gh/wconstab/13/base May 13, 2024
wconstab added a commit that referenced this pull request May 13, 2024
Unchanged: we precompute freqs_cis for max_seqlen, >> seqlen for a given
batch.

Changed: instead of slicing self.freqs_cis down to seqlen at top level
transformer based on the input token shape, we slice it down to seqlen
inside a transformer layer after we have re-expanded to the full seqlen
in cases where TP has sharded across seqlen.

In the PP case, stage 1's input may be seqlen/TP instead of seqlen, but
we do not generally know this.  That makes it hard for stage1 to slice
freqs_cis correctly.  It's easy to do the slicing deeper inside, since
at that point we do know the full seqlen unambiguously.

Note: the full self.freqs_cis is stored in memory either way, and the
thing passed into every layer is just a view. This change should not be
material for memory usage or otherwise.

ghstack-source-id: 20ef05e
Pull Request resolved: #321
@wconstab wconstab deleted the gh/wconstab/13/head branch May 13, 2024 21:46
torch.Tensor: Reshaped frequency tensor.
"""
ndim = x.ndim
assert 0 <= 1 < ndim
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not from this PR: I wonder what the point of the 0 <= 1 part is 😃 .

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol. its always good to check your assumptions

tianyu-l pushed a commit to tianyu-l/torchtitan_intern24 that referenced this pull request Aug 16, 2024
Unchanged: we precompute freqs_cis for max_seqlen, >> seqlen for a given
batch.

Changed: instead of slicing self.freqs_cis down to seqlen at top level
transformer based on the input token shape, we slice it down to seqlen
inside a transformer layer after we have re-expanded to the full seqlen
in cases where TP has sharded across seqlen.

In the PP case, stage 1's input may be seqlen/TP instead of seqlen, but
we do not generally know this.  That makes it hard for stage1 to slice
freqs_cis correctly.  It's easy to do the slicing deeper inside, since
at that point we do know the full seqlen unambiguously.

Note: the full self.freqs_cis is stored in memory either way, and the
thing passed into every layer is just a view. This change should not be
material for memory usage or otherwise.

ghstack-source-id: 20ef05e
Pull Request resolved: pytorch#321
philippguevorguian pushed a commit to YerevaNN/YNNtitan that referenced this pull request Aug 17, 2024
Unchanged: we precompute freqs_cis for max_seqlen, >> seqlen for a given
batch.

Changed: instead of slicing self.freqs_cis down to seqlen at top level
transformer based on the input token shape, we slice it down to seqlen
inside a transformer layer after we have re-expanded to the full seqlen
in cases where TP has sharded across seqlen.

In the PP case, stage 1's input may be seqlen/TP instead of seqlen, but
we do not generally know this.  That makes it hard for stage1 to slice
freqs_cis correctly.  It's easy to do the slicing deeper inside, since
at that point we do know the full seqlen unambiguously.

Note: the full self.freqs_cis is stored in memory either way, and the
thing passed into every layer is just a view. This change should not be
material for memory usage or otherwise.

ghstack-source-id: 20ef05e
Pull Request resolved: pytorch#321
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants