-
Notifications
You must be signed in to change notification settings - Fork 603
add support for simplefsdp+ep #1529
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
058cf25 to
407f5f8
Compare
407f5f8 to
63eec2d
Compare
97bb837 to
273e7b4
Compare
b1d2def to
f194b2c
Compare
Yes, it's turned on here:
|
|
AC HOP doesn't allow mutations in it, what are you wrapping in AC for this configuration? if you apply ac to attention and moe experts separately, we can rewrite the buffer mutation to happen outside of moe |
|
I tried SAC in the default torchtitan wrapping setting. I'm working on a refactor of this PR. I can try wrapping AC separately later and hand to you to check the mutation errors? |
|
The log for AC mutation error is: Note: |
e386ea6 to
8696988
Compare
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the ground-breaking work on SimpleFSDP + EP.
Left comments.
|
|
||
| if job_config.compile.enable: | ||
| torch._inductor.config.reorder_for_peak_memory = False | ||
| torch._dynamo.config.capture_scalar_outputs = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Last time I heard from @xmfan that there are correctness concern for this field, although it gives us full graph?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is to ensure tolist() in A2A can be traced into the graph, otherwise it would cause graph breaks. It reads the items from a list by treating it as a tensor, and readout as .item(). Not sure if there would be correctness concerns @xmfan.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There shouldn't be. Bother @bobrenjc93 and @laithsakka if there are
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so flipping those flags should NOT ""effect correctness"" but lets define correctness here.
flipping those flags, would opt into unbacked semantics for some unbacked dependent ops (but thats the only way you can trace through those). With unbacked semantics some behaviors could deviate from eager not in a very harmful manner, usually side effects are different output strides, or clones happening.
For example a reshape that depends on unbacked symbols(outputs of .tolist() could result in a clone changing the the output strides of the reshape; had those symbols been known and a reshape was translated to view strides could have been different).
| transformer_block.moe.shared_experts = data_parallel( | ||
| transformer_block.moe.shared_experts, | ||
| dp_mesh, | ||
| "replicate", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd recommend we keep the same decision as torchtitan FSDP2 code -- shard instead of replicate. Using replicate here would mean that it can't be bucketed with the reset of TransformerBlock, as the comms are different.
The perf / tradeoffs are not clear at this moment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we know the perf ROI for replicate? is there value in u-benching perf of either approach?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The replicated compute in shared expert can potentially be overlapped with A2A in token dispatch/combine. However, shared expert is related small in DSV3. This is why Tianyu said replicate or shard the experts won't make a big difference to perf.
| # we shouldn't parametrize the parameters that are already sharded in DP | ||
| if isinstance(p, DTensor) and any( | ||
| "dp_shard" in dim_name for dim_name in p._spec.mesh.mesh_dim_names | ||
| ): | ||
| enable_parametrizations.append(False) | ||
| else: | ||
| enable_parametrizations.append(True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry I can't parse this. Is it for nested data_parallel wrapping?
Could you elaborate the idea? I feel this is the biggest UX decision to make, so let's be careful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, this is for nested data_parallel. Specifically, it's for the MoE wrapping and model wrapping here: filtering out what module has been wrapped by data_parallel and wrapping the rest would make deepseek_v3_parallelize.py messy.
| param_sharding = (Replicate(),) * device_mesh.ndim | ||
| elif mode == "fully_shard": | ||
| param_sharding = (Shard(0),) | ||
| param_sharding = (Shard(0),) * device_mesh.ndim | ||
| elif mode == "hybrid_shard": | ||
| # replicate inter-host, fully shard intra-host | ||
| param_sharding = (Replicate(), Shard(0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't always do DP shard on dim-0 for MoE. See #1561
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make sense, I hit the dim not match bug when adding HSDP. Will try to update like this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here are testing results following your PR. I also added _ StridedShard based on offline discussion:
- 4 experts, FSDP2 EP4
[rank0]:p after DTensor(local_tensor=tensor(..., device='meta', size=(1, 128, 256)), device_mesh=DeviceMesh((dp_shard_mod_ep=2, ep=4), device: 'cuda', stride: (4, 1)), placements=(Shard(dim=1), Shard(dim=0)))
[rank0]:p after DTensor(local_tensor=tensor(..., device='meta', size=(1, 128, 256)), device_mesh=DeviceMesh((dp_shard_mod_ep=2, ep=4), device: 'cuda', stride: (4, 1)), placements=(Shard(dim=1), Shard(dim=0)))
[rank0]:p after DTensor(local_tensor=tensor(..., device='meta', size=(1, 128, 256)), device_mesh=DeviceMesh((dp_shard_mod_ep=2, ep=4), device: 'cuda', stride: (4, 1)), placements=(Shard(dim=1), Shard(dim=0)))
- 8 experts FSDP2 EP4
[rank0]:p after DTensor(local_tensor=tensor(..., device='meta', size=(1, 256, 256)), device_mesh=DeviceMesh((dp_shard_mod_ep=2, ep=4), device: 'cuda', stride: (4, 1)), placements=(_StridedShard(dim=0, sf=4), Shard(dim=0)))
[rank0]:p after DTensor(local_tensor=tensor(..., device='meta', size=(1, 256, 256)), device_mesh=DeviceMesh((dp_shard_mod_ep=2, ep=4), device: 'cuda', stride: (4, 1)), placements=(_StridedShard(dim=0, sf=4), Shard(dim=0)))
[rank0]:p after DTensor(local_tensor=tensor(..., device='meta', size=(1, 256, 256)), device_mesh=DeviceMesh((dp_shard_mod_ep=2, ep=4), device: 'cuda', stride: (4, 1)), placements=(_StridedShard(dim=0, sf=4), Shard(dim=0)))
| inner_spec.placements[1], | ||
| ) | ||
| else: | ||
| # For FSDP + TP |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why it has to be FSDP, can it be FSDP + EP?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, because this _distribute_dtensor function would only be called when p is a DTensor -- p has to be shareded by TP to be a DTensor. In FSDP + EP case, it would call distribute_tensor instead of this one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why? After EP sharding it's also a DTensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm yes, you are right. It will call _distribute_dtensor. But the functional is still the same since it would be either FSDP+EP or FSDP+TP. We checked if tp is in the device mesh dim name when calling _StridedShard.
| grad_placements=self.grad_placements | ||
| ) | ||
|
|
||
| if ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again I think here we should try to make code generic, without hardcoding "ep" / "tp" if possible. Please give a try and let me know if you find it impossible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I totally agree on it. I think we can explicitly parse dim name used for doing "ep" and "tp" into data_parallel function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think have a dp_tp_ep_mesh_name arg input to allow user specify their tp/ep mesh name would make sense? Link
| transformer_block.moe.experts = data_parallel( | ||
| transformer_block.moe.experts, | ||
| world_mesh[tuple(dp_mod_ep_mesh_dim_names)], | ||
| dp_mode, | ||
| ac_mode=job_config.activation_checkpoint.mode, | ||
| mp_policy=mp_policy, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we also need to solve this issue
#1551
o/w the gradient has wrong scale
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch, will add!
|
@tianyu-l do you want to give the final stamp? I think it looks pretty good overall but you raised some good points. |
8696988 to
6e37735
Compare
|
I was able to address most of @tianyu-l's comment, except for |
7513799 to
ae8cee7
Compare
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed offline, I would recommend:
- supporting nested
data_parallelcall but relies on altered FSDP module name instead of device mesh name, which is not reliable. - supporting gradient divisor via a custom subclass of
Partial. This can be done in a separate PR.
ae8cee7 to
4605581
Compare
4605581 to
f14a8a7
Compare
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost there!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's
- not change the name of this function
- include SimpleFSDP + EP composability tests for deepseek v3, e.g. you can consider these two https://github.com/pytorch/torchtitan/blob/main/tests/integration_tests/models.py#L47-L79
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good. added new test for FSDP+EP/FSDP+EP+TP/FSDP+EP+ETP.
I skipped PP since I found both 1F1B and Interleaved1F1B doesn't compose well with EP+TP. The following PP enablement all work in eager mode. However, after adding compile, I hit some errors.
Here are some explorations (I will also add to PR summary). I didn't do a in-depth debugging, but most of them seems to be related to dynamic shapes lolll.
- SimpleFSDP + PP (1F1B) + TP + compile: works
- SimpleFSDP + PP (1F1B) + EP + compile: failed with error below (Tlparse Link)
[rank0]:[titan] 2025-09-29 21:51:23,835 - root - INFO - Training starts at step 1
[rank0]:/home/ruisizhang123/pytorch/torch/distributed/device_mesh.py:321: UserWarning: You are attempting to slice a submesh from another submesh. While we support this operation, it is users' responsibility to ensure that the submesh is consistently sliced across all ranks. If not, this may result in some ranks receiving the submesh while others encounter errors.
[rank0]: warnings.warn(
[rank0]:/home/ruisizhang123/pytorch/torch/distributed/device_mesh.py:321: UserWarning: You are attempting to slice a submesh from another submesh. While we support this operation, it is users' responsibility to ensure that the submesh is consistently sliced across all ranks. If not, this may result in some ranks receiving the submesh while others encounter errors.
[rank0]: warnings.warn(
[rank0]:[rank0]:W0929 21:51:27.582000 1518045 torch/_logging/_internal.py:1199] [0/0]
[rank0]:[rank0]:W0929 21:51:27.582000 1518045 torch/_logging/_internal.py:1199] [0/0] Detected that context_fn is passed to torch.utils.checkpoint under torch.compile.
[rank0]:[rank0]:W0929 21:51:27.582000 1518045 torch/_logging/internal.py:1199] [0/0] Please make sure the checkpointed region does not contain in-place ops (e.g. torch.relu).
[rank0]:[rank0]:W0929 21:51:27.582000 1518045 torch/_logging/_internal.py:1199] [0/0]
[rank0]:/home/ruisizhang123/pytorch/torch/_inductor/lowering.py:2004: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]: warnings.warn(
[rank0]:[rank0]: Traceback (most recent call last):
[rank0]:[rank0]: File "", line 198, in _run_module_as_main
[rank0]:[rank0]: File "", line 88, in _run_code
[rank0]:[rank0]: File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 671, in
[rank0]:[rank0]: trainer.train()
[rank0]:[rank0]: ~~~~~~~~~~~~~^^
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/distributed/elastic/multiprocessing/errors/init.py", line 357, in wrapper
[rank0]:[rank0]: return f(*args, **kwargs)
[rank0]:[rank0]: File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 596, in train
[rank0]:[rank0]: self.train_step(data_iterator)
[rank0]:[rank0]: ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 496, in train_step
[rank0]:[rank0]: loss = self.forward_backward_step(input_dict, labels)
[rank0]:[rank0]: File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 445, in forward_backward_step
[rank0]:[rank0]: self.pp_schedule.step(
[rank0]:[rank0]: ~~~~~~~~~~~~~~~~~~~~~^
[rank0]:[rank0]: inputs,
[rank0]:[rank0]: ^^^^^^^
[rank0]:[rank0]: ...<3 lines>...
[rank0]:[rank0]: input_batch=inputs,
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: )
[rank0]:[rank0]: ^
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/distributed/pipelining/schedules.py", line 624, in step
[rank0]:[rank0]: self._step_microbatches(args_split, kwargs_split, targets_split, losses)
[rank0]:[rank0]: ~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/distributed/pipelining/schedules.py", line 839, in _step_microbatches
[rank0]:[rank0]: self._initialize_stage(arg_mbs[0], kwarg_mbs[0])
[rank0]:[rank0]: ~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/distributed/pipelining/schedules.py", line 583, in _initialize_stage
[rank0]:[rank0]: self._stage._prepare_forward_infra(self._n_microbatches, args, kwargs)
[rank0]:[rank0]: ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/distributed/pipelining/stage.py", line 1510, in _prepare_forward_infra
[rank0]:[rank0]: outputs = self._shape_inference(args, kwargs)
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/distributed/pipelining/stage.py", line 1441, in _shape_inference
[rank0]:[rank0]: outputs = self.submod(*args, **kwargs)
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/_dynamo/eval_frame.py", line 418, in call
[rank0]:[rank0]: return super().call(*args, **kwargs)
[rank0]:[rank0]: ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
[rank0]:[rank0]: return self._call_impl(*args, **kwargs)
[rank0]:[rank0]: ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1788, in _call_impl
[rank0]:[rank0]: return forward_call(*args, **kwargs)
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/_dynamo/eval_frame.py", line 886, in compile_wrapper
[rank0]:[rank0]: return fn(*args, **kwargs)
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
[rank0]:[rank0]: return self._call_impl(*args, **kwargs)
[rank0]:[rank0]: ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1788, in _call_impl
[rank0]:[rank0]: return forward_call(*args, **kwargs)
[rank0]:[rank0]: File "/home/ruisizhang123/torchtitan/torchtitan/models/deepseek_v3/model/model.py", line 363, in forward
[rank0]:[rank0]: def forward(
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/_dynamo/eval_frame.py", line 1098, in _fn
[rank0]:[rank0]: return fn(*args, **kwargs)
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/_functorch/aot_autograd.py", line 1134, in forward
[rank0]:[rank0]: return compiled_fn(full_args)
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 355, in runtime_wrapper
[rank0]:[rank0]: all_outs = call_func_at_runtime_with_args(
[rank0]:[rank0]: compiled_fn, args, disable_amp=disable_amp, steal_args=True
[rank0]:[rank0]: )
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/utils.py", line 130, in call_func_at_runtime_with_args
[rank0]:[rank0]: out = normalize_as_list(f(args))
[rank0]:[rank0]: ~^^^^^^
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 693, in inner_fn
[rank0]:[rank0]: unwrapped_outs = compiled_fn(unwrapped_args)
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 727, in inner_fn
[rank0]:[rank0]: outs = compiled_fn(args)
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 529, in wrapper
[rank0]:[rank0]: return compiled_fn(runtime_args)
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/_inductor/output_code.py", line 617, in call
[rank0]:[rank0]: return self.current_callable(inputs)
[rank0]:[rank0]: ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/_inductor/utils.py", line 3018, in run
[rank0]:[rank0]: out = model(new_inputs)
[rank0]:[rank0]: File "/tmp/torchinductor_ruisizhang123/sk/csk3kend5qh55mkvyqnaqg7tj7zhubsaoiewgglmofcgbpn6me4y.py", line 2926, in call
[rank0]:[rank0]: raise RuntimeError('Eq(u0 + u1, 6144)')
[rank0]:[rank0]: RuntimeError: Eq(u0 + u1, 6144)
- SimpleFSDP + PP (Interleaved1F1B) + TP + compile: failed with error below (full log Link Tlparse (Link))
File "/home/ruisizhang123/pytorch/torch/fx/experimental/proxy_tensor.py", line 571, in thunkify
[rank0]:[rank0]: r = f(*args, **kwargs)
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/fx/experimental/proxy_tensor.py", line 1584, in _compute_proxy
[rank0]:[rank0]: n_args = tuple(
[rank0]:[rank0]: (
[rank0]:[rank0]: ...<4 lines>...
[rank0]:[rank0]: for a in args
[rank0]:[rank0]: )
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/fx/experimental/proxy_tensor.py", line 1586, in
[rank0]:[rank0]: get_proxy_slot(a, self.tracer).force().node
[rank0]:[rank0]: ~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/fx/experimental/proxy_tensor.py", line 411, in get_proxy_slot
[rank0]:[rank0]: raise RuntimeError(
[rank0]:[rank0]: f"{obj} ({id(obj)})is not tracked with proxy for {tracer}"
[rank0]:[rank0]: )
[rank0]:[rank0]: torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
[rank0]:[rank0]: RuntimeError: 256*s27 (140036273922896)is not tracked with proxy for <torch.fx.experimental.proxy_tensor.PythonKeyTracer object at 0x7f5c29750fc0>
[rank0]:
[rank0]:[rank0]: While executing %embedding : [num_users=1] = call_function[target=torch.nn.functional.embedding](args = (%input_tensor, %output, None, None, 2.0, False, False), kwargs = {})
[rank0]:[rank0]: Original traceback:
[rank0]:[rank0]: File "/home/ruisizhang123/torchtitan/torchtitan/models/deepseek_v3/model/model.py", line 385, in forward
[rank0]:[rank0]: h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens
[rank0]:
[rank0]:[rank0]: Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)
[rank0]:
- SimpleFSDP + PP (Interleaved1F1B) + EP + compile: went through, but not sure if it's really working (Tlpase Link)
I checked the Tlparse for 4, and found it has the same Eq(u0 + u1, 6144) as 2 (SimpleFSDP + PP (1F1B) + EP + compile). It's just that the code didn't go into Eq(u0 + u1, 6144) in if-else branch.
torchtitan/experiments/simple_fsdp/deepseek_v3/deepseek_v3_parallelize.py
Outdated
Show resolved
Hide resolved
torchtitan/experiments/simple_fsdp/deepseek_v3/deepseek_v3_parallelize.py
Outdated
Show resolved
Hide resolved
1df12f1 to
ce849e6
Compare
Thank you for the review. Updated accordingly. |
ce67540 to
0761f4d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks very nice. Let's follow up with
- Partial enhancement
- fixing PP
Please fix a minor comment issue (see inline) before merge.
| ) | ||
|
|
||
| # re-wrap 1D all-gathered DTensor on dp_mesh to 1D DTensor on tp_mesh | ||
| # re-wrap 1D all-gathered DTensor on dp_mesh to 1D DTensor on tp_mesh/ep_mesh |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment doesn't sound right, even before this PR.
I think what it tries to say is
| # re-wrap 1D all-gathered DTensor on dp_mesh to 1D DTensor on tp_mesh/ep_mesh | |
| # re-wrap all-gathered DTensor on dp_mesh to be on non_dp_mesh |
48dd348 to
1360a85
Compare
1360a85 to
c28ddd8
Compare
this PR is a followup of SimpleFSDP+EP [PR](#1529). Here, we add a `gradient_divide_factor` following FSDP2 to ensure modules wrapped by (FSDP+EP) has the correct gradient reduction value. - The original FSDP2 implementation is in this [PR](#1551). - The `gradient_divide_factor` logic is [here](https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py#L688) We have two ways of handling `gradient_divide_factor` in `reduce_scatter`: 1. The first one is to use `ReduceOp.PREMUL_SUM` to handle the `gradient_divide_factor`. However, DTensor's `_reduce_shard_value` only accepts `reduce_op` as a str input ([here](https://github.com/pytorch/pytorch/blob/8f705d019a64b1ca882e043b3eb98559273a9e59/torch/distributed/tensor/placement_types.py#L177-L210)). To make` _reduce_shard_value` work correctly with ReduceOp.PREMUL_SUM, we need to update the DTensor `_reduce_shard_tensor` and `torch.distributed._functional_collectives.reduce_scatter_tensor` so that it can pass the factor associated with ReduceOp.PREMUL_SUM as an input. 2. Another way is to simulate `ReduceOp.PREMUL_SUM` with `ReduceOp.SUM`. The logic is in this [Diff](https://www.internalfb.com/diff/D76546536). It does a `div_` over gradient before performing `ReduceOp.SUM`. Currently I'm following 2 since it is requires less change to `_functional_collectives`. After enabling `reduction_divide_factor`, we will see FSDP(=2) + EP (=4) have identical loss: <img width="1194" height="780" alt="Screenshot 2025-10-08 at 5 27 24 PM" src="https://github.com/user-attachments/assets/aaf83109-8db8-4051-973d-c7b6950513de" />
this PR is a followup of SimpleFSDP+EP [PR](pytorch#1529). Here, we add a `gradient_divide_factor` following FSDP2 to ensure modules wrapped by (FSDP+EP) has the correct gradient reduction value. - The original FSDP2 implementation is in this [PR](pytorch#1551). - The `gradient_divide_factor` logic is [here](https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py#L688) We have two ways of handling `gradient_divide_factor` in `reduce_scatter`: 1. The first one is to use `ReduceOp.PREMUL_SUM` to handle the `gradient_divide_factor`. However, DTensor's `_reduce_shard_value` only accepts `reduce_op` as a str input ([here](https://github.com/pytorch/pytorch/blob/8f705d019a64b1ca882e043b3eb98559273a9e59/torch/distributed/tensor/placement_types.py#L177-L210)). To make` _reduce_shard_value` work correctly with ReduceOp.PREMUL_SUM, we need to update the DTensor `_reduce_shard_tensor` and `torch.distributed._functional_collectives.reduce_scatter_tensor` so that it can pass the factor associated with ReduceOp.PREMUL_SUM as an input. 2. Another way is to simulate `ReduceOp.PREMUL_SUM` with `ReduceOp.SUM`. The logic is in this [Diff](https://www.internalfb.com/diff/D76546536). It does a `div_` over gradient before performing `ReduceOp.SUM`. Currently I'm following 2 since it is requires less change to `_functional_collectives`. After enabling `reduction_divide_factor`, we will see FSDP(=2) + EP (=4) have identical loss: <img width="1194" height="780" alt="Screenshot 2025-10-08 at 5 27 24 PM" src="https://github.com/user-attachments/assets/aaf83109-8db8-4051-973d-c7b6950513de" />
As titled, this pr adds support for simplefsdp+ep.
Profiler Trace & Correctness
The following results are benchmarks on 8 H100. The correctness are validated on eager (with seed set to 42); I also added profile trace and tlparse in compile mode.
Tlparse: Link
Profile Trace: Link there would be only fsdp comms.
Tlparse: Link
Profile Trace: (Link)
Tlparse: Link
Profile Trace: (Link)
Tlparse: Link
Profile Trace: (Link)
Status
Currently, SimpleFSDP+EP+AC doesn't work because of graph breaks in HOP mutation. (see the comment message )
@ezyang also points out that numerical issues in Float8 training may be related to this, since quantization calibration involves mutation, which might not compose well with AC.
There are some minor numerical differences after EP is added. @tianyu-l suggests that it's because the reduce-scatter is averaged over whole dp_mesh, instead of dp_mod_ep_mesh, for the EP part. We will fix this in a separate PR.
PP(1F1B and Interleaved1F1B)+EP+SimpleFSDP has a series of composability issues: add support for simplefsdp+ep #1529 (comment)
cc. @anijain2305