Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchtitan/models/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torchtitan.components.optimizer import build_optimizers
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer

from torchtitan.protocols.train_spec import register_train_spec, TrainSpec

from .infra.parallelize import parallelize_deepseekv3
Expand Down
35 changes: 31 additions & 4 deletions torchtitan/models/deepseek_v3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import torch.nn as nn

from torch.distributed.device_mesh import DeviceMesh

from torchtitan.config_manager import JobConfig
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.distributed import ParallelDims
from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_fsdp
from torchtitan.tools.logging import logger


def parallelize_deepseekv3(
Expand All @@ -19,5 +19,32 @@ def parallelize_deepseekv3(
parallel_dims: ParallelDims,
job_config: JobConfig,
):
# TODO: Add support for parallelizing the model, this is a placeholder function for now
if job_config.activation_checkpoint.mode != "none":
apply_ac(model, job_config.activation_checkpoint)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that for SAC we are counting the number of matmuls occuring during forward, then selectively saving every say, N matmuls.

MoE might affect this in two ways:

  1. matmul imbalances (gating/routing computation is lightweight, while expert MM is heavy)
  2. Not sure how this interacts with expert parallel is across multiple ranks?

I'm not sure if we cover this in Llama4, any ideas @tianyu-l? Anyways, if SAC isn't covered i dont think its that high pri but maybe just add a comment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a great point I've missed! Let me note this down and see how to resolve. If we can identify router/gating matmuls we can just ignore them in AC.

SAC per layer should still be more or less useful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only tested Full AC not SAC, if we agree we will not support SAC, I could add a comment.


dp_mesh: DeviceMesh | None = None
if (
parallel_dims.dp_shard_enabled
): # apply FSDP or HSDP, potentially with Context Parallel
if parallel_dims.dp_replicate_enabled:
dp_mesh_dim_names = ("dp_replicate", "dp_shard")
else:
dp_mesh_dim_names = ("dp_shard",)
dp_mesh = world_mesh[tuple(dp_mesh_dim_names)]

apply_fsdp(
model,
dp_mesh,
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
pp_enabled=parallel_dims.pp_enabled,
cpu_offload=job_config.training.enable_cpu_offload,
reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward,
)

if parallel_dims.dp_replicate_enabled:
logger.info("Applied HSDP to the model")
else:
logger.info("Applied FSDP to the model")

return model
1 change: 0 additions & 1 deletion torchtitan/models/deepseek_v3/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, in
nparams_dense = 0

for name, p in model.named_parameters():
print(name)
if "embedding" in name:
nparams_embedding += p.numel()
nparams_dense += p.numel()
Expand Down
13 changes: 8 additions & 5 deletions torchtitan/models/deepseek_v3/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,10 @@ def forward(
[k_nope, k_pe.expand(-1, -1, self.n_heads, -1)], dim=-1
) # (bsz, seqlen, n_heads, qk_head_dim)

q = q.transpose(1, 2) # (bsz, n_heads, seqlen, qk_head_dim)
k = k.transpose(1, 2) # (bsz, n_heads, seqlen, qk_head_dim)
v = v.transpose(1, 2) # (bsz, n_heads, seqlen, v_head_dim)

# TODO: Need to pass softmax_scale to sdpa() interface.
# For mask, DeepseekV3 uses causal mask, so we can use the default mask in sdpa
# https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py#L17
Expand Down Expand Up @@ -310,11 +314,10 @@ def __init__(self, model_args: DeepSeekV3ModelArgs):
"freqs_cis", precompute_freqs_cis(model_args), persistent=False
)

self.layers = torch.nn.ModuleList()
self.layers = torch.nn.ModuleDict()
for layer_id in range(model_args.n_layers):
self.layers.append(
TransformerBlock(layer_id=layer_id, model_args=model_args)
)
self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args)

self.norm = nn.RMSNorm(model_args.dim)
self.output = nn.Linear(
model_args.dim, model_args.vocab_size, dtype=torch.get_default_dtype()
Expand All @@ -333,7 +336,7 @@ def forward(self, tokens: torch.Tensor):
"""
h = self.tok_embeddings(tokens)

for layer in self.layers:
for layer in self.layers.values():
h = layer(h, self.freqs_cis)
h = self.norm(h)
output = self.output(h) # (batch_size, seq_len, dim)
Expand Down
4 changes: 3 additions & 1 deletion torchtitan/models/deepseek_v3/model/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

# shape (bs*slen*top_k, dim)
routed_output = self.experts(routed_input, num_local_tokens_per_expert)
routed_output = routed_output * top_scores.unsqueeze(-1)
routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious how come this is needed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Router computation is in fp32, so top_scores is in fp32.
This step is to make the score x activation computation in high precision, and then cast back.
Router precision in MoE seems critical for the training stability.

Copy link
Contributor Author

@wwwjn wwwjn Jun 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After applying FSDP, the routed_output at line 309 is bf16, and the top_scores is float32. If we don't explicitly convert dtype, the routed_output = routed_output * top_scores at line 310 will has dtype float32 (auto converted to high precision).

out = out.scatter_add(dim=0, index=token_indices, src=routed_output)

In this line, the out is bf16, as we applied FSDP. So I added this explicit dtype conversion following llama4

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanations!!

x.dtype
)

# shared expert
if self.shared_expert is not None:
Expand Down
67 changes: 67 additions & 0 deletions torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to use more realistic config, but can revisit later.

Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# torchtitan Config.toml

[job]
dump_folder = "./outputs"
description = "DeepSeek-V3 16B model training"
print_args = false

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 10
enable_memory_snapshot = false
save_memory_snapshot_folder = "memory_snapshot"

[metrics]
log_freq = 1
disable_color_printing = false
enable_tensorboard = false
save_tb_folder = "tb"
enable_wandb = false

[model]
name = "deepseek_v3"
flavor = "16B"
# test tokenizer.model, for debug purpose only
tokenizer_path = "./tests/assets/test_tiktoken.model"
# converters = ["float8"]

[optimizer]
name = "AdamW"
lr = 8e-4
eps = 1e-8

[lr_scheduler]
warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
decay_type = "linear"
lr_min = 0.0

[training]
local_batch_size = 32
seq_len = 2048
max_norm = 1.0 # grad norm clipping
steps = 10
compile = false
dataset = "c4" # supported datasets: c4_test (2K), c4 (177M)

[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
fsdp_reshard_after_forward = "default" # default / never / always

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval = 10
last_save_model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]"

[activation_checkpoint]
mode = "full" # ["none", "selective", "full"]

[float8]
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
filter_fqns = ["output"]
Loading