Skip to content

Conversation

@xmfan
Copy link
Member

@xmfan xmfan commented Jul 23, 2025

This PR adds a HOP working with the aot_module_export frontend of autoparallel. Lack of Dynamo frontend imposes some limitations documented on the apply_local_map API, which will also no longer be needed once we move to Dynamo.

At a high-level:

  • We manually wrap local_map'd functions into local_map_hop
  • When we trace, local_map_hop implements dispatch keys to prevent the compiler from inlining into it. It instead inserts a call_function node into the tracer's graph, with node metadata containing the local_map placement information
  • Solver special handles nodes with local_map placement information, and allows the node to be inlined in the next trace by adding a kwarg: _inline. This special handling is actually desired, we need the solver to consider the placements mandated by local_map in it's global cost optimization.
  • When we trace while applying sharding, the local_map_hop is inlined into, with tensors that should already respect the local_map placements, thus local_map becomes a no-op. I considered having the HOP exclude the local_map and only wrap the inner function, but I opted out because it would require some special handling of the HOP, and it would only work for Autoparallel.

Test running: PYTHONPATH=. python examples/example_local_map.py
Sharding and cost estimations look reasonable: P1881707012.

# Graph solver sees
alias_default = torch.ops.aten.alias.default(param_0);  param_0 = None  # placement=(S(0)S(0)S(0)) -> S(0)S(0)S(0), cost=[0.0]
# sharded_pointwise
call_local_map = autoparallel_local_map_hop_call_local_map(alias_default, _inline = True);  alias_default = None  # placement=(S(0)S(0)R) -> S(0)S(0)R, cost=[27.190755059863168]
alias_default_1 = torch.ops.aten.alias.default(param_1);  param_1 = None  # placement=(RRR) -> RRR, cost=[59.88767436634411]
alias_default_12 = torch.ops.aten.alias.default(input_0);  input_0 = None  # placement=(RRS(0)) -> RRS(0), cost=[5088.284471856535]
# replicate_linear
call_local_map_1 = autoparallel_local_map_hop_call_local_map_1(call_local_map, alias_default_1, alias_default_12, _inline = True);  call_local_map = alias_default_1 = None  # placement=(RRR, RRR, RRR) -> RRR, cost=[7225.651931947756, 0.0, 3215.355530216648]
view = torch.ops.aten.view.default(alias_default_12, [16384, 6144])  # placement=(RRS(0)) -> RRS(0), cost=[0.0]
alias_default_2 = torch.ops.aten.alias.default(param_2);  param_2 = None  # placement=(S(0)S(0)R) -> S(0)S(0)R, cost=[27.190755059863168]
...
permute_2 = torch.ops.aten.permute.default(view_4, [0, 2, 1, 3]);  view_4 = None  # placement=(RRR) -> RRR, cost=[0.0]
view_5 = torch.ops.aten.view.default(view_1, [64, 256, 48, 128]);  view_1 = None  # placement=(S(0)S(2)S(0)) -> S(0)S(2)S(0), cost=[0.0]
permute_3 = torch.ops.aten.permute.default(view_5, [0, 2, 1, 3]);  view_5 = None  # placement=(S(0)S(2)R) -> S(0)S(1)R, cost=[58.50868015963512]
view_6 = torch.ops.aten.view.default(view_3, [64, 256, 48, 128]);  view_3 = None  # placement=(S(0)S(2)R) -> S(0)S(2)R, cost=[151.4406424924847]
permute_4 = torch.ops.aten.permute.default(view_6, [0, 2, 1, 3]);  view_6 = None  # placement=(S(0)S(2)R) -> S(0)S(1)R, cost=[0.0]
# context parallel
call_local_map_2 = autoparallel_local_map_hop_call_local_map_2(permute_2, permute_3, permute_4, _inline = True)  # placement=(S(0)S(1)S(2), S(0)S(1)R, S(0)S(1)R) -> S(0)S(1)S(2), cost=[0.0, 0.0, 0.0]
permute_5 = torch.ops.aten.permute.default(call_local_map_2, [0, 2, 1, 3]);  call_local_map_2 = None  # placement=(S(0)S(1)S(2)) -> S(0)S(2)S(1), cost=[0.0]
view_7 = torch.ops.aten.view.default(permute_5, [64, 256, 6144]);  permute_5 = None  # placement=(S(0)S(2)S(1)) -> S(0)S(2)S(1), cost=[0.0]
view_8 = torch.ops.aten.view.default(view_7, [16384, 6144]);  view_7 = None  # placement=(S(0)S(2)R) -> S(0)S(1)R, cost=[234.03472063854048]

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 23, 2025
@ezyang
Copy link
Contributor

ezyang commented Jul 23, 2025

HOP should live in pytorch/pytorch proper. OK to land here for now but better to get to the main repo sooner rather than later (in case someone is doing HOP refactoring).

@ezyang
Copy link
Contributor

ezyang commented Jul 23, 2025

Solver special handles nodes with local_map placement information, and allows the node to be inlined in the next trace by adding a kwarg: _inline.

autoparallel retraces with make_fx, so it seems like it can just evaporate it when that happens.

out = parallel_mod(*x)
out.backward(torch.randn_like(out))

print("All good!")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example doesn't actually tell us that the local map regions had their sharding respected, you have to explicitly check it in the output! Additionally, the local_map'ed over functions are a bit "too" simple, and are expressible as DTensor operations, IIUC.

Copy link
Member Author

@xmfan xmfan Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I didn't write the checks into the code, but from the tlparse, the solver estimated costs look like the sharding decision made it through. i'm gonna try with the context parallel to see if it works already



def get_local_map_placement_option(
mesh, op, specs, user_args, user_kwargs, output_val, in_placements, out_placements
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type signature here would be great

)
strats[node] = strat
if local_map_kwargs := node.meta.get("custom", {}).get(
"dtensor_local_map_kwargs"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we testing the custom meta as opposed to looking at op.target to identify a local map HOP?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we still need to pass the placement kwargs somehow, op.target alone wouldn't be enough

assert local_map_kwargs[
"redistribute_inputs"
], "Autoparallel should always be allowed to redistribute inputs"
assert local_map_kwargs["in_grad_placements"] is None, "Not yet implemented"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my first impression was in_grad_placements can be supported out-of-box, but it seems not the case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the user should be allowed to specify in_grad_placements, since in autoparallel, the user would have no way of knowing the solver's decision at the time of writing local_map. But there could be an API to actually tell the solver to shard the grads in a certain way. TBD, don't have a clear proposal yet

"""

def __init__(self):
super().__init__("local_map_hop")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe this is hard, but it would be nice if some info about orig_fwd (e.g. name/line) were included in the name of the hop, so that the graph doesn't just show "hop1, hop2, hop3" with no info about what's inside

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

having stack traces in the autoparallel graphs for all nodes would address this issue too

with suspend_functionalization(), disable_functional_mode():
with disable_proxy_modes_tracing():

def _from_fun(t):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more descriptive name? (maybe _empty_like() or something)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, i totally missed that fun means functional. Although I still don't know what that means.

isinstance(t, (FakeTensor, int, torch.SymInt)) for t in fw_inputs
)

# redundant? we already _from_fun'd the inputs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the concern that usercode could allocate new tensors? if they did, they would already be fake because of the mode, so i'm not sure why we can't return them directly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm not sure why other HOPs do this. you shouldn't need to unwrap the outputs if the inputs were already unwrapped and the fw_func doesn't introduce new functional tensors.

like you said allocating new tensors should just be fakes, and not be functional wrapped. the only thing i can think of is if fw_inputs weren't properly unwrapped

for example_grad in example_grads:
if example_grad.requires_grad:
optional_grads.append(example_grad)
_, grads = joint(fw_inputs, optional_grads)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i was expecting joint to return both forward outputs and grads. maybe i just misunderstood this?

Copy link
Member Author

@xmfan xmfan Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could, it doesn't matter if fw_func always gives the same graph

save_tensors_and_symints_for_backward(ctx, args)
ctx.joint_graph = joint_graph

with torch._C._AutoDispatchBelowAutograd(): # why
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this for?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC we need to do this because issuing ops will always go through all their keys by default. Dispatching below autograd hits the functional key next which enters the redispatch_to_next context, which applies to all keys afterwards.

Here, I'm not sure why we need to explicitly use this API here instead of entering the redispatch_to_next context already. Just following some existing HOPs code...


unwrapped_inputs = ctx.unwrap_tensors(args)
with ctx.redispatch_to_next():
# TODO: dynamo safety checks on orig_fwd
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes please. It's a little hard to believe that there isn't an easy utility for this?

Copy link
Member Author

@xmfan xmfan Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is, speculate_subgraph, but it needs dynamo to apply it, and pass a safe orig_fwd_graph instead of orig_fwd to the hop.

Or were you thinking of torch._subclasses.functional_tensor.*.functionalize

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a concrete implementation in mind, I'm just working off of the prior that we already have a lot of HOPs in PyTorch core and this one, by all accounts, is a very tame one, and so it should be subsumed by some of the implementations already in core. Hmm... paging @zou3519 @ydwu4

Copy link

@ydwu4 ydwu4 Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a util check_input_alias_and_mutation_return_outputs for checking aliasing and mutations of tensor inputs. For others like side-effects, lifting tensor closures/symints as inputs (to make sure dependency correctly captured), we would need to turn on dynamo and use speculate_subgraph.

"call_function", call_local_map, proxy_args, {}
)
out_proxy.node.meta["custom"] = {
"dtensor_local_map_kwargs": orig_fwd.local_map_kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Storing this in the meta gives me the jeebies, because if you lose the meta there's no way to get this information back, it is load bearing lol

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is sus lol, but meta is already load bearing for export/ac/ though, and we have special cased keys: https://github.com/pytorch/pytorch/blame/255c0545e7eac2ec6d00a41a3fc9d6d8201f8f39/torch/fx/proxy.py#L107-L120. "custom" was added specifically to try to decouple. I think it's probably one of the better places on the fx.Graph, and haven't really look at passing this info along outside of the graph obj.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@oulgen @jamesjwu I hope we're caching on this meta key!!!

Copy link

@oulgen oulgen Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@oulgen @jamesjwu I hope we're caching on this meta key!!!

I'm pretty sure we are not. I think @zou3519 said that most meta's get dropped before getting to inductor?

I do agree that we should not be using meta as a dumping ground though

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Concretely, this sounds like AC is not cacheable by AOTAutograd then. But I guess AC isn't cacheable for other reasons too lol.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea I remember @xmfan and @jamesjwu talked about this. @xmfan are there any explicit plans to support caching?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In defense of node.meta, this pattern exist because it's quite difficult to propagate information between different compiler modules when tracing. Same story with writing to some global like TracingContext/Virtualized. As long as the types can be cached, can't we snapshot their values at cache lookup time?

local_map_hop = LocalMapAOTExportModule()


def create_hop_joint_graph(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this different than create_joint?

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have any problem with this going in as an unblock. I think it probably needs more work before going in pytorch core.

from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree


class LocalMapAOTExportModule(HigherOrderOperator):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we said no HOPs out of core (@ezyang). Especially if you want to sync this to fbcode.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I will upstream ASAP

Comment on lines 54 to 56
[Replicate(), Replicate(), Shard(2)],
[Replicate(), Replicate(), Shard(2)],
[Replicate(), Replicate(), Shard(2)],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we instead make the placement for the query be S(0), S(1), S(2) , and for the keys and values to be S(0), S(1), R ? This way we are closer to actually enforcing CP, as we will force to all-gather the keys and values, and things should work properly (as long as the is_causal=False).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added


@apply_local_map(
out_placements=[
[Replicate(), Replicate(), Shard(2)],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we make the change as I mentioned before, the output placement would be S(0), S(1), S(2)

@xmfan xmfan merged commit e6565f1 into main Jul 26, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants