Skip to content

Commit 6e37735

Browse files
committed
add support for simplefsdp+ep
1 parent 0f34257 commit 6e37735

File tree

11 files changed

+399
-82
lines changed

11 files changed

+399
-82
lines changed

.github/workflows/integration_test_8gpu_simple_fsdp.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,4 @@ jobs:
4747
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126
4848
4949
mkdir artifacts-to-be-uploaded
50-
python -m torchtitan.experiments.simple_fsdp.tests.integration_tests artifacts-to-be-uploaded --ngpu 8
50+
python -m torchtitan.experiments.simple_fsdp.tests.llama3_integration_tests artifacts-to-be-uploaded --ngpu 8

torchtitan/experiments/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
_supported_experiments = frozenset(["flux", "llama4", "qwen3", "simple_fsdp", "vlm"])
7+
_supported_experiments = frozenset(["flux", "llama4", "qwen3", "simple_fsdp.llama3", "simple_fsdp.deepseek_v3", "vlm"])

torchtitan/experiments/simple_fsdp/README.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,18 @@ This folder includes an experimental frontend implementation for [SimpleFSDP: Si
1212

1313
### Run SimpleFSDP Training on Llama 3
1414

15+
#### Training Llama3 models
16+
1517
```bash
1618
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.name simple_fsdp --compile.enable
1719
```
1820

21+
#### Training DeepSeek_v3 models
22+
23+
```bash
24+
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name simple_fsdp --compile.enable
25+
```
26+
1927
### Composability Support
2028

2129
Some of the features require the updates from PyTorch, with which we are working on providing composability support for the following features:
@@ -30,7 +38,9 @@ Some of the features require the updates from PyTorch, with which we are working
3038
|Pipeline Parallelism||
3139
|Distributed Checkpointing||
3240
|Float8 Training| 🚧 |
33-
41+
|Expert Parallelism ||
42+
|Expert Parallelism + Activation Checkpointing| 🚧 |
43+
|Expert Parallelism + DualPipe| 🚧 |
3444

3545
### Citation
3646

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8+
9+
from torchtitan.components.loss import build_cross_entropy_loss
10+
from torchtitan.components.lr_scheduler import build_lr_schedulers
11+
from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing
12+
from torchtitan.components.tokenizer import build_hf_tokenizer
13+
from torchtitan.datasets.hf_datasets import build_hf_dataloader
14+
from torchtitan.models.deepseek_v3 import deepseekv3_configs
15+
from torchtitan.models.llama3 import pipeline_llama
16+
from torchtitan.protocols.train_spec import TrainSpec
17+
from .deepseek_v3_model import SimpleFSDPDeepSeekV3Model
18+
from .deepseek_v3_parallelize import parallelize_deepseekv3
19+
20+
21+
def get_train_spec() -> TrainSpec:
22+
return TrainSpec(
23+
name="simple_fsdp.deepseek_v3",
24+
model_cls=SimpleFSDPDeepSeekV3Model,
25+
model_args=deepseekv3_configs,
26+
parallelize_fn=parallelize_deepseekv3,
27+
pipelining_fn=pipeline_llama,
28+
build_optimizers_fn=build_optimizers_with_moe_load_balancing,
29+
build_lr_schedulers_fn=build_lr_schedulers,
30+
build_dataloader_fn=build_hf_dataloader,
31+
build_tokenizer_fn=build_hf_tokenizer,
32+
build_loss_fn=build_cross_entropy_loss,
33+
)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from torchtitan.models.deepseek_v3 import DeepSeekV3Model, DeepSeekV3ModelArgs
8+
from ..simple_fsdp import enable_active_parametrization
9+
10+
11+
class SimpleFSDPDeepSeekV3Model(DeepSeekV3Model):
12+
def __init__(self, model_args: DeepSeekV3ModelArgs):
13+
super().__init__(model_args)
14+
self.init_weights()
15+
16+
def init_weights(self, *args, **kwargs):
17+
with enable_active_parametrization():
18+
super().init_weights(*args, **kwargs)
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
import torch.nn as nn
9+
from torch.distributed.device_mesh import DeviceMesh
10+
11+
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
12+
from torchtitan.distributed import ParallelDims
13+
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
14+
from torchtitan.experiments.llama4.infra.parallelize import apply_moe_ep_tp
15+
from torchtitan.models.deepseek_v3.infra.parallelize import apply_non_moe_tp
16+
from torchtitan.models.llama3.infra.parallelize import apply_ac
17+
from torchtitan.tools.logging import logger
18+
19+
from ..simple_fsdp import data_parallel, MixedPrecisionPolicy
20+
21+
# Adapted from llama4/infra/parallelize.py
22+
def parallelize_deepseekv3(
23+
model: nn.Module,
24+
parallel_dims: ParallelDims,
25+
job_config: JobConfig,
26+
):
27+
world_mesh = parallel_dims.world_mesh
28+
# TODO: TP currently cannot handle uneven seq_len because we set
29+
# `use_local_output=True` to use plain Tensors for legacy reasons.
30+
# Need to revisit this.
31+
assert (
32+
job_config.training.seq_len % parallel_dims.seq_len_divisor == 0
33+
), f"""
34+
Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree
35+
({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}), i.e. {parallel_dims.seq_len_divisor}.
36+
"""
37+
38+
if (
39+
job_config.parallelism.context_parallel_degree > 1
40+
and model.model_args.use_flex_attn
41+
):
42+
raise NotImplementedError("CP support for FlexAttention is still in progress.")
43+
44+
if parallel_dims.tp_enabled:
45+
if job_config.parallelism.enable_async_tensor_parallel:
46+
# TODO(jianiw): This branch needs to be tested and enabled
47+
raise NotImplementedError(
48+
"Currently, async TP is not tested for deepseekv3. \
49+
torch.compile is not supported yet, which is required for async TP."
50+
)
51+
52+
enable_float8_linear = "float8" in job_config.model.converters
53+
float8_is_rowwise = job_config.float8.recipe_name in (
54+
"rowwise",
55+
"rowwise_with_gw_hp",
56+
)
57+
58+
enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
59+
if enable_float8_tensorwise_tp:
60+
# TODO(jianiw): This branch needs to be tested and enabled
61+
raise NotImplementedError(
62+
"Currently, float8 tensorwise TP is not tested for deepseekv3"
63+
)
64+
65+
apply_non_moe_tp(
66+
model,
67+
world_mesh["tp"],
68+
loss_parallel=not job_config.parallelism.disable_loss_parallel,
69+
enable_float8_tensorwise_tp=False,
70+
)
71+
maybe_enable_async_tp(job_config, world_mesh["tp"])
72+
73+
if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
74+
apply_moe_ep_tp(
75+
model,
76+
tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None,
77+
ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None,
78+
ep_tp_mesh=(
79+
world_mesh["ep", "tp"]
80+
if parallel_dims.tp_enabled
81+
and parallel_dims.ep_enabled
82+
and parallel_dims.etp_enabled
83+
else None
84+
),
85+
etp_enabled=parallel_dims.etp_enabled,
86+
)
87+
88+
if job_config.activation_checkpoint.mode != "none":
89+
apply_ac(model, job_config.activation_checkpoint)
90+
91+
mp_policy = MixedPrecisionPolicy(
92+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
93+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
94+
)
95+
96+
# apply data parallel
97+
dp_mesh: DeviceMesh | None = None
98+
if (
99+
parallel_dims.fsdp_enabled
100+
or parallel_dims.ep_enabled
101+
or parallel_dims.dp_replicate_enabled
102+
):
103+
if parallel_dims.dp_replicate_enabled:
104+
if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled:
105+
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
106+
dp_mode = "hybrid_shard"
107+
else:
108+
dp_mesh_dim_names = ("dp_replicate",)
109+
dp_mode = "replicate"
110+
else:
111+
dp_mesh_dim_names = ("dp_shard_cp",)
112+
dp_mode = "fully_shard"
113+
114+
dp_mesh = world_mesh[tuple(dp_mesh_dim_names)]
115+
# the mesh dim names of which the MoE params are sharded on via FSDP/HSDP
116+
dp_mod_ep_mesh_dim_names = []
117+
118+
if parallel_dims.ep_enabled:
119+
if parallel_dims.dp_replicate_enabled:
120+
dp_mod_ep_mesh_dim_names.append("dp_replicate")
121+
dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep")
122+
123+
# allow users to parse their ep/tp mesh_dim_names to simple_fsdp
124+
dp_tp_ep_mesh_name = {"dp": "dp_shard", "ep": "", "tp": ""}
125+
if parallel_dims.ep_enabled and not parallel_dims.tp_enabled:
126+
dp_tp_ep_mesh_name["ep"] = "ep"
127+
elif not parallel_dims.ep_enabled and parallel_dims.tp_enabled:
128+
dp_tp_ep_mesh_name["tp"] = "tp"
129+
elif parallel_dims.ep_enabled and parallel_dims.tp_enabled:
130+
dp_tp_ep_mesh_name["ep"], dp_tp_ep_mesh_name["tp"] = "ep", "tp"
131+
shared_expert_dp_shard_dim = 0
132+
133+
for _, transformer_block in model.layers.items():
134+
if transformer_block.moe_enabled and parallel_dims.ep_enabled:
135+
if (
136+
world_mesh[tuple(dp_mod_ep_mesh_dim_names)].size() * parallel_dims.ep
137+
> transformer_block.moe.experts.num_experts
138+
):
139+
shared_expert_dp_shard_dim = 1
140+
141+
transformer_block.moe.experts = data_parallel(
142+
transformer_block.moe.experts,
143+
world_mesh[tuple(dp_mod_ep_mesh_dim_names)],
144+
dp_mode,
145+
ac_mode=job_config.activation_checkpoint.mode,
146+
mp_policy=mp_policy,
147+
dp_tp_ep_mesh_name=dp_tp_ep_mesh_name,
148+
dp_shard_dim=shared_expert_dp_shard_dim,
149+
)
150+
transformer_block.moe.shared_experts = data_parallel(
151+
transformer_block.moe.shared_experts,
152+
dp_mesh,
153+
dp_mode,
154+
ac_mode=job_config.activation_checkpoint.mode,
155+
mp_policy=mp_policy,
156+
dp_tp_ep_mesh_name=dp_tp_ep_mesh_name,
157+
)
158+
# TODO(ruisizhang123): support set_gradient_divide_factor in simplefsdp
159+
# transformer_block.moe.experts.set_gradient_divide_factor(
160+
# parallel_dims.fsdp_gradient_divide_factor,
161+
# )
162+
163+
model = data_parallel(
164+
model,
165+
dp_mesh,
166+
dp_mode,
167+
ac_mode=job_config.activation_checkpoint.mode,
168+
mp_policy=mp_policy,
169+
dp_tp_ep_mesh_name=dp_tp_ep_mesh_name,
170+
)
171+
172+
logger.info(
173+
"Applied Data Parallel (simple_fsdp) (dp mode=%s) to the model", dp_mode
174+
)
175+
176+
if job_config.compile.enable:
177+
torch._inductor.config.reorder_for_peak_memory = False
178+
torch._dynamo.config.capture_scalar_outputs = True
179+
model = torch.compile(model, fullgraph=True)
180+
181+
return model

torchtitan/experiments/simple_fsdp/__init__.py renamed to torchtitan/experiments/simple_fsdp/llama3/__init__.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,13 @@
1313
from torchtitan.datasets.hf_datasets import build_hf_dataloader
1414
from torchtitan.models.llama3 import llama3_configs, pipeline_llama
1515
from torchtitan.protocols.train_spec import TrainSpec
16-
17-
from .model import SimpleFSDPTransformer
18-
from .parallelize import parallelize_llama
16+
from .llama3_model import SimpleFSDPTransformer
17+
from .llama3_parallelize import parallelize_llama
1918

2019

2120
def get_train_spec() -> TrainSpec:
22-
return TrainSpec(
23-
name="simple_fsdp",
21+
return TrainSpec(
22+
name="simple_fsdp.llama3",
2423
model_cls=SimpleFSDPTransformer,
2524
model_args=llama3_configs,
2625
parallelize_fn=parallelize_llama,

torchtitan/experiments/simple_fsdp/model.py renamed to torchtitan/experiments/simple_fsdp/llama3/llama3_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from torchtitan.models.llama3 import Transformer, TransformerModelArgs
8-
from .simple_fsdp import disable_data_parallel
8+
from ..simple_fsdp import enable_active_parametrization
99

1010

1111
class SimpleFSDPTransformer(Transformer):
1212
def __init__(self, model_args: TransformerModelArgs):
1313
super().__init__(model_args)
1414

1515
def init_weights(self, *args, **kwargs):
16-
with disable_data_parallel():
16+
with enable_active_parametrization():
1717
super().init_weights(*args, **kwargs)

torchtitan/experiments/simple_fsdp/parallelize.py renamed to torchtitan/experiments/simple_fsdp/llama3/llama3_parallelize.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from torchtitan.models.llama3.infra.parallelize import apply_tp
1515
from torchtitan.tools.logging import logger
1616

17-
from .simple_fsdp import data_parallel, MixedPrecisionPolicy
17+
from ..simple_fsdp import data_parallel, MixedPrecisionPolicy
1818

1919

2020
# for selective op activation checkpointing
@@ -116,7 +116,9 @@ def parallelize_llama(
116116
ac_mode=job_config.activation_checkpoint.mode,
117117
mp_policy=mp_policy,
118118
)
119-
logger.info("Applied Data Parallel (dp mode=%s) to the model", dp_mode)
119+
logger.info(
120+
"Applied Data Parallel (simple_fsdp) (dp mode=%s) to the model", dp_mode
121+
)
120122

121123
if job_config.compile.enable and "model" in job_config.compile.components:
122124
torch._inductor.config.reorder_for_peak_memory = False

0 commit comments

Comments
 (0)