Skip to content

Conversation

@ruisizhang123
Copy link
Contributor

@ruisizhang123 ruisizhang123 commented Aug 5, 2025

As titled, this pr adds support for simplefsdp+ep.

Profiler Trace & Correctness

The following results are benchmarks on 8 H100. The correctness are validated on eager (with seed set to 42); I also added profile trace and tlparse in compile mode.

  1. SimpleFSDP
  • Loss:
Screenshot 2025-09-23 at 8 56 50 PM
  • Tlparse: Link

  • Profile Trace: Link there would be only fsdp comms.

Screenshot 2025-09-29 at 9 35 37 AM
  1. SimpleFSDP + EP(degree=2)
  • Loss:
Screenshot 2025-09-23 at 7 50 07 PM Screenshot 2025-09-23 at 7 35 36 PM
  1. SimpleFSDP + TP(degree=2) + EP(degree=2) + ETP(degree=1)
  • Loss:
Screenshot 2025-09-23 at 7 58 18 PM Screenshot 2025-09-23 at 8 09 41 PM
  1. HSDP(dp_shard=2; dp_replicate=2) + EP (degree=2)
  • Loss:
Screenshot 2025-09-24 at 9 44 48 PM Screenshot 2025-09-24 at 9 16 44 PM
  1. HSDP(dp_shard=2; dp_replicate=2) + EP (degree=2) + EP(degree=2) + ETP(degree=1)
  • Loss:
Screenshot 2025-09-24 at 9 41 20 PM Screenshot 2025-09-24 at 9 30 04 PM

Status

  1. Currently, SimpleFSDP+EP+AC doesn't work because of graph breaks in HOP mutation. (see the comment message )

  2. @ezyang also points out that numerical issues in Float8 training may be related to this, since quantization calibration involves mutation, which might not compose well with AC.

  3. There are some minor numerical differences after EP is added. @tianyu-l suggests that it's because the reduce-scatter is averaged over whole dp_mesh, instead of dp_mod_ep_mesh, for the EP part. We will fix this in a separate PR.

  4. PP(1F1B and Interleaved1F1B)+EP+SimpleFSDP has a series of composability issues: add support for simplefsdp+ep #1529 (comment)

cc. @anijain2305

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 5, 2025
@ruisizhang123 ruisizhang123 marked this pull request as draft August 5, 2025 16:35
@ruisizhang123 ruisizhang123 force-pushed the ruisi/simplefsdp_ep branch 8 times, most recently from 058cf25 to 407f5f8 Compare August 7, 2025 18:46
@ruisizhang123 ruisizhang123 changed the title [WIP] add support for simplefsdp+ep add support for simplefsdp+ep Aug 7, 2025
@ruisizhang123 ruisizhang123 marked this pull request as ready for review August 7, 2025 19:05
@ruisizhang123 ruisizhang123 force-pushed the ruisi/simplefsdp_ep branch 4 times, most recently from 97bb837 to 273e7b4 Compare August 14, 2025 23:07
@ruisizhang123 ruisizhang123 force-pushed the ruisi/simplefsdp_ep branch 2 times, most recently from b1d2def to f194b2c Compare September 22, 2025 21:48
@ezyang
Copy link
Contributor

ezyang commented Sep 23, 2025

one in buffer mutation in self.input_splits = num_tokens_per_expert.view(device_mesh.shape[0], -1).sum(dim=1) in expert_parallel.py after confirming with @xmfan LINK;

This graph break can be removed by turning on capture_scalar_outputs right?

@ruisizhang123
Copy link
Contributor Author

ruisizhang123 commented Sep 23, 2025

one in buffer mutation in self.input_splits = num_tokens_per_expert.view(device_mesh.shape[0], -1).sum(dim=1) in expert_parallel.py after confirming with @xmfan LINK;

This graph break can be removed by turning on capture_scalar_outputs right?

Yes, it's turned on here:

torch._dynamo.config.capture_scalar_outputs = True
; The only problem now is with AC+SimpleFSDP+EP, which has some HOP mutation errors.

@xmfan
Copy link
Member

xmfan commented Sep 23, 2025

AC HOP doesn't allow mutations in it, what are you wrapping in AC for this configuration?

if you apply ac to attention and moe experts separately, we can rewrite the buffer mutation to happen outside of moe

@ruisizhang123
Copy link
Contributor Author

I tried SAC in the default torchtitan wrapping setting. I'm working on a refactor of this PR. I can try wrapping AC separately later and hand to you to check the mutation errors?

@ruisizhang123
Copy link
Contributor Author

ruisizhang123 commented Sep 23, 2025

The log for AC mutation error is:

Note: tolist() works withut AC, and I don't think it's the problem with capture_scalar_outputs. You can reproduce this with CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name deepseekv3_simple_fsdp --parallelism.expert_parallel_degree 2 --profiling.enable_profiling --activation_checkpoint.mode "selective" --compile.enable

cc @ezyang @xmfan

torch._dynamo.exc.Unsupported: HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)
    Explanation: This is not supported.

 Developer debug context: 
  
   For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0067.html
  
  from user code:
     File "/home/ruisizhang123/torchtitan/torchtitan/models/deepseek_v3/model/model.py", line 388, in forward
      h = layer(h, self.freqs_cis)
    File "/home/ruisizhang123/pytorch/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 171, in forward
      return self.checkpoint_fn(  # type: ignore[misc]
    File "/home/ruisizhang123/torchtitan/torchtitan/models/deepseek_v3/model/model.py", line 300, in forward
      x = x + self.moe(self.ffn_norm(x))
    File "/home/ruisizhang123/torchtitan/torchtitan/models/moe.py", line 418, in forward
      routed_output = self.experts(routed_input, num_tokens_per_expert)
    File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1878, in _call_impl
      return inner()
    File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1818, in inner
      args_result = hook(self, args)
    File "/home/ruisizhang123/pytorch/torch/distributed/tensor/_api.py", line 938, in <lambda>
      lambda mod, inputs: input_fn(mod, inputs, device_mesh)
    File "/home/ruisizhang123/torchtitan/torchtitan/distributed/expert_parallel.py", line 119, in _token_dispatch
      self.input_splits = input_splits.tolist()

@ruisizhang123 ruisizhang123 force-pushed the ruisi/simplefsdp_ep branch 2 times, most recently from e386ea6 to 8696988 Compare September 25, 2025 05:41
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the ground-breaking work on SimpleFSDP + EP.

Left comments.


if job_config.compile.enable:
torch._inductor.config.reorder_for_peak_memory = False
torch._dynamo.config.capture_scalar_outputs = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last time I heard from @xmfan that there are correctness concern for this field, although it gives us full graph?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is to ensure tolist() in A2A can be traced into the graph, otherwise it would cause graph breaks. It reads the items from a list by treating it as a tensor, and readout as .item(). Not sure if there would be correctness concerns @xmfan.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There shouldn't be. Bother @bobrenjc93 and @laithsakka if there are

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so flipping those flags should NOT ""effect correctness"" but lets define correctness here.

flipping those flags, would opt into unbacked semantics for some unbacked dependent ops (but thats the only way you can trace through those). With unbacked semantics some behaviors could deviate from eager not in a very harmful manner, usually side effects are different output strides, or clones happening.

For example a reshape that depends on unbacked symbols(outputs of .tolist() could result in a clone changing the the output strides of the reshape; had those symbols been known and a reshape was translated to view strides could have been different).

transformer_block.moe.shared_experts = data_parallel(
transformer_block.moe.shared_experts,
dp_mesh,
"replicate",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd recommend we keep the same decision as torchtitan FSDP2 code -- shard instead of replicate. Using replicate here would mean that it can't be bucketed with the reset of TransformerBlock, as the comms are different.

The perf / tradeoffs are not clear at this moment.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we know the perf ROI for replicate? is there value in u-benching perf of either approach?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The replicated compute in shared expert can potentially be overlapped with A2A in token dispatch/combine. However, shared expert is related small in DSV3. This is why Tianyu said replicate or shard the experts won't make a big difference to perf.

Comment on lines 352 to 374
# we shouldn't parametrize the parameters that are already sharded in DP
if isinstance(p, DTensor) and any(
"dp_shard" in dim_name for dim_name in p._spec.mesh.mesh_dim_names
):
enable_parametrizations.append(False)
else:
enable_parametrizations.append(True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I can't parse this. Is it for nested data_parallel wrapping?
Could you elaborate the idea? I feel this is the biggest UX decision to make, so let's be careful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this is for nested data_parallel. Specifically, it's for the MoE wrapping and model wrapping here: filtering out what module has been wrapped by data_parallel and wrapping the rest would make deepseek_v3_parallelize.py messy.

Comment on lines 331 to 336
param_sharding = (Replicate(),) * device_mesh.ndim
elif mode == "fully_shard":
param_sharding = (Shard(0),)
param_sharding = (Shard(0),) * device_mesh.ndim
elif mode == "hybrid_shard":
# replicate inter-host, fully shard intra-host
param_sharding = (Replicate(), Shard(0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't always do DP shard on dim-0 for MoE. See #1561

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense, I hit the dim not match bug when adding HSDP. Will try to update like this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are testing results following your PR. I also added _ StridedShard based on offline discussion:

  1. 4 experts, FSDP2 EP4

[rank0]:p after DTensor(local_tensor=tensor(..., device='meta', size=(1, 128, 256)), device_mesh=DeviceMesh((dp_shard_mod_ep=2, ep=4), device: 'cuda', stride: (4, 1)), placements=(Shard(dim=1), Shard(dim=0)))
[rank0]:p after DTensor(local_tensor=tensor(..., device='meta', size=(1, 128, 256)), device_mesh=DeviceMesh((dp_shard_mod_ep=2, ep=4), device: 'cuda', stride: (4, 1)), placements=(Shard(dim=1), Shard(dim=0)))
[rank0]:p after DTensor(local_tensor=tensor(..., device='meta', size=(1, 128, 256)), device_mesh=DeviceMesh((dp_shard_mod_ep=2, ep=4), device: 'cuda', stride: (4, 1)), placements=(Shard(dim=1), Shard(dim=0)))

  1. 8 experts FSDP2 EP4

[rank0]:p after DTensor(local_tensor=tensor(..., device='meta', size=(1, 256, 256)), device_mesh=DeviceMesh((dp_shard_mod_ep=2, ep=4), device: 'cuda', stride: (4, 1)), placements=(_StridedShard(dim=0, sf=4), Shard(dim=0)))
[rank0]:p after DTensor(local_tensor=tensor(..., device='meta', size=(1, 256, 256)), device_mesh=DeviceMesh((dp_shard_mod_ep=2, ep=4), device: 'cuda', stride: (4, 1)), placements=(_StridedShard(dim=0, sf=4), Shard(dim=0)))
[rank0]:p after DTensor(local_tensor=tensor(..., device='meta', size=(1, 256, 256)), device_mesh=DeviceMesh((dp_shard_mod_ep=2, ep=4), device: 'cuda', stride: (4, 1)), placements=(_StridedShard(dim=0, sf=4), Shard(dim=0)))

inner_spec.placements[1],
)
else:
# For FSDP + TP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why it has to be FSDP, can it be FSDP + EP?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, because this _distribute_dtensor function would only be called when p is a DTensor -- p has to be shareded by TP to be a DTensor. In FSDP + EP case, it would call distribute_tensor instead of this one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? After EP sharding it's also a DTensor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm yes, you are right. It will call _distribute_dtensor. But the functional is still the same since it would be either FSDP+EP or FSDP+TP. We checked if tp is in the device mesh dim name when calling _StridedShard.

grad_placements=self.grad_placements
)

if (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again I think here we should try to make code generic, without hardcoding "ep" / "tp" if possible. Please give a try and let me know if you find it impossible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I totally agree on it. I think we can explicitly parse dim name used for doing "ep" and "tp" into data_parallel function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think have a dp_tp_ep_mesh_name arg input to allow user specify their tp/ep mesh name would make sense? Link

Comment on lines 125 to 135
transformer_block.moe.experts = data_parallel(
transformer_block.moe.experts,
world_mesh[tuple(dp_mod_ep_mesh_dim_names)],
dp_mode,
ac_mode=job_config.activation_checkpoint.mode,
mp_policy=mp_policy,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also need to solve this issue
#1551
o/w the gradient has wrong scale

Copy link
Contributor Author

@ruisizhang123 ruisizhang123 Sep 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, will add!

@wconstab
Copy link
Contributor

@tianyu-l do you want to give the final stamp? I think it looks pretty good overall but you raised some good points.

@ruisizhang123
Copy link
Contributor Author

I was able to address most of @tianyu-l's comment, except for fsdp_gradient_divide_factor in this pr. I realized FSDP2 has divide factor inside foreach_reduce here. I will need to do similar thing to ReplicaCompute. It would make things easier to review if I add the gradient_divide_factor as a separate PR.

@ruisizhang123 ruisizhang123 force-pushed the ruisi/simplefsdp_ep branch 2 times, most recently from 7513799 to ae8cee7 Compare September 26, 2025 23:44
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, I would recommend:

  1. supporting nested data_parallel call but relies on altered FSDP module name instead of device mesh name, which is not reliable.
  2. supporting gradient divisor via a custom subclass of Partial. This can be done in a separate PR.

@ruisizhang123
Copy link
Contributor Author

ruisizhang123 commented Sep 27, 2025

Thank you for the summary.

  1. I'm now using __class__.__name__ to check if the module has been wrapped by simplefsdp here.

  2. I'm also using the non_dp_mesh_dim to handle EP/TP/EP+TP wrapping in RelicateCompute here. As such, we won't hardcode EP/TP dim name.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost there!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's

  1. not change the name of this function
  2. include SimpleFSDP + EP composability tests for deepseek v3, e.g. you can consider these two https://github.com/pytorch/torchtitan/blob/main/tests/integration_tests/models.py#L47-L79

Copy link
Contributor Author

@ruisizhang123 ruisizhang123 Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good. added new test for FSDP+EP/FSDP+EP+TP/FSDP+EP+ETP.

I skipped PP since I found both 1F1B and Interleaved1F1B doesn't compose well with EP+TP. The following PP enablement all work in eager mode. However, after adding compile, I hit some errors.

Here are some explorations (I will also add to PR summary). I didn't do a in-depth debugging, but most of them seems to be related to dynamic shapes lolll.

  1. SimpleFSDP + PP (1F1B) + TP + compile: works
  2. SimpleFSDP + PP (1F1B) + EP + compile: failed with error below (Tlparse Link)

[rank0]:[titan] 2025-09-29 21:51:23,835 - root - INFO - Training starts at step 1
[rank0]:/home/ruisizhang123/pytorch/torch/distributed/device_mesh.py:321: UserWarning: You are attempting to slice a submesh from another submesh. While we support this operation, it is users' responsibility to ensure that the submesh is consistently sliced across all ranks. If not, this may result in some ranks receiving the submesh while others encounter errors.
[rank0]: warnings.warn(
[rank0]:/home/ruisizhang123/pytorch/torch/distributed/device_mesh.py:321: UserWarning: You are attempting to slice a submesh from another submesh. While we support this operation, it is users' responsibility to ensure that the submesh is consistently sliced across all ranks. If not, this may result in some ranks receiving the submesh while others encounter errors.
[rank0]: warnings.warn(
[rank0]:[rank0]:W0929 21:51:27.582000 1518045 torch/_logging/_internal.py:1199] [0/0]
[rank0]:[rank0]:W0929 21:51:27.582000 1518045 torch/_logging/_internal.py:1199] [0/0] Detected that context_fn is passed to torch.utils.checkpoint under torch.compile.
[rank0]:[rank0]:W0929 21:51:27.582000 1518045 torch/_logging/internal.py:1199] [0/0] Please make sure the checkpointed region does not contain in-place ops (e.g. torch.relu).
[rank0]:[rank0]:W0929 21:51:27.582000 1518045 torch/_logging/_internal.py:1199] [0/0]
[rank0]:/home/ruisizhang123/pytorch/torch/_inductor/lowering.py:2004: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]: warnings.warn(
[rank0]:[rank0]: Traceback (most recent call last):
[rank0]:[rank0]: File "", line 198, in _run_module_as_main
[rank0]:[rank0]: File "", line 88, in _run_code
[rank0]:[rank0]: File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 671, in
[rank0]:[rank0]: trainer.train()
[rank0]:[rank0]: ~~~~~~~~~~~~~^^
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/distributed/elastic/multiprocessing/errors/init.py", line 357, in wrapper
[rank0]:[rank0]: return f(*args, **kwargs)
[rank0]:[rank0]: File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 596, in train
[rank0]:[rank0]: self.train_step(data_iterator)
[rank0]:[rank0]: ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 496, in train_step
[rank0]:[rank0]: loss = self.forward_backward_step(input_dict, labels)
[rank0]:[rank0]: File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 445, in forward_backward_step
[rank0]:[rank0]: self.pp_schedule.step(
[rank0]:[rank0]: ~~~~~~~~~~~~~~~~~~~~~^
[rank0]:[rank0]: inputs,
[rank0]:[rank0]: ^^^^^^^
[rank0]:[rank0]: ...<3 lines>...
[rank0]:[rank0]: input_batch=inputs,
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: )
[rank0]:[rank0]: ^
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/distributed/pipelining/schedules.py", line 624, in step
[rank0]:[rank0]: self._step_microbatches(args_split, kwargs_split, targets_split, losses)
[rank0]:[rank0]: ~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/distributed/pipelining/schedules.py", line 839, in _step_microbatches
[rank0]:[rank0]: self._initialize_stage(arg_mbs[0], kwarg_mbs[0])
[rank0]:[rank0]: ~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/distributed/pipelining/schedules.py", line 583, in _initialize_stage
[rank0]:[rank0]: self._stage._prepare_forward_infra(self._n_microbatches, args, kwargs)
[rank0]:[rank0]: ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/distributed/pipelining/stage.py", line 1510, in _prepare_forward_infra
[rank0]:[rank0]: outputs = self._shape_inference(args, kwargs)
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/distributed/pipelining/stage.py", line 1441, in _shape_inference
[rank0]:[rank0]: outputs = self.submod(*args, **kwargs)
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/_dynamo/eval_frame.py", line 418, in call
[rank0]:[rank0]: return super().call(*args, **kwargs)
[rank0]:[rank0]: ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
[rank0]:[rank0]: return self._call_impl(*args, **kwargs)
[rank0]:[rank0]: ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1788, in _call_impl
[rank0]:[rank0]: return forward_call(*args, **kwargs)
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/_dynamo/eval_frame.py", line 886, in compile_wrapper
[rank0]:[rank0]: return fn(*args, **kwargs)
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
[rank0]:[rank0]: return self._call_impl(*args, **kwargs)
[rank0]:[rank0]: ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1788, in _call_impl
[rank0]:[rank0]: return forward_call(*args, **kwargs)
[rank0]:[rank0]: File "/home/ruisizhang123/torchtitan/torchtitan/models/deepseek_v3/model/model.py", line 363, in forward
[rank0]:[rank0]: def forward(
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/_dynamo/eval_frame.py", line 1098, in _fn
[rank0]:[rank0]: return fn(*args, **kwargs)
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/_functorch/aot_autograd.py", line 1134, in forward
[rank0]:[rank0]: return compiled_fn(full_args)
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 355, in runtime_wrapper
[rank0]:[rank0]: all_outs = call_func_at_runtime_with_args(
[rank0]:[rank0]: compiled_fn, args, disable_amp=disable_amp, steal_args=True
[rank0]:[rank0]: )
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/utils.py", line 130, in call_func_at_runtime_with_args
[rank0]:[rank0]: out = normalize_as_list(f(args))
[rank0]:[rank0]: ~^^^^^^
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 693, in inner_fn
[rank0]:[rank0]: unwrapped_outs = compiled_fn(unwrapped_args)
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 727, in inner_fn
[rank0]:[rank0]: outs = compiled_fn(args)
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 529, in wrapper
[rank0]:[rank0]: return compiled_fn(runtime_args)
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/_inductor/output_code.py", line 617, in call
[rank0]:[rank0]: return self.current_callable(inputs)
[rank0]:[rank0]: ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/_inductor/utils.py", line 3018, in run
[rank0]:[rank0]: out = model(new_inputs)
[rank0]:[rank0]: File "/tmp/torchinductor_ruisizhang123/sk/csk3kend5qh55mkvyqnaqg7tj7zhubsaoiewgglmofcgbpn6me4y.py", line 2926, in call
[rank0]:[rank0]: raise RuntimeError('Eq(u0 + u1, 6144)')
[rank0]:[rank0]: RuntimeError: Eq(u0 + u1, 6144)

  1. SimpleFSDP + PP (Interleaved1F1B) + TP + compile: failed with error below (full log Link Tlparse (Link))

File "/home/ruisizhang123/pytorch/torch/fx/experimental/proxy_tensor.py", line 571, in thunkify
[rank0]:[rank0]: r = f(*args, **kwargs)
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/fx/experimental/proxy_tensor.py", line 1584, in _compute_proxy
[rank0]:[rank0]: n_args = tuple(
[rank0]:[rank0]: (
[rank0]:[rank0]: ...<4 lines>...
[rank0]:[rank0]: for a in args
[rank0]:[rank0]: )
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/fx/experimental/proxy_tensor.py", line 1586, in
[rank0]:[rank0]: get_proxy_slot(a, self.tracer).force().node
[rank0]:[rank0]: ~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/ruisizhang123/pytorch/torch/fx/experimental/proxy_tensor.py", line 411, in get_proxy_slot
[rank0]:[rank0]: raise RuntimeError(
[rank0]:[rank0]: f"{obj} ({id(obj)})is not tracked with proxy for {tracer}"
[rank0]:[rank0]: )
[rank0]:[rank0]: torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
[rank0]:[rank0]: RuntimeError: 256*s27 (140036273922896)is not tracked with proxy for <torch.fx.experimental.proxy_tensor.PythonKeyTracer object at 0x7f5c29750fc0>
[rank0]:
[rank0]:[rank0]: While executing %embedding : [num_users=1] = call_function[target=torch.nn.functional.embedding](args = (%input_tensor, %output, None, None, 2.0, False, False), kwargs = {})
[rank0]:[rank0]: Original traceback:
[rank0]:[rank0]: File "/home/ruisizhang123/torchtitan/torchtitan/models/deepseek_v3/model/model.py", line 385, in forward
[rank0]:[rank0]: h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens
[rank0]:
[rank0]:[rank0]: Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)
[rank0]:

  1. SimpleFSDP + PP (Interleaved1F1B) + EP + compile: went through, but not sure if it's really working (Tlpase Link)

I checked the Tlparse for 4, and found it has the same Eq(u0 + u1, 6144) as 2 (SimpleFSDP + PP (1F1B) + EP + compile). It's just that the code didn't go into Eq(u0 + u1, 6144) in if-else branch.

cc. @xmfan @ezyang @bdhirsh @sanketpurandare

@ruisizhang123 ruisizhang123 force-pushed the ruisi/simplefsdp_ep branch 2 times, most recently from 1df12f1 to ce849e6 Compare September 30, 2025 04:23
@ruisizhang123
Copy link
Contributor Author

Almost there!

Thank you for the review. Updated accordingly.

@ruisizhang123 ruisizhang123 force-pushed the ruisi/simplefsdp_ep branch 3 times, most recently from ce67540 to 0761f4d Compare September 30, 2025 18:29
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very nice. Let's follow up with

  1. Partial enhancement
  2. fixing PP

Please fix a minor comment issue (see inline) before merge.

)

# re-wrap 1D all-gathered DTensor on dp_mesh to 1D DTensor on tp_mesh
# re-wrap 1D all-gathered DTensor on dp_mesh to 1D DTensor on tp_mesh/ep_mesh
Copy link
Contributor

@tianyu-l tianyu-l Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment doesn't sound right, even before this PR.
I think what it tries to say is

Suggested change
# re-wrap 1D all-gathered DTensor on dp_mesh to 1D DTensor on tp_mesh/ep_mesh
# re-wrap all-gathered DTensor on dp_mesh to be on non_dp_mesh

@ruisizhang123 ruisizhang123 force-pushed the ruisi/simplefsdp_ep branch 2 times, most recently from 48dd348 to 1360a85 Compare September 30, 2025 20:39
@ruisizhang123 ruisizhang123 merged commit 37e536d into main Sep 30, 2025
5 checks passed
@ruisizhang123 ruisizhang123 deleted the ruisi/simplefsdp_ep branch September 30, 2025 22:37
ruisizhang123 added a commit that referenced this pull request Oct 9, 2025
this PR is a followup of SimpleFSDP+EP
[PR](#1529). Here, we add a
`gradient_divide_factor` following FSDP2 to ensure modules wrapped by
(FSDP+EP) has the correct gradient reduction value.

- The original FSDP2 implementation is in this
[PR](#1551).
- The `gradient_divide_factor` logic is
[here](https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py#L688)

We have two ways of handling `gradient_divide_factor` in
`reduce_scatter`:

1. The first one is to use `ReduceOp.PREMUL_SUM` to handle the
`gradient_divide_factor`. However, DTensor's `_reduce_shard_value` only
accepts `reduce_op` as a str input
([here](https://github.com/pytorch/pytorch/blob/8f705d019a64b1ca882e043b3eb98559273a9e59/torch/distributed/tensor/placement_types.py#L177-L210)).

To make` _reduce_shard_value` work correctly with ReduceOp.PREMUL_SUM,
we need to update the DTensor `_reduce_shard_tensor` and
`torch.distributed._functional_collectives.reduce_scatter_tensor` so
that it can pass the factor associated with ReduceOp.PREMUL_SUM as an
input.



2. Another way is to simulate `ReduceOp.PREMUL_SUM` with `ReduceOp.SUM`.
The logic is in this [Diff](https://www.internalfb.com/diff/D76546536).
It does a `div_` over gradient before performing `ReduceOp.SUM`.

Currently I'm following 2 since it is requires less change to
`_functional_collectives`.


After enabling `reduction_divide_factor`, we will see FSDP(=2) + EP (=4)
have identical loss:

<img width="1194" height="780" alt="Screenshot 2025-10-08 at 5 27 24 PM"
src="https://github.com/user-attachments/assets/aaf83109-8db8-4051-973d-c7b6950513de"
/>
githubsgi pushed a commit to githubsgi/torchtitan that referenced this pull request Oct 13, 2025
this PR is a followup of SimpleFSDP+EP
[PR](pytorch#1529). Here, we add a
`gradient_divide_factor` following FSDP2 to ensure modules wrapped by
(FSDP+EP) has the correct gradient reduction value.

- The original FSDP2 implementation is in this
[PR](pytorch#1551).
- The `gradient_divide_factor` logic is
[here](https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py#L688)

We have two ways of handling `gradient_divide_factor` in
`reduce_scatter`:

1. The first one is to use `ReduceOp.PREMUL_SUM` to handle the
`gradient_divide_factor`. However, DTensor's `_reduce_shard_value` only
accepts `reduce_op` as a str input
([here](https://github.com/pytorch/pytorch/blob/8f705d019a64b1ca882e043b3eb98559273a9e59/torch/distributed/tensor/placement_types.py#L177-L210)).

To make` _reduce_shard_value` work correctly with ReduceOp.PREMUL_SUM,
we need to update the DTensor `_reduce_shard_tensor` and
`torch.distributed._functional_collectives.reduce_scatter_tensor` so
that it can pass the factor associated with ReduceOp.PREMUL_SUM as an
input.



2. Another way is to simulate `ReduceOp.PREMUL_SUM` with `ReduceOp.SUM`.
The logic is in this [Diff](https://www.internalfb.com/diff/D76546536).
It does a `div_` over gradient before performing `ReduceOp.SUM`.

Currently I'm following 2 since it is requires less change to
`_functional_collectives`.


After enabling `reduction_divide_factor`, we will see FSDP(=2) + EP (=4)
have identical loss:

<img width="1194" height="780" alt="Screenshot 2025-10-08 at 5 27 24 PM"
src="https://github.com/user-attachments/assets/aaf83109-8db8-4051-973d-c7b6950513de"
/>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants