Skip to content
1 change: 1 addition & 0 deletions torchtitan/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@

# Import the built-in models here so that the corresponding register_model_spec()
# will be called.
import torchtitan.models.deepseek_v3 # noqa: F401
import torchtitan.models.llama3 # noqa: F401
54 changes: 54 additions & 0 deletions torchtitan/models/deepseek_v3/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# DeepSeek-V3 in TorchTitan

DeepSeek-V3 is a Mixture-of-Experts (MoE) transformer model with Multi-head Latent Attention (MLA) architecture.

## Setup

### Download Tokenizer

```bash
# DeepSeek tokenizer (automatically downloads tokenizer.json and tokenizer_config.json)
python scripts/download_tokenizer.py --repo_id deepseek-ai/DeepSeek-V3
```

## Training

### Debug Training

```bash
# Quick debug run with small model
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh
```

### Full Model Training

```bash
# 16B parameter model: adapted from older 16B parameter model from https://huggingface.co/deepseek-ai/deepseek-moe-16b-base
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh
```

```bash
# 671B parameter model
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml" ./run_train.sh
```


## Supported Features
- FSDP, HSDP
- Activation checkpointing
- Tensor Parallel (TP)
- Expert Parallel (EP)


## To be added
- Modeling
- Merge DeepSeek-V3 and Llama4 MoE common components
- Attention Layer: need to pass softmax_scale to sdpa() to support scaling
- Parallelism
- Context Parallel support for DeepSeek-V3
- PP support for DeepSeek-V3
- torch.compile
- Quantization
- Testing
- perfomance and loss converging tests
- CI integration
126 changes: 126 additions & 0 deletions torchtitan/models/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from torchtitan.components.loss import build_cross_entropy_loss
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.tokenizer import build_hf_tokenizer
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.experiments.llama4.optimizer import build_llama4_optimizers

from torchtitan.protocols.train_spec import register_train_spec, TrainSpec

from .infra.parallelize import parallelize_deepseekv3
from .model.args import DeepSeekV3ModelArgs
from .model.model import DeepSeekV3Model

__all__ = [
"parallelize_deepseekv3",
"DeepseekV3ModelArgs",
"DeepseekV3Model",
"deepseekv3_configs",
]


deepseekv3_configs = {
"debugmodel": DeepSeekV3ModelArgs(
vocab_size=102400,
dim=256,
inter_dim=1024,
moe_inter_dim=256,
n_layers=3,
n_dense_layers=1,
n_heads=16,
n_routed_experts=8,
n_shared_experts=2,
n_activated_experts=3,
route_scale=1.0,
q_lora_rank=0,
kv_lora_rank=512,
qk_nope_head_dim=128,
qk_rope_head_dim=64,
v_head_dim=128,
mscale=0.70,
),
"16B": DeepSeekV3ModelArgs(
vocab_size=102400,
dim=2048,
inter_dim=10944,
moe_inter_dim=1408,
n_layers=27,
n_dense_layers=1,
n_heads=16,
n_routed_experts=64,
n_shared_experts=2,
n_activated_experts=6,
route_scale=1.0,
q_lora_rank=0,
kv_lora_rank=512,
qk_nope_head_dim=128,
qk_rope_head_dim=64,
v_head_dim=128,
mscale=0.70,
),
"236B": DeepSeekV3ModelArgs(
vocab_size=102400,
dim=5120,
inter_dim=12288,
moe_inter_dim=1536,
n_layers=60,
n_dense_layers=1,
n_heads=128,
n_routed_experts=160,
n_shared_experts=2,
n_activated_experts=6,
n_expert_groups=8,
n_limited_groups=3,
route_scale=16.0,
q_lora_rank=1536,
kv_lora_rank=512,
qk_nope_head_dim=128,
qk_rope_head_dim=64,
v_head_dim=128,
),
"671B": DeepSeekV3ModelArgs(
vocab_size=129280,
dim=7168,
inter_dim=18432,
moe_inter_dim=2048,
n_layers=61,
n_dense_layers=3,
n_heads=128,
n_routed_experts=256,
n_shared_experts=1,
n_activated_experts=8,
n_expert_groups=8,
n_limited_groups=4,
route_scale=2.5,
score_func="sigmoid",
q_lora_rank=1536,
kv_lora_rank=512,
qk_nope_head_dim=128,
qk_rope_head_dim=64,
v_head_dim=128,
dtype="fp8",
),
}


register_train_spec(
TrainSpec(
name="deepseek_v3",
cls=DeepSeekV3Model,
config=deepseekv3_configs,
parallelize_fn=parallelize_deepseekv3,
pipelining_fn=None,
build_optimizers_fn=build_llama4_optimizers, # use optimizer hooks to update expert weights
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
build_tokenizer_fn=build_hf_tokenizer,
build_loss_fn=build_cross_entropy_loss,
)
)
Loading
Loading