-
Notifications
You must be signed in to change notification settings - Fork 603
[DSV3] Adding 16B model training config, Enable FSDP and AC on DSV3-16B model #1330
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
As titled, to save some H100 resource and avoid long waiting, only run integration test when PR's base branch is main. No need to run H100 tests on PRs like #1330
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to use more realistic config, but can revisit later.
H-Huang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| ): | ||
| # TODO: Add support for parallelizing the model, this is a placeholder function for now | ||
| if job_config.activation_checkpoint.mode != "none": | ||
| apply_ac(model, job_config.activation_checkpoint) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding is that for SAC we are counting the number of matmuls occuring during forward, then selectively saving every say, N matmuls.
MoE might affect this in two ways:
- matmul imbalances (gating/routing computation is lightweight, while expert MM is heavy)
- Not sure how this interacts with expert parallel is across multiple ranks?
I'm not sure if we cover this in Llama4, any ideas @tianyu-l? Anyways, if SAC isn't covered i dont think its that high pri but maybe just add a comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a great point I've missed! Let me note this down and see how to resolve. If we can identify router/gating matmuls we can just ignore them in AC.
SAC per layer should still be more or less useful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only tested Full AC not SAC, if we agree we will not support SAC, I could add a comment.
| # shape (bs*slen*top_k, dim) | ||
| routed_output = self.experts(routed_input, num_local_tokens_per_expert) | ||
| routed_output = routed_output * top_scores.unsqueeze(-1) | ||
| routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just curious how come this is needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Router computation is in fp32, so top_scores is in fp32.
This step is to make the score x activation computation in high precision, and then cast back.
Router precision in MoE seems critical for the training stability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After applying FSDP, the routed_output at line 309 is bf16, and the top_scores is float32. If we don't explicitly convert dtype, the routed_output = routed_output * top_scores at line 310 will has dtype float32 (auto converted to high precision).
out = out.scatter_add(dim=0, index=token_indices, src=routed_output)
In this line, the out is bf16, as we applied FSDP. So I added this explicit dtype conversion following llama4
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanations!!
…#1331) As titled, to save some H100 resource and avoid long waiting, only run integration test when PR's base branch is main. No need to run H100 tests on PRs like pytorch#1330
…6B model (pytorch#1330) ## Context 1. Introduced a basic DSV3-16B model training config 2. Enabled FSDP/HSDP on DSV3-16B model training ## Performance Current profiler looks like this: The `to_copy` takes to long and needs to be optimized. The copy comes from dtype conversion in class MoE(): ```routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to(x.dtype)``` With FSDP only: <img width="1544" alt="Screenshot 2025-06-23 at 2 10 20 PM" src="https://github.com/user-attachments/assets/bcd698dc-3899-46e0-ae53-e7f8b0db13fc" />
As titled, to save some H100 resource and avoid long waiting, only run integration test when PR's base branch is main. No need to run H100 tests on PRs like #1330
…6B model (#1330) ## Context 1. Introduced a basic DSV3-16B model training config 2. Enabled FSDP/HSDP on DSV3-16B model training ## Performance Current profiler looks like this: The `to_copy` takes to long and needs to be optimized. The copy comes from dtype conversion in class MoE(): ```routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to(x.dtype)``` With FSDP only: <img width="1544" alt="Screenshot 2025-06-23 at 2 10 20 PM" src="https://github.com/user-attachments/assets/bcd698dc-3899-46e0-ae53-e7f8b0db13fc" />
…6B model (#1330) ## Context 1. Introduced a basic DSV3-16B model training config 2. Enabled FSDP/HSDP on DSV3-16B model training ## Performance Current profiler looks like this: The `to_copy` takes to long and needs to be optimized. The copy comes from dtype conversion in class MoE(): ```routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to(x.dtype)``` With FSDP only: <img width="1544" alt="Screenshot 2025-06-23 at 2 10 20 PM" src="https://github.com/user-attachments/assets/bcd698dc-3899-46e0-ae53-e7f8b0db13fc" />
…6B model (#1330) ## Context 1. Introduced a basic DSV3-16B model training config 2. Enabled FSDP/HSDP on DSV3-16B model training ## Performance Current profiler looks like this: The `to_copy` takes to long and needs to be optimized. The copy comes from dtype conversion in class MoE(): ```routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to(x.dtype)``` With FSDP only: <img width="1544" alt="Screenshot 2025-06-23 at 2 10 20 PM" src="https://github.com/user-attachments/assets/bcd698dc-3899-46e0-ae53-e7f8b0db13fc" />
…6B model (pytorch#1330) ## Context 1. Introduced a basic DSV3-16B model training config 2. Enabled FSDP/HSDP on DSV3-16B model training ## Performance Current profiler looks like this: The `to_copy` takes to long and needs to be optimized. The copy comes from dtype conversion in class MoE(): ```routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to(x.dtype)``` With FSDP only: <img width="1544" alt="Screenshot 2025-06-23 at 2 10 20 PM" src="https://github.com/user-attachments/assets/bcd698dc-3899-46e0-ae53-e7f8b0db13fc" />
…6B model (pytorch#1330) ## Context 1. Introduced a basic DSV3-16B model training config 2. Enabled FSDP/HSDP on DSV3-16B model training ## Performance Current profiler looks like this: The `to_copy` takes to long and needs to be optimized. The copy comes from dtype conversion in class MoE(): ```routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to(x.dtype)``` With FSDP only: <img width="1544" alt="Screenshot 2025-06-23 at 2 10 20 PM" src="https://github.com/user-attachments/assets/bcd698dc-3899-46e0-ae53-e7f8b0db13fc" />
…6B model (#1330) ## Context 1. Introduced a basic DSV3-16B model training config 2. Enabled FSDP/HSDP on DSV3-16B model training ## Performance Current profiler looks like this: The `to_copy` takes to long and needs to be optimized. The copy comes from dtype conversion in class MoE(): ```routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to(x.dtype)``` With FSDP only: <img width="1544" alt="Screenshot 2025-06-23 at 2 10 20 PM" src="https://github.com/user-attachments/assets/bcd698dc-3899-46e0-ae53-e7f8b0db13fc" />
…#1331) As titled, to save some H100 resource and avoid long waiting, only run integration test when PR's base branch is main. No need to run H100 tests on PRs like pytorch#1330
…6B model (#1330) ## Context 1. Introduced a basic DSV3-16B model training config 2. Enabled FSDP/HSDP on DSV3-16B model training ## Performance Current profiler looks like this: The `to_copy` takes to long and needs to be optimized. The copy comes from dtype conversion in class MoE(): ```routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to(x.dtype)``` With FSDP only: <img width="1544" alt="Screenshot 2025-06-23 at 2 10 20 PM" src="https://github.com/user-attachments/assets/bcd698dc-3899-46e0-ae53-e7f8b0db13fc" />
Context
Performance
Current profiler looks like this: The
to_copytakes to long and needs to be optimized. The copy comes from dtype conversion in class MoE():routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to(x.dtype)With FSDP only:
