-
Notifications
You must be signed in to change notification settings - Fork 603
[DSV3] Adding 16B model training config, Enable FSDP and AC on DSV3-16B model #1330
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f7e9ee9
f1bb6b8
ea262c4
0b56b96
5cbafad
11d3b38
b98683a
dfe2b61
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -307,7 +307,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |
|
|
||
| # shape (bs*slen*top_k, dim) | ||
| routed_output = self.experts(routed_input, num_local_tokens_per_expert) | ||
| routed_output = routed_output * top_scores.unsqueeze(-1) | ||
| routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just curious how come this is needed?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Router computation is in fp32, so
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After applying FSDP, the
In this line, the
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the explanations!! |
||
| x.dtype | ||
| ) | ||
|
|
||
| # shared expert | ||
| if self.shared_expert is not None: | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. need to use more realistic config, but can revisit later. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| # torchtitan Config.toml | ||
|
|
||
| [job] | ||
| dump_folder = "./outputs" | ||
| description = "DeepSeek-V3 16B model training" | ||
| print_args = false | ||
|
|
||
| [profiling] | ||
| enable_profiling = false | ||
| save_traces_folder = "profile_trace" | ||
| profile_freq = 10 | ||
| enable_memory_snapshot = false | ||
| save_memory_snapshot_folder = "memory_snapshot" | ||
|
|
||
| [metrics] | ||
| log_freq = 1 | ||
| disable_color_printing = false | ||
| enable_tensorboard = false | ||
| save_tb_folder = "tb" | ||
| enable_wandb = false | ||
|
|
||
| [model] | ||
| name = "deepseek_v3" | ||
| flavor = "16B" | ||
| # test tokenizer.model, for debug purpose only | ||
| tokenizer_path = "./tests/assets/test_tiktoken.model" | ||
| # converters = ["float8"] | ||
|
|
||
| [optimizer] | ||
| name = "AdamW" | ||
| lr = 8e-4 | ||
| eps = 1e-8 | ||
|
|
||
| [lr_scheduler] | ||
| warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps | ||
| decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps | ||
| decay_type = "linear" | ||
| lr_min = 0.0 | ||
|
|
||
| [training] | ||
| local_batch_size = 32 | ||
| seq_len = 2048 | ||
| max_norm = 1.0 # grad norm clipping | ||
| steps = 10 | ||
| compile = false | ||
| dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) | ||
|
|
||
| [parallelism] | ||
| data_parallel_replicate_degree = 1 | ||
| data_parallel_shard_degree = -1 | ||
| fsdp_reshard_after_forward = "default" # default / never / always | ||
|
|
||
| [checkpoint] | ||
| enable_checkpoint = false | ||
| folder = "checkpoint" | ||
| interval = 10 | ||
| last_save_model_weights_only = false | ||
| export_dtype = "float32" | ||
| async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" | ||
|
|
||
| [activation_checkpoint] | ||
| mode = "full" # ["none", "selective", "full"] | ||
|
|
||
| [float8] | ||
| enable_fsdp_float8_all_gather = false | ||
| precompute_float8_dynamic_scale_for_fsdp = false | ||
| filter_fqns = ["output"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding is that for SAC we are counting the number of matmuls occuring during forward, then selectively saving every say, N matmuls.
MoE might affect this in two ways:
I'm not sure if we cover this in Llama4, any ideas @tianyu-l? Anyways, if SAC isn't covered i dont think its that high pri but maybe just add a comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a great point I've missed! Let me note this down and see how to resolve. If we can identify router/gating matmuls we can just ignore them in AC.
SAC per layer should still be more or less useful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only tested Full AC not SAC, if we agree we will not support SAC, I could add a comment.