Skip to content

Conversation

@wwwjn
Copy link
Contributor

@wwwjn wwwjn commented Jun 23, 2025

Context

  1. Introduced a basic DSV3-16B model training config
  2. Enabled FSDP/HSDP on DSV3-16B model training

Performance

Current profiler looks like this: The to_copy takes to long and needs to be optimized. The copy comes from dtype conversion in class MoE():
routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to(x.dtype)

With FSDP only:
Screenshot 2025-06-23 at 2 10 20 PM

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 23, 2025
@wwwjn wwwjn requested review from H-Huang and tianyu-l June 23, 2025 21:12
tianyu-l pushed a commit that referenced this pull request Jun 24, 2025
As titled, to save some H100 resource and avoid long waiting, only run
integration test when PR's base branch is main.

No need to run H100 tests on PRs like #1330
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to use more realistic config, but can revisit later.

Copy link
Member

@H-Huang H-Huang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

):
# TODO: Add support for parallelizing the model, this is a placeholder function for now
if job_config.activation_checkpoint.mode != "none":
apply_ac(model, job_config.activation_checkpoint)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that for SAC we are counting the number of matmuls occuring during forward, then selectively saving every say, N matmuls.

MoE might affect this in two ways:

  1. matmul imbalances (gating/routing computation is lightweight, while expert MM is heavy)
  2. Not sure how this interacts with expert parallel is across multiple ranks?

I'm not sure if we cover this in Llama4, any ideas @tianyu-l? Anyways, if SAC isn't covered i dont think its that high pri but maybe just add a comment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a great point I've missed! Let me note this down and see how to resolve. If we can identify router/gating matmuls we can just ignore them in AC.

SAC per layer should still be more or less useful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only tested Full AC not SAC, if we agree we will not support SAC, I could add a comment.

# shape (bs*slen*top_k, dim)
routed_output = self.experts(routed_input, num_local_tokens_per_expert)
routed_output = routed_output * top_scores.unsqueeze(-1)
routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious how come this is needed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Router computation is in fp32, so top_scores is in fp32.
This step is to make the score x activation computation in high precision, and then cast back.
Router precision in MoE seems critical for the training stability.

Copy link
Contributor Author

@wwwjn wwwjn Jun 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After applying FSDP, the routed_output at line 309 is bf16, and the top_scores is float32. If we don't explicitly convert dtype, the routed_output = routed_output * top_scores at line 310 will has dtype float32 (auto converted to high precision).

out = out.scatter_add(dim=0, index=token_indices, src=routed_output)

In this line, the out is bf16, as we applied FSDP. So I added this explicit dtype conversion following llama4

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanations!!

@wwwjn wwwjn merged commit b74918a into deepseek-v3 Jun 24, 2025
5 checks passed
@tianyu-l tianyu-l deleted the dsv3-fsdp branch June 25, 2025 02:01
H-Huang pushed a commit to H-Huang/torchtitan that referenced this pull request Jun 26, 2025
…#1331)

As titled, to save some H100 resource and avoid long waiting, only run
integration test when PR's base branch is main.

No need to run H100 tests on PRs like pytorch#1330
H-Huang pushed a commit to H-Huang/torchtitan that referenced this pull request Jun 26, 2025
…6B model (pytorch#1330)

## Context
1. Introduced a basic DSV3-16B model training config
2. Enabled FSDP/HSDP on DSV3-16B model training

## Performance
Current profiler looks like this: The `to_copy` takes to long and needs
to be optimized. The copy comes from dtype conversion in class MoE():
```routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to(x.dtype)```

With FSDP only:
<img width="1544" alt="Screenshot 2025-06-23 at 2 10 20 PM" src="https://github.com/user-attachments/assets/bcd698dc-3899-46e0-ae53-e7f8b0db13fc" />
wwwjn added a commit that referenced this pull request Jul 1, 2025
As titled, to save some H100 resource and avoid long waiting, only run
integration test when PR's base branch is main.

No need to run H100 tests on PRs like #1330
wwwjn added a commit that referenced this pull request Jul 1, 2025
…6B model (#1330)

## Context
1. Introduced a basic DSV3-16B model training config
2. Enabled FSDP/HSDP on DSV3-16B model training

## Performance
Current profiler looks like this: The `to_copy` takes to long and needs
to be optimized. The copy comes from dtype conversion in class MoE():
```routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to(x.dtype)```

With FSDP only:
<img width="1544" alt="Screenshot 2025-06-23 at 2 10 20 PM" src="https://github.com/user-attachments/assets/bcd698dc-3899-46e0-ae53-e7f8b0db13fc" />
wwwjn added a commit that referenced this pull request Jul 1, 2025
…6B model (#1330)

## Context
1. Introduced a basic DSV3-16B model training config
2. Enabled FSDP/HSDP on DSV3-16B model training

## Performance
Current profiler looks like this: The `to_copy` takes to long and needs
to be optimized. The copy comes from dtype conversion in class MoE():
```routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to(x.dtype)```

With FSDP only:
<img width="1544" alt="Screenshot 2025-06-23 at 2 10 20 PM" src="https://github.com/user-attachments/assets/bcd698dc-3899-46e0-ae53-e7f8b0db13fc" />
wwwjn added a commit that referenced this pull request Jul 2, 2025
…6B model (#1330)

## Context
1. Introduced a basic DSV3-16B model training config
2. Enabled FSDP/HSDP on DSV3-16B model training

## Performance
Current profiler looks like this: The `to_copy` takes to long and needs
to be optimized. The copy comes from dtype conversion in class MoE():
```routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to(x.dtype)```

With FSDP only:
<img width="1544" alt="Screenshot 2025-06-23 at 2 10 20 PM" src="https://github.com/user-attachments/assets/bcd698dc-3899-46e0-ae53-e7f8b0db13fc" />
H-Huang pushed a commit to H-Huang/torchtitan that referenced this pull request Jul 3, 2025
…6B model (pytorch#1330)

## Context
1. Introduced a basic DSV3-16B model training config
2. Enabled FSDP/HSDP on DSV3-16B model training

## Performance
Current profiler looks like this: The `to_copy` takes to long and needs
to be optimized. The copy comes from dtype conversion in class MoE():
```routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to(x.dtype)```

With FSDP only:
<img width="1544" alt="Screenshot 2025-06-23 at 2 10 20 PM" src="https://github.com/user-attachments/assets/bcd698dc-3899-46e0-ae53-e7f8b0db13fc" />
H-Huang pushed a commit to H-Huang/torchtitan that referenced this pull request Jul 8, 2025
…6B model (pytorch#1330)

## Context
1. Introduced a basic DSV3-16B model training config
2. Enabled FSDP/HSDP on DSV3-16B model training

## Performance
Current profiler looks like this: The `to_copy` takes to long and needs
to be optimized. The copy comes from dtype conversion in class MoE():
```routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to(x.dtype)```

With FSDP only:
<img width="1544" alt="Screenshot 2025-06-23 at 2 10 20 PM" src="https://github.com/user-attachments/assets/bcd698dc-3899-46e0-ae53-e7f8b0db13fc" />
wwwjn added a commit that referenced this pull request Jul 8, 2025
…6B model (#1330)

## Context
1. Introduced a basic DSV3-16B model training config
2. Enabled FSDP/HSDP on DSV3-16B model training

## Performance
Current profiler looks like this: The `to_copy` takes to long and needs
to be optimized. The copy comes from dtype conversion in class MoE():
```routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to(x.dtype)```

With FSDP only:
<img width="1544" alt="Screenshot 2025-06-23 at 2 10 20 PM" src="https://github.com/user-attachments/assets/bcd698dc-3899-46e0-ae53-e7f8b0db13fc" />
mori360 pushed a commit to mori360/torchtitan that referenced this pull request Jul 8, 2025
…#1331)

As titled, to save some H100 resource and avoid long waiting, only run
integration test when PR's base branch is main.

No need to run H100 tests on PRs like pytorch#1330
wwwjn added a commit that referenced this pull request Jul 10, 2025
…6B model (#1330)

## Context
1. Introduced a basic DSV3-16B model training config
2. Enabled FSDP/HSDP on DSV3-16B model training

## Performance
Current profiler looks like this: The `to_copy` takes to long and needs
to be optimized. The copy comes from dtype conversion in class MoE():
```routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to(x.dtype)```

With FSDP only:
<img width="1544" alt="Screenshot 2025-06-23 at 2 10 20 PM" src="https://github.com/user-attachments/assets/bcd698dc-3899-46e0-ae53-e7f8b0db13fc" />
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants