-
Notifications
You must be signed in to change notification settings - Fork 603
add support for simplefsdp+ep #1529
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
| # Copyright (c) Meta Platforms, Inc. All Rights Reserved. | ||
|
|
||
| from torchtitan.components.loss import build_cross_entropy_loss | ||
| from torchtitan.components.lr_scheduler import build_lr_schedulers | ||
| from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing | ||
| from torchtitan.components.tokenizer import build_hf_tokenizer | ||
| from torchtitan.datasets.hf_datasets import build_hf_dataloader | ||
| from torchtitan.models.deepseek_v3 import deepseekv3_configs | ||
| from torchtitan.models.llama3 import pipeline_llama | ||
| from torchtitan.protocols.train_spec import TrainSpec | ||
|
|
||
| from .model import SimpleFSDPDeepSeekV3Model | ||
| from .parallelize import parallelize_deepseekv3 | ||
|
|
||
|
|
||
| def get_train_spec() -> TrainSpec: | ||
| return TrainSpec( | ||
| name="simple_fsdp.deepseek_v3", | ||
| model_cls=SimpleFSDPDeepSeekV3Model, | ||
| model_args=deepseekv3_configs, | ||
| parallelize_fn=parallelize_deepseekv3, | ||
| pipelining_fn=pipeline_llama, | ||
| build_optimizers_fn=build_optimizers_with_moe_load_balancing, | ||
| build_lr_schedulers_fn=build_lr_schedulers, | ||
| build_dataloader_fn=build_hf_dataloader, | ||
| build_tokenizer_fn=build_hf_tokenizer, | ||
| build_loss_fn=build_cross_entropy_loss, | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from torchtitan.models.deepseek_v3 import DeepSeekV3Model, DeepSeekV3ModelArgs | ||
|
|
||
| from ..simple_fsdp import disable_active_parametrization | ||
|
|
||
|
|
||
| class SimpleFSDPDeepSeekV3Model(DeepSeekV3Model): | ||
| def __init__(self, model_args: DeepSeekV3ModelArgs): | ||
| super().__init__(model_args) | ||
| self.init_weights() | ||
|
|
||
| def init_weights(self, *args, **kwargs): | ||
| with disable_active_parametrization(): | ||
| super().init_weights(*args, **kwargs) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,158 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from torch.distributed.device_mesh import DeviceMesh | ||
|
|
||
| from torchtitan.config import JobConfig, TORCH_DTYPE_MAP | ||
| from torchtitan.distributed import ParallelDims | ||
| from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp | ||
| from torchtitan.experiments.llama4.infra.parallelize import apply_moe_ep_tp | ||
| from torchtitan.models.deepseek_v3.infra.parallelize import apply_non_moe_tp | ||
| from torchtitan.models.llama3.infra.parallelize import apply_ac | ||
| from torchtitan.tools.logging import logger | ||
|
|
||
| from ..simple_fsdp import data_parallel, MixedPrecisionPolicy | ||
|
|
||
| # Adapted from llama4/infra/parallelize.py | ||
| def parallelize_deepseekv3( | ||
| model: nn.Module, | ||
| parallel_dims: ParallelDims, | ||
| job_config: JobConfig, | ||
| ): | ||
| world_mesh = parallel_dims.world_mesh | ||
| # TODO: TP currently cannot handle uneven seq_len because we set | ||
| # `use_local_output=True` to use plain Tensors for legacy reasons. | ||
| # Need to revisit this. | ||
| assert ( | ||
| job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 | ||
| ), f""" | ||
| Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree | ||
| ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}), i.e. {parallel_dims.seq_len_divisor}. | ||
| """ | ||
|
|
||
| if ( | ||
| job_config.parallelism.context_parallel_degree > 1 | ||
| and model.model_args.use_flex_attn | ||
| ): | ||
| raise NotImplementedError("CP support for FlexAttention is still in progress.") | ||
|
|
||
| if parallel_dims.tp_enabled: | ||
| enable_float8_linear = "float8" in job_config.model.converters | ||
| float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in ( | ||
| "rowwise", | ||
| "rowwise_with_gw_hp", | ||
| ) | ||
|
|
||
| enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise | ||
| if enable_float8_tensorwise_tp: | ||
| # TODO(jianiw): This branch needs to be tested and enabled | ||
| raise NotImplementedError( | ||
| "Currently, float8 tensorwise TP is not tested for deepseekv3" | ||
| ) | ||
|
|
||
| apply_non_moe_tp( | ||
| model, | ||
| world_mesh["tp"], | ||
| loss_parallel=not job_config.parallelism.disable_loss_parallel, | ||
| enable_float8_tensorwise_tp=False, | ||
| ) | ||
| maybe_enable_async_tp(job_config, world_mesh["tp"]) | ||
|
|
||
| if parallel_dims.tp_enabled or parallel_dims.ep_enabled: | ||
| apply_moe_ep_tp( | ||
| model, | ||
| tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, | ||
| ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, | ||
| ep_tp_mesh=( | ||
| world_mesh["ep", "tp"] | ||
| if parallel_dims.tp_enabled | ||
| and parallel_dims.ep_enabled | ||
| and parallel_dims.etp_enabled | ||
| else None | ||
| ), | ||
| etp_enabled=parallel_dims.etp_enabled, | ||
| ) | ||
|
|
||
| if job_config.activation_checkpoint.mode != "none": | ||
| apply_ac(model, job_config.activation_checkpoint) | ||
|
|
||
| mp_policy = MixedPrecisionPolicy( | ||
| param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], | ||
| reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], | ||
| ) | ||
|
|
||
| # apply data parallel | ||
| dp_mesh: DeviceMesh | None = None | ||
| if ( | ||
| parallel_dims.fsdp_enabled | ||
| or parallel_dims.ep_enabled | ||
| or parallel_dims.dp_replicate_enabled | ||
| ): | ||
| if parallel_dims.dp_replicate_enabled: | ||
| if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled: | ||
| dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") | ||
| dp_mode = "hybrid_shard" | ||
| else: | ||
| dp_mesh_dim_names = ("dp_replicate",) | ||
| dp_mode = "replicate" | ||
| else: | ||
| dp_mesh_dim_names = ("dp_shard_cp",) | ||
| dp_mode = "fully_shard" | ||
|
|
||
| dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] | ||
| # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are the MoE params actually 'sharded' on dp_replicate (borrowing the replicate dim for further sharding), or are they replicated as the name would suggest?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I copied this line from I would also want to learn this more from @tianyu-l. My basic understanding (maybe wrong) is: it depends on the dp_dim left after EP borrowing. If the left |
||
| dp_mod_ep_mesh_dim_names = [] | ||
|
|
||
| if parallel_dims.ep_enabled: | ||
| if parallel_dims.dp_replicate_enabled: | ||
| dp_mod_ep_mesh_dim_names.append("dp_replicate") | ||
| dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") | ||
| dp_mod_ep_mesh = world_mesh[tuple(dp_mod_ep_mesh_dim_names)] | ||
|
|
||
| for _, transformer_block in model.layers.items(): | ||
| if transformer_block.moe_enabled and parallel_dims.ep_enabled: | ||
| experts_shard_dim = 0 | ||
| assert dp_mod_ep_mesh is not None | ||
| assert hasattr(transformer_block, "moe") | ||
| if ( | ||
| dp_mod_ep_mesh.size() * parallel_dims.ep | ||
| > transformer_block.moe.experts.num_experts | ||
| ): | ||
| experts_shard_dim = 1 | ||
|
|
||
| transformer_block.moe.experts = data_parallel( | ||
| transformer_block.moe.experts, | ||
| dp_mod_ep_mesh, | ||
| dp_mode, | ||
| ac_mode=job_config.activation_checkpoint.mode, | ||
| mp_policy=mp_policy, | ||
| shard_dim=experts_shard_dim, | ||
| ) | ||
| # TODO(ruisizhang123): support set_gradient_divide_factor in simplefsdp | ||
| # transformer_block.moe.experts.set_gradient_divide_factor( | ||
| # parallel_dims.fsdp_gradient_divide_factor, | ||
| # ) | ||
|
|
||
| model = data_parallel( | ||
| model, | ||
| dp_mesh, | ||
| dp_mode, | ||
| ac_mode=job_config.activation_checkpoint.mode, | ||
| mp_policy=mp_policy, | ||
| ) | ||
|
|
||
| logger.info( | ||
| "Applied Data Parallel (simple_fsdp) (dp mode=%s) to the model", dp_mode | ||
| ) | ||
|
|
||
| if job_config.compile.enable: | ||
| torch._inductor.config.reorder_for_peak_memory = False | ||
| torch._dynamo.config.capture_scalar_outputs = True | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Last time I heard from @xmfan that there are correctness concern for this field, although it gives us full graph?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is to ensure
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There shouldn't be. Bother @bobrenjc93 and @laithsakka if there are There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so flipping those flags should NOT ""effect correctness"" but lets define correctness here. flipping those flags, would opt into unbacked semantics for some unbacked dependent ops (but thats the only way you can trace through those). With unbacked semantics some behaviors could deviate from eager not in a very harmful manner, usually side effects are different output strides, or clones happening. For example a reshape that depends on unbacked symbols(outputs of .tolist() could result in a clone changing the the output strides of the reshape; had those symbols been known and a reshape was translated to view strides could have been different). |
||
| model = torch.compile(model, fullgraph=True) | ||
|
|
||
| return model | ||
Uh oh!
There was an error while loading. Please reload this page.