From c28ddd881319959b7d5f62c13407830a69e0c728 Mon Sep 17 00:00:00 2001 From: ruisizhang123 Date: Fri, 1 Aug 2025 14:09:48 -0700 Subject: [PATCH] add support for simplefsdp+ep --- torchtitan/experiments/__init__.py | 4 +- torchtitan/experiments/simple_fsdp/README.md | 14 +- .../simple_fsdp/deepseek_v3/__init__.py | 34 ++++ .../simple_fsdp/deepseek_v3/model.py | 19 +++ .../simple_fsdp/deepseek_v3/parallelize.py | 158 ++++++++++++++++++ .../simple_fsdp/{ => llama3}/__init__.py | 2 +- .../simple_fsdp/{ => llama3}/model.py | 5 +- .../simple_fsdp/{ => llama3}/parallelize.py | 6 +- .../experiments/simple_fsdp/simple_fsdp.py | 95 +++++++---- .../simple_fsdp/tests/integration_tests.py | 74 ++++++-- 10 files changed, 349 insertions(+), 62 deletions(-) create mode 100644 torchtitan/experiments/simple_fsdp/deepseek_v3/__init__.py create mode 100644 torchtitan/experiments/simple_fsdp/deepseek_v3/model.py create mode 100644 torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py rename torchtitan/experiments/simple_fsdp/{ => llama3}/__init__.py (97%) rename torchtitan/experiments/simple_fsdp/{ => llama3}/model.py (83%) rename torchtitan/experiments/simple_fsdp/{ => llama3}/parallelize.py (96%) diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index 74c7eaec9b..0008238039 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -4,4 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -_supported_experiments = frozenset(["flux", "llama4", "qwen3", "simple_fsdp", "vlm"]) +_supported_experiments = frozenset( + ["flux", "llama4", "qwen3", "simple_fsdp.llama3", "simple_fsdp.deepseek_v3", "vlm"] +) diff --git a/torchtitan/experiments/simple_fsdp/README.md b/torchtitan/experiments/simple_fsdp/README.md index 05d7425293..746d8df875 100644 --- a/torchtitan/experiments/simple_fsdp/README.md +++ b/torchtitan/experiments/simple_fsdp/README.md @@ -12,8 +12,16 @@ This folder includes an experimental frontend implementation for [SimpleFSDP: Si ### Run SimpleFSDP Training on Llama 3 +#### Training Llama3 models + +```bash +CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.name simple_fsdp.llama3 --compile.enable +``` + +#### Training DeepSeek_v3 models + ```bash -CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.name simple_fsdp --compile.enable +CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name simple_fsdp.deepseek_v3 --compile.enable ``` ### Composability Support @@ -30,7 +38,9 @@ Some of the features require the updates from PyTorch, with which we are working |Pipeline Parallelism| ✅ | |Distributed Checkpointing| ✅ | |Float8 Training| 🚧 | - +|Expert Parallelism | ✅ | +|Expert Parallelism + Activation Checkpointing| 🚧 | +|Expert Parallelism + Pipeline Parallelism| 🚧 | ### Citation diff --git a/torchtitan/experiments/simple_fsdp/deepseek_v3/__init__.py b/torchtitan/experiments/simple_fsdp/deepseek_v3/__init__.py new file mode 100644 index 0000000000..a9d2f34abe --- /dev/null +++ b/torchtitan/experiments/simple_fsdp/deepseek_v3/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.datasets.hf_datasets import build_hf_dataloader +from torchtitan.models.deepseek_v3 import deepseekv3_configs +from torchtitan.models.llama3 import pipeline_llama +from torchtitan.protocols.train_spec import TrainSpec + +from .model import SimpleFSDPDeepSeekV3Model +from .parallelize import parallelize_deepseekv3 + + +def get_train_spec() -> TrainSpec: + return TrainSpec( + name="simple_fsdp.deepseek_v3", + model_cls=SimpleFSDPDeepSeekV3Model, + model_args=deepseekv3_configs, + parallelize_fn=parallelize_deepseekv3, + pipelining_fn=pipeline_llama, + build_optimizers_fn=build_optimizers_with_moe_load_balancing, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_hf_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + ) diff --git a/torchtitan/experiments/simple_fsdp/deepseek_v3/model.py b/torchtitan/experiments/simple_fsdp/deepseek_v3/model.py new file mode 100644 index 0000000000..83c9fde561 --- /dev/null +++ b/torchtitan/experiments/simple_fsdp/deepseek_v3/model.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.models.deepseek_v3 import DeepSeekV3Model, DeepSeekV3ModelArgs + +from ..simple_fsdp import disable_active_parametrization + + +class SimpleFSDPDeepSeekV3Model(DeepSeekV3Model): + def __init__(self, model_args: DeepSeekV3ModelArgs): + super().__init__(model_args) + self.init_weights() + + def init_weights(self, *args, **kwargs): + with disable_active_parametrization(): + super().init_weights(*args, **kwargs) diff --git a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py new file mode 100644 index 0000000000..7d71c8aeaf --- /dev/null +++ b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py @@ -0,0 +1,158 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh + +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP +from torchtitan.distributed import ParallelDims +from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp +from torchtitan.experiments.llama4.infra.parallelize import apply_moe_ep_tp +from torchtitan.models.deepseek_v3.infra.parallelize import apply_non_moe_tp +from torchtitan.models.llama3.infra.parallelize import apply_ac +from torchtitan.tools.logging import logger + +from ..simple_fsdp import data_parallel, MixedPrecisionPolicy + +# Adapted from llama4/infra/parallelize.py +def parallelize_deepseekv3( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + world_mesh = parallel_dims.world_mesh + # TODO: TP currently cannot handle uneven seq_len because we set + # `use_local_output=True` to use plain Tensors for legacy reasons. + # Need to revisit this. + assert ( + job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 + ), f""" + Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree + ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}), i.e. {parallel_dims.seq_len_divisor}. + """ + + if ( + job_config.parallelism.context_parallel_degree > 1 + and model.model_args.use_flex_attn + ): + raise NotImplementedError("CP support for FlexAttention is still in progress.") + + if parallel_dims.tp_enabled: + enable_float8_linear = "float8" in job_config.model.converters + float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in ( + "rowwise", + "rowwise_with_gw_hp", + ) + + enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + if enable_float8_tensorwise_tp: + # TODO(jianiw): This branch needs to be tested and enabled + raise NotImplementedError( + "Currently, float8 tensorwise TP is not tested for deepseekv3" + ) + + apply_non_moe_tp( + model, + world_mesh["tp"], + loss_parallel=not job_config.parallelism.disable_loss_parallel, + enable_float8_tensorwise_tp=False, + ) + maybe_enable_async_tp(job_config, world_mesh["tp"]) + + if parallel_dims.tp_enabled or parallel_dims.ep_enabled: + apply_moe_ep_tp( + model, + tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, + ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, + ep_tp_mesh=( + world_mesh["ep", "tp"] + if parallel_dims.tp_enabled + and parallel_dims.ep_enabled + and parallel_dims.etp_enabled + else None + ), + etp_enabled=parallel_dims.etp_enabled, + ) + + if job_config.activation_checkpoint.mode != "none": + apply_ac(model, job_config.activation_checkpoint) + + mp_policy = MixedPrecisionPolicy( + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + ) + + # apply data parallel + dp_mesh: DeviceMesh | None = None + if ( + parallel_dims.fsdp_enabled + or parallel_dims.ep_enabled + or parallel_dims.dp_replicate_enabled + ): + if parallel_dims.dp_replicate_enabled: + if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + dp_mode = "hybrid_shard" + else: + dp_mesh_dim_names = ("dp_replicate",) + dp_mode = "replicate" + else: + dp_mesh_dim_names = ("dp_shard_cp",) + dp_mode = "fully_shard" + + dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP + dp_mod_ep_mesh_dim_names = [] + + if parallel_dims.ep_enabled: + if parallel_dims.dp_replicate_enabled: + dp_mod_ep_mesh_dim_names.append("dp_replicate") + dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") + dp_mod_ep_mesh = world_mesh[tuple(dp_mod_ep_mesh_dim_names)] + + for _, transformer_block in model.layers.items(): + if transformer_block.moe_enabled and parallel_dims.ep_enabled: + experts_shard_dim = 0 + assert dp_mod_ep_mesh is not None + assert hasattr(transformer_block, "moe") + if ( + dp_mod_ep_mesh.size() * parallel_dims.ep + > transformer_block.moe.experts.num_experts + ): + experts_shard_dim = 1 + + transformer_block.moe.experts = data_parallel( + transformer_block.moe.experts, + dp_mod_ep_mesh, + dp_mode, + ac_mode=job_config.activation_checkpoint.mode, + mp_policy=mp_policy, + shard_dim=experts_shard_dim, + ) + # TODO(ruisizhang123): support set_gradient_divide_factor in simplefsdp + # transformer_block.moe.experts.set_gradient_divide_factor( + # parallel_dims.fsdp_gradient_divide_factor, + # ) + + model = data_parallel( + model, + dp_mesh, + dp_mode, + ac_mode=job_config.activation_checkpoint.mode, + mp_policy=mp_policy, + ) + + logger.info( + "Applied Data Parallel (simple_fsdp) (dp mode=%s) to the model", dp_mode + ) + + if job_config.compile.enable: + torch._inductor.config.reorder_for_peak_memory = False + torch._dynamo.config.capture_scalar_outputs = True + model = torch.compile(model, fullgraph=True) + + return model diff --git a/torchtitan/experiments/simple_fsdp/__init__.py b/torchtitan/experiments/simple_fsdp/llama3/__init__.py similarity index 97% rename from torchtitan/experiments/simple_fsdp/__init__.py rename to torchtitan/experiments/simple_fsdp/llama3/__init__.py index fbbaa867bf..86e5e180f7 100644 --- a/torchtitan/experiments/simple_fsdp/__init__.py +++ b/torchtitan/experiments/simple_fsdp/llama3/__init__.py @@ -20,7 +20,7 @@ def get_train_spec() -> TrainSpec: return TrainSpec( - name="simple_fsdp", + name="simple_fsdp.llama3", model_cls=SimpleFSDPTransformer, model_args=llama3_configs, parallelize_fn=parallelize_llama, diff --git a/torchtitan/experiments/simple_fsdp/model.py b/torchtitan/experiments/simple_fsdp/llama3/model.py similarity index 83% rename from torchtitan/experiments/simple_fsdp/model.py rename to torchtitan/experiments/simple_fsdp/llama3/model.py index f3edf76ec0..b0c11f9a44 100644 --- a/torchtitan/experiments/simple_fsdp/model.py +++ b/torchtitan/experiments/simple_fsdp/llama3/model.py @@ -5,7 +5,8 @@ # LICENSE file in the root directory of this source tree. from torchtitan.models.llama3 import Transformer, TransformerModelArgs -from .simple_fsdp import disable_data_parallel + +from ..simple_fsdp import disable_active_parametrization class SimpleFSDPTransformer(Transformer): @@ -13,5 +14,5 @@ def __init__(self, model_args: TransformerModelArgs): super().__init__(model_args) def init_weights(self, *args, **kwargs): - with disable_data_parallel(): + with disable_active_parametrization(): super().init_weights(*args, **kwargs) diff --git a/torchtitan/experiments/simple_fsdp/parallelize.py b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py similarity index 96% rename from torchtitan/experiments/simple_fsdp/parallelize.py rename to torchtitan/experiments/simple_fsdp/llama3/parallelize.py index 6ae9b4f6bf..1e3741c613 100644 --- a/torchtitan/experiments/simple_fsdp/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py @@ -14,7 +14,7 @@ from torchtitan.models.llama3.infra.parallelize import apply_tp from torchtitan.tools.logging import logger -from .simple_fsdp import data_parallel, MixedPrecisionPolicy +from ..simple_fsdp import data_parallel, MixedPrecisionPolicy # for selective op activation checkpointing @@ -116,7 +116,9 @@ def parallelize_llama( ac_mode=job_config.activation_checkpoint.mode, mp_policy=mp_policy, ) - logger.info("Applied Data Parallel (dp mode=%s) to the model", dp_mode) + logger.info( + "Applied Data Parallel (simple_fsdp) (dp mode=%s) to the model", dp_mode + ) if job_config.compile.enable and "model" in job_config.compile.components: torch._inductor.config.reorder_for_peak_memory = False diff --git a/torchtitan/experiments/simple_fsdp/simple_fsdp.py b/torchtitan/experiments/simple_fsdp/simple_fsdp.py index 38074a2844..8cb2a44730 100644 --- a/torchtitan/experiments/simple_fsdp/simple_fsdp.py +++ b/torchtitan/experiments/simple_fsdp/simple_fsdp.py @@ -34,7 +34,7 @@ @contextmanager -def disable_data_parallel(): +def disable_active_parametrization(): global _active_parametrization try: _active_parametrization = False @@ -52,12 +52,12 @@ class MixedPrecisionPolicy: def _distribute_dtensor( tensor: DTensor, device_mesh: DeviceMesh, - placements: Sequence[Placement], + dp_placements: Sequence[Placement], ) -> DTensor: """ Below are experimental enhancements to distribute a DTensor. - This helps enable Simple FSDP + TP, in which - inner spec/mesh is TP spec/mesh + This helps enable Simple FSDP + TP/EP, in which + inner spec/mesh is TP/EP spec/mesh outer spec/mesh is FSDP/DDP/HSDP spec/mesh The logic follows https://github.com/pytorch/pytorch/blob/main/torch/distributed/_composable/fsdp/_fsdp_param.py#L261 @@ -78,40 +78,41 @@ def _distribute_dtensor( submesh_names = outer_mesh.mesh_dim_names + inner_mesh.mesh_dim_names spanned_mesh = outer_global_mesh[submesh_names] - if len(placements) == 1: - assert placements[0].is_replicate() or placements[0].is_shard() - if placements[0].is_shard(): - # For FSDP + TP dtensor placement - shard_dim = placements[0].dim + if len(dp_placements) == 1: + assert dp_placements[0].is_replicate() or dp_placements[0].is_shard() + if dp_placements[0].is_shard(): + # For FSDP + EP/TP/EP+TP + assert len(inner_spec.placements) == 2 or len(inner_spec.placements) == 1 + shard_dim = dp_placements[0].dim split_factor = inner_spec.num_shards_map[shard_dim] tensor_placement = ( ( _StridedShard(shard_dim, split_factor=split_factor) if split_factor > 1 - else placements[0] + else dp_placements[0] ), - inner_spec.placements[0], - ) + ) + inner_spec.placements else: - # For DDP + TP dtensor placement - tensor_placement = (placements[0], inner_spec.placements[0]) - elif len(placements) == 2: - assert placements[0].is_replicate() and placements[1].is_shard() - # For HSDP + TP dtensor placement - shard_dim = placements[1].dim + # For DDP + TP/EP + assert len(inner_spec.placements) == 1 + tensor_placement = (dp_placements[0], inner_spec.placements[0]) + elif len(dp_placements) == 2: + assert dp_placements[0].is_replicate() and dp_placements[1].is_shard() + # For HSDP + EP/TP/EP+TP + assert len(inner_spec.placements) == 2 or len(inner_spec.placements) == 1 + shard_dim = dp_placements[1].dim split_factor = inner_spec.num_shards_map[shard_dim] tensor_placement = ( - placements[0], + dp_placements[0], ( _StridedShard(shard_dim, split_factor=split_factor) if split_factor > 1 - else placements[1] + else dp_placements[1] ), - inner_spec.placements[0], - ) + ) + inner_spec.placements else: raise ValueError( - f"Unsupported placement {placements} for distributing DTensor {tensor}" + f"Unsupported placement {dp_placements} for distributing DTensor {tensor}" ) current_spec = DTensorSpec( @@ -121,7 +122,7 @@ def _distribute_dtensor( ) target_spec = DTensorSpec( mesh=outer_mesh, - placements=(placements[-1],), + placements=(dp_placements[-1],), tensor_meta=inner_spec.tensor_meta, ) result_tensor = redistribute_local_tensor( @@ -157,7 +158,7 @@ def _register_parametrization( for param_name in param_names } module_cls = type( - f"FSDP{module.__class__.__name__}", + f"SimpleFSDP{module.__class__.__name__}", (module.__class__,), param_name_to_property, ) @@ -202,17 +203,19 @@ def __init__( mp_policy = mp_policy or MixedPrecisionPolicy() self.param_dtype = mp_policy.param_dtype self.reduce_dtype = mp_policy.reduce_dtype + self.ep_mesh_name, self.tp_mesh_name = "ep", "tp" def replicate_compute(self, x): # data parallel runtime replicate parameters and do local compute # the gradients are partial tensors that needs to perform reduction # (i.e. DDP: allreduce, FSDP: reduce_scatter, HSDP: mix of both) - - # support for FSDP/DDP/HSDP + TP (assuming TP shards the inner-most dim) - if x._spec.mesh.mesh_dim_names[-1] == "tp": - tp_placement = x._spec.placements[-1] - dp_mesh, tp_mesh = self.device_mesh, x._spec.mesh["tp"] - + # support FSDP/DDP/HSDP + EP + TP (assuming TP shards the inner-most dim) + non_dp_mesh_dims = x._spec.mesh.ndim - self.device_mesh.ndim + assert non_dp_mesh_dims <= 2, "Only DP + EP/TP/EP+TP is supported" + if non_dp_mesh_dims > 0: + # TODO: remove tp_mesh as an input arg to data_parallel API and use x._spec.mesh["tp"] + # after DeviceMesh supports slicing a non-root mesh + dp_mesh = self.device_mesh # re-wrap 2D DTensor to 1D DTensor on dp_mesh for efficient FSDP all-gather sharded_local_tensor = x.to_local() sharded_dtensor = DTensor.from_local( @@ -227,20 +230,31 @@ def replicate_compute(self, x): backward_dtype=self.reduce_dtype, ) - # re-wrap 1D all-gathered DTensor on dp_mesh to 1D DTensor on tp_mesh + # re-wrap all-gathered DTensor on dp_mesh to be on non_dp_mesh # TODO: DTensor should support this mesh collapsing operation replicated_local_tensor = replicated_dtensor.to_local( grad_placements=self.grad_placements ) + + non_dp_placements = tuple(x._spec.placements[-non_dp_mesh_dims:]) + non_dp_mesh_dim_names = tuple( + x._spec.mesh.mesh_dim_names[-non_dp_mesh_dims:] + ) + non_dp_mesh = x._spec.mesh[non_dp_mesh_dim_names] + output = DTensor.from_local( - replicated_local_tensor, tp_mesh, (tp_placement,) + replicated_local_tensor, non_dp_mesh, non_dp_placements ) - else: + elif non_dp_mesh_dims == 0: output = x.redistribute( placements=self.compute_placements, forward_dtype=self.param_dtype, backward_dtype=self.reduce_dtype, ).to_local(grad_placements=self.grad_placements) + else: + raise AssertionError( + f"Unsupported replicate compute on placement {x._spec.placements} for DTensor {x}" + ) return output @@ -249,7 +263,7 @@ def forward(self, x): # This should never be set to true during forward, only outside for model # inspection / debugging / initialization # model initialization can be done now through - # with disable_data_parallel(): + # with disable_active_parametrization(): # model.init_weights() if not _active_parametrization: return x @@ -271,14 +285,15 @@ def data_parallel( mode="replicate", ac_mode: str = "none", mp_policy: Optional[MixedPrecisionPolicy] = None, + shard_dim: int = 0, ): if mode == "replicate": param_sharding = (Replicate(),) elif mode == "fully_shard": - param_sharding = (Shard(0),) + param_sharding = (Shard(shard_dim),) elif mode == "hybrid_shard": # replicate inter-host, fully shard intra-host - param_sharding = (Replicate(), Shard(0)) + param_sharding = (Replicate(), Shard(shard_dim)) assert ( device_mesh.ndim == 2 ), "hybrid sharded data parallel requires 2D DeviceMesh" @@ -292,6 +307,11 @@ def data_parallel( for mod in modules: params_dict = dict(mod.named_parameters(recurse=False)) + # we shouldn't apply data parallel to the modules that are already + # sharded by data parallel + if "SimpleFSDP" in mod.__class__.__name__: + continue + for p_name, p in params_dict.items(): if p is not None and p.numel() > 0: distribute_tensor_func = ( @@ -303,6 +323,7 @@ def data_parallel( distribute_tensor_func(p, device_mesh, param_sharding) ), ) + # to be compatible with DCP, we use a customized _register_parametrization # instead of nn.utils.parametrize.register_parametrization here # nn.utils.parametrize.register_parametrization( diff --git a/torchtitan/experiments/simple_fsdp/tests/integration_tests.py b/torchtitan/experiments/simple_fsdp/tests/integration_tests.py index 33ccbd8903..aa7f40cdc9 100755 --- a/torchtitan/experiments/simple_fsdp/tests/integration_tests.py +++ b/torchtitan/experiments/simple_fsdp/tests/integration_tests.py @@ -21,7 +21,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]: OverrideDefinitions( [ [ - "--model.name simple_fsdp", + "--model.name simple_fsdp.llama3", "--compile.enable", ], ], @@ -31,7 +31,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]: OverrideDefinitions( [ [ - "--model.name simple_fsdp", + "--model.name simple_fsdp.llama3", "--compile.enable", "--activation_checkpoint.mode selective", "--activation_checkpoint.selective_ac_option op", @@ -43,7 +43,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]: OverrideDefinitions( [ [ - "--model.name simple_fsdp", + "--model.name simple_fsdp.llama3", "--compile.enable", "--activation_checkpoint.mode full", ], @@ -54,7 +54,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]: OverrideDefinitions( [ [ - "--model.name simple_fsdp", + "--model.name simple_fsdp.llama3", "--compile.enable", "--parallelism.tensor_parallel_degree 2", ], @@ -79,12 +79,12 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]: OverrideDefinitions( [ [ - "--model.name simple_fsdp", + "--model.name simple_fsdp.llama3", "--compile.enable", "--checkpoint.enable", ], [ - "--model.name simple_fsdp", + "--model.name simple_fsdp.llama3", "--compile.enable", "--checkpoint.enable", "--training.steps 20", @@ -96,7 +96,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]: OverrideDefinitions( [ [ - "--model.name simple_fsdp", + "--model.name simple_fsdp.llama3", "--compile.enable", "--checkpoint.enable", "--parallelism.pipeline_parallel_degree 2", @@ -104,7 +104,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]: "--parallelism.tensor_parallel_degree 2", ], [ - "--model.name simple_fsdp", + "--model.name simple_fsdp.llama3", "--compile.enable", "--training.steps 20", "--checkpoint.enable", @@ -120,7 +120,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]: OverrideDefinitions( [ [ - "--model.name simple_fsdp", + "--model.name simple_fsdp.llama3", "--compile.enable", "--parallelism.data_parallel_shard_degree 1", "--parallelism.data_parallel_replicate_degree 4", @@ -133,7 +133,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]: OverrideDefinitions( [ [ - "--model.name simple_fsdp", + "--model.name simple_fsdp.llama3", "--compile.enable", "--parallelism.data_parallel_shard_degree 2", "--parallelism.data_parallel_replicate_degree 2", @@ -146,7 +146,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]: OverrideDefinitions( [ [ - "--model.name simple_fsdp", + "--model.name simple_fsdp.llama3", "--compile.enable", "--parallelism.data_parallel_shard_degree 2", "--parallelism.data_parallel_replicate_degree 2", @@ -160,7 +160,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]: OverrideDefinitions( [ [ - "--model.name simple_fsdp", + "--model.name simple_fsdp.llama3", "--compile.enable", "--parallelism.data_parallel_replicate_degree 2", "--parallelism.tensor_parallel_degree 2", @@ -173,7 +173,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]: OverrideDefinitions( [ [ - "--model.name simple_fsdp", + "--model.name simple_fsdp.llama3", "--compile.enable", "--parallelism.data_parallel_shard_degree 2", "--parallelism.data_parallel_replicate_degree 2", @@ -187,7 +187,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]: OverrideDefinitions( [ [ - "--model.name simple_fsdp", + "--model.name simple_fsdp.llama3", "--compile.enable", "--parallelism.data_parallel_shard_degree 2", "--parallelism.tensor_parallel_degree 2", @@ -201,7 +201,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]: OverrideDefinitions( [ [ - "--model.name simple_fsdp", + "--model.name simple_fsdp.llama3", "--compile.enable", "--checkpoint.enable", "--training.steps 10", @@ -209,7 +209,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]: # Save at [dp:4] and load at [dp:2, tp:2]. Note that the dataloader should be # excluded during loading to avoid errors caused by mismatched dp_degree. [ - "--model.name simple_fsdp", + "--model.name simple_fsdp.llama3", "--compile.enable", "--checkpoint.enable", "--checkpoint.exclude_from_loading lr_scheduler,dataloader,optimizer", @@ -218,7 +218,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]: ], # load at [tp:4]. [ - "--model.name simple_fsdp", + "--model.name simple_fsdp.llama3", "--compile.enable", "--checkpoint.enable", "--checkpoint.exclude_from_loading lr_scheduler,dataloader,optimizer", @@ -230,6 +230,46 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]: "optional_checkpoint", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--model.name simple_fsdp.deepseek_v3", + "--parallelism.data_parallel_shard_degree 4", + "--parallelism.expert_parallel_degree 2", + ], + ], + "FSDP+EP", + "fsdp+ep", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--model.name simple_fsdp.deepseek_v3", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--parallelism.expert_parallel_degree 4", + "--parallelism.expert_tensor_parallel_degree 1", + ], + ], + "FSDP+TP+EP", + "fsdp+tp+ep", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--model.name simple_fsdp.deepseek_v3", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--parallelism.expert_parallel_degree 2", + "--parallelism.expert_tensor_parallel_degree 2", + ], + ], + "FSDP+TP+EP+ETP", + "fsdp+tp+ep+etp", + ngpu=4, + ), ] return integration_tests_flavors