Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion torchtitan/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

_supported_experiments = frozenset(["flux", "llama4", "qwen3", "simple_fsdp", "vlm"])
_supported_experiments = frozenset(
["flux", "llama4", "qwen3", "simple_fsdp.llama3", "simple_fsdp.deepseek_v3", "vlm"]
)
14 changes: 12 additions & 2 deletions torchtitan/experiments/simple_fsdp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,16 @@ This folder includes an experimental frontend implementation for [SimpleFSDP: Si

### Run SimpleFSDP Training on Llama 3

#### Training Llama3 models

```bash
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.name simple_fsdp.llama3 --compile.enable
```

#### Training DeepSeek_v3 models

```bash
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.name simple_fsdp --compile.enable
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name simple_fsdp.deepseek_v3 --compile.enable
```

### Composability Support
Expand All @@ -30,7 +38,9 @@ Some of the features require the updates from PyTorch, with which we are working
|Pipeline Parallelism| ✅ |
|Distributed Checkpointing| ✅ |
|Float8 Training| 🚧 |

|Expert Parallelism | ✅ |
|Expert Parallelism + Activation Checkpointing| 🚧 |
|Expert Parallelism + Pipeline Parallelism| 🚧 |

### Citation

Expand Down
34 changes: 34 additions & 0 deletions torchtitan/experiments/simple_fsdp/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from torchtitan.components.loss import build_cross_entropy_loss
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing
from torchtitan.components.tokenizer import build_hf_tokenizer
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.models.deepseek_v3 import deepseekv3_configs
from torchtitan.models.llama3 import pipeline_llama
from torchtitan.protocols.train_spec import TrainSpec

from .model import SimpleFSDPDeepSeekV3Model
from .parallelize import parallelize_deepseekv3


def get_train_spec() -> TrainSpec:
return TrainSpec(
name="simple_fsdp.deepseek_v3",
model_cls=SimpleFSDPDeepSeekV3Model,
model_args=deepseekv3_configs,
parallelize_fn=parallelize_deepseekv3,
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers_with_moe_load_balancing,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
build_tokenizer_fn=build_hf_tokenizer,
build_loss_fn=build_cross_entropy_loss,
)
19 changes: 19 additions & 0 deletions torchtitan/experiments/simple_fsdp/deepseek_v3/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.models.deepseek_v3 import DeepSeekV3Model, DeepSeekV3ModelArgs

from ..simple_fsdp import disable_active_parametrization


class SimpleFSDPDeepSeekV3Model(DeepSeekV3Model):
def __init__(self, model_args: DeepSeekV3ModelArgs):
super().__init__(model_args)
self.init_weights()

def init_weights(self, *args, **kwargs):
with disable_active_parametrization():
super().init_weights(*args, **kwargs)
158 changes: 158 additions & 0 deletions torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
from torch.distributed.device_mesh import DeviceMesh

from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
from torchtitan.distributed import ParallelDims
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
from torchtitan.experiments.llama4.infra.parallelize import apply_moe_ep_tp
from torchtitan.models.deepseek_v3.infra.parallelize import apply_non_moe_tp
from torchtitan.models.llama3.infra.parallelize import apply_ac
from torchtitan.tools.logging import logger

from ..simple_fsdp import data_parallel, MixedPrecisionPolicy

# Adapted from llama4/infra/parallelize.py
def parallelize_deepseekv3(
model: nn.Module,
parallel_dims: ParallelDims,
job_config: JobConfig,
):
world_mesh = parallel_dims.world_mesh
# TODO: TP currently cannot handle uneven seq_len because we set
# `use_local_output=True` to use plain Tensors for legacy reasons.
# Need to revisit this.
assert (
job_config.training.seq_len % parallel_dims.seq_len_divisor == 0
), f"""
Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree
({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}), i.e. {parallel_dims.seq_len_divisor}.
"""

if (
job_config.parallelism.context_parallel_degree > 1
and model.model_args.use_flex_attn
):
raise NotImplementedError("CP support for FlexAttention is still in progress.")

if parallel_dims.tp_enabled:
enable_float8_linear = "float8" in job_config.model.converters
float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in (
"rowwise",
"rowwise_with_gw_hp",
)

enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
if enable_float8_tensorwise_tp:
# TODO(jianiw): This branch needs to be tested and enabled
raise NotImplementedError(
"Currently, float8 tensorwise TP is not tested for deepseekv3"
)

apply_non_moe_tp(
model,
world_mesh["tp"],
loss_parallel=not job_config.parallelism.disable_loss_parallel,
enable_float8_tensorwise_tp=False,
)
maybe_enable_async_tp(job_config, world_mesh["tp"])

if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
apply_moe_ep_tp(
model,
tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None,
ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None,
ep_tp_mesh=(
world_mesh["ep", "tp"]
if parallel_dims.tp_enabled
and parallel_dims.ep_enabled
and parallel_dims.etp_enabled
else None
),
etp_enabled=parallel_dims.etp_enabled,
)

if job_config.activation_checkpoint.mode != "none":
apply_ac(model, job_config.activation_checkpoint)

mp_policy = MixedPrecisionPolicy(
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
)

# apply data parallel
dp_mesh: DeviceMesh | None = None
if (
parallel_dims.fsdp_enabled
or parallel_dims.ep_enabled
or parallel_dims.dp_replicate_enabled
):
if parallel_dims.dp_replicate_enabled:
if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled:
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
dp_mode = "hybrid_shard"
else:
dp_mesh_dim_names = ("dp_replicate",)
dp_mode = "replicate"
else:
dp_mesh_dim_names = ("dp_shard_cp",)
dp_mode = "fully_shard"

dp_mesh = world_mesh[tuple(dp_mesh_dim_names)]
# the mesh dim names of which the MoE params are sharded on via FSDP/HSDP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are the MoE params actually 'sharded' on dp_replicate (borrowing the replicate dim for further sharding), or are they replicated as the name would suggest?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied this line from deepseek_v3 file.

I would also want to learn this more from @tianyu-l. My basic understanding (maybe wrong) is: it depends on the dp_dim left after EP borrowing. If the left dp_dim is shard, then MoE parameters are shard; if it's replicated, then the parameters are replicated. But I'm curious if you have EP=2 and HSDP(dp_shard=2 and dp_replicate=2), the EP would first borrow dp_shard dim or dp_replicate dim? @tianyu-l lolll

dp_mod_ep_mesh_dim_names = []

if parallel_dims.ep_enabled:
if parallel_dims.dp_replicate_enabled:
dp_mod_ep_mesh_dim_names.append("dp_replicate")
dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep")
dp_mod_ep_mesh = world_mesh[tuple(dp_mod_ep_mesh_dim_names)]

for _, transformer_block in model.layers.items():
if transformer_block.moe_enabled and parallel_dims.ep_enabled:
experts_shard_dim = 0
assert dp_mod_ep_mesh is not None
assert hasattr(transformer_block, "moe")
if (
dp_mod_ep_mesh.size() * parallel_dims.ep
> transformer_block.moe.experts.num_experts
):
experts_shard_dim = 1

transformer_block.moe.experts = data_parallel(
transformer_block.moe.experts,
dp_mod_ep_mesh,
dp_mode,
ac_mode=job_config.activation_checkpoint.mode,
mp_policy=mp_policy,
shard_dim=experts_shard_dim,
)
# TODO(ruisizhang123): support set_gradient_divide_factor in simplefsdp
# transformer_block.moe.experts.set_gradient_divide_factor(
# parallel_dims.fsdp_gradient_divide_factor,
# )

model = data_parallel(
model,
dp_mesh,
dp_mode,
ac_mode=job_config.activation_checkpoint.mode,
mp_policy=mp_policy,
)

logger.info(
"Applied Data Parallel (simple_fsdp) (dp mode=%s) to the model", dp_mode
)

if job_config.compile.enable:
torch._inductor.config.reorder_for_peak_memory = False
torch._dynamo.config.capture_scalar_outputs = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last time I heard from @xmfan that there are correctness concern for this field, although it gives us full graph?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is to ensure tolist() in A2A can be traced into the graph, otherwise it would cause graph breaks. It reads the items from a list by treating it as a tensor, and readout as .item(). Not sure if there would be correctness concerns @xmfan.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There shouldn't be. Bother @bobrenjc93 and @laithsakka if there are

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so flipping those flags should NOT ""effect correctness"" but lets define correctness here.

flipping those flags, would opt into unbacked semantics for some unbacked dependent ops (but thats the only way you can trace through those). With unbacked semantics some behaviors could deviate from eager not in a very harmful manner, usually side effects are different output strides, or clones happening.

For example a reshape that depends on unbacked symbols(outputs of .tolist() could result in a clone changing the the output strides of the reshape; had those symbols been known and a reshape was translated to view strides could have been different).

model = torch.compile(model, fullgraph=True)

return model
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

def get_train_spec() -> TrainSpec:
return TrainSpec(
name="simple_fsdp",
name="simple_fsdp.llama3",
model_cls=SimpleFSDPTransformer,
model_args=llama3_configs,
parallelize_fn=parallelize_llama,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
# LICENSE file in the root directory of this source tree.

from torchtitan.models.llama3 import Transformer, TransformerModelArgs
from .simple_fsdp import disable_data_parallel

from ..simple_fsdp import disable_active_parametrization


class SimpleFSDPTransformer(Transformer):
def __init__(self, model_args: TransformerModelArgs):
super().__init__(model_args)

def init_weights(self, *args, **kwargs):
with disable_data_parallel():
with disable_active_parametrization():
super().init_weights(*args, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torchtitan.models.llama3.infra.parallelize import apply_tp
from torchtitan.tools.logging import logger

from .simple_fsdp import data_parallel, MixedPrecisionPolicy
from ..simple_fsdp import data_parallel, MixedPrecisionPolicy


# for selective op activation checkpointing
Expand Down Expand Up @@ -116,7 +116,9 @@ def parallelize_llama(
ac_mode=job_config.activation_checkpoint.mode,
mp_policy=mp_policy,
)
logger.info("Applied Data Parallel (dp mode=%s) to the model", dp_mode)
logger.info(
"Applied Data Parallel (simple_fsdp) (dp mode=%s) to the model", dp_mode
)

if job_config.compile.enable and "model" in job_config.compile.components:
torch._inductor.config.reorder_for_peak_memory = False
Expand Down
Loading