Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
256 changes: 256 additions & 0 deletions autoparallel/local_map_hop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# NOTE: this file may be removed once we move to a dynamo frontend

import functools

import torch
import torch.utils._pytree as pytree
from torch._higher_order_ops.utils import (
save_tensors_and_symints_for_backward,
saved_tensors_and_symints,
)
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed._tensor.experimental import local_map
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree


class LocalMapAOTExportModule(HigherOrderOperator):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we said no HOPs out of core (@ezyang). Especially if you want to sync this to fbcode.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I will upstream ASAP

"""
A HOP that integrates with autoparallel's current frontend (aot_export_module).
This HOP exists starting the pre-solver graph and lives until we apply sharding.
During which, orig_fwd will be inlined into the post-solver graph.
"""

def __init__(self):
super().__init__("local_map_hop")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe this is hard, but it would be nice if some info about orig_fwd (e.g. name/line) were included in the name of the hop, so that the graph doesn't just show "hop1, hop2, hop3" with no info about what's inside

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

having stack traces in the autoparallel graphs for all nodes would address this issue too


def __call__(self, orig_fwd, *args, **kwargs):
return super().__call__(orig_fwd, *args, **kwargs)


local_map_hop = LocalMapAOTExportModule()


def create_hop_joint_graph(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this different than create_joint?

fw_func,
*_args,
):
# Keeping these imports here
# Avoid circular dependencies once we upstream with dynamo frontend
from torch._dispatch.python import suspend_functionalization
from torch._functorch.aot_autograd import AOTConfig, create_joint
from torch._guards import detect_fake_mode
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch._subclasses.functional_tensor import disable_functional_mode
from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing, make_fx

dummy_aot_config = AOTConfig(
fw_compiler=None, # type: ignore[arg-type]
bw_compiler=None, # type: ignore[arg-type]
partition_fn=None, # type: ignore[arg-type]
decompositions={},
num_params_buffers=0,
aot_id=0,
keep_inference_input_mutations=False,
)

with suspend_functionalization(), disable_functional_mode():
with disable_proxy_modes_tracing():

# create a tensor (fake) from a compiler wrapped FunctionalTensor
def _from_fun(t):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more descriptive name? (maybe _empty_like() or something)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, i totally missed that fun means functional. Although I still don't know what that means.

if isinstance(t, torch.Tensor):
return torch.empty_strided(
t.size(),
t.stride(),
device=t.device,
dtype=t.dtype,
requires_grad=t.requires_grad,
)
return t

# If someone runs this hop under the default compiler backend ("eager")
# Then this path will be run with the actual user inputs. We convert them
# to fake tensors in order to not perform any actual compute.

fake_mode = detect_fake_mode(_args)
if fake_mode is None:
fake_mode = FakeTensorMode(allow_non_fake_inputs=True)

with fake_mode:
fw_inputs = pytree.tree_map(_from_fun, _args)

assert all(
isinstance(t, (FakeTensor, int, torch.SymInt)) for t in fw_inputs
)

# redundant? we already _from_fun'd the inputs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the concern that usercode could allocate new tensors? if they did, they would already be fake because of the mode, so i'm not sure why we can't return them directly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm not sure why other HOPs do this. you shouldn't need to unwrap the outputs if the inputs were already unwrapped and the fw_func doesn't introduce new functional tensors.

like you said allocating new tensors should just be fakes, and not be functional wrapped. the only thing i can think of is if fw_inputs weren't properly unwrapped

example_flat_out = pytree.tree_map(
_from_fun,
fw_func(*fw_inputs),
)
example_grads = _from_fun(example_flat_out)
if not isinstance(example_grads, (list, tuple)):
example_grads = [example_grads]

def joint_f(
*primals_and_tangents,
):
fw_inputs = primals_and_tangents[: len(_args)]
example_grads = primals_and_tangents[len(_args) :]

def run_fwd(*fw_inputs):
outs = fw_func(*fw_inputs)
if not isinstance(outs, (list, tuple)):
outs = (outs,)
masks = [o.requires_grad for o in outs]
return (outs, masks)

joint = create_joint(run_fwd, aot_config=dummy_aot_config)
optional_grads = []
for example_grad in example_grads:
if example_grad.requires_grad:
optional_grads.append(example_grad)
_, grads = joint(fw_inputs, optional_grads)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i was expecting joint to return both forward outputs and grads. maybe i just misunderstood this?

Copy link
Member Author

@xmfan xmfan Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could, it doesn't matter if fw_func always gives the same graph

return grads

primals_and_tangents = [*fw_inputs, *example_grads]
joint_graph = make_fx(joint_f)(*primals_and_tangents)
return None, joint_graph


class LocalMapAutogradOp(torch.autograd.Function):
@staticmethod
def forward(ctx, orig_fwd, joint_graph, *args, **kwargs):
ctx.save_for_backward(*args)

save_tensors_and_symints_for_backward(ctx, args)
ctx.joint_graph = joint_graph

with torch._C._AutoDispatchBelowAutograd():
return local_map_hop(orig_fwd, *args, **kwargs)

@staticmethod
def backward(ctx, *grads):
args = saved_tensors_and_symints(ctx)
grad_ins = ctx.joint_graph(*args, *grads)
# TODO: hopify to support local_map'd function containing custom autograd.Function
return None, None, *grad_ins


@local_map_hop.py_impl(torch._C.DispatchKey.Autograd)
def autograd_key(
orig_fwd,
*args,
**kwargs,
):
if "_inline" in kwargs:
# Solver pass adds a _inline kwarg, which tells this hop to desugar on the next trace
del kwargs["_inline"]
return orig_fwd(*args, **kwargs)

_, joint_graph = create_hop_joint_graph(orig_fwd, *args)
return LocalMapAutogradOp.apply(orig_fwd, joint_graph, *args, **kwargs)


@local_map_hop.py_functionalize_impl
def functional_mode_key(ctx, orig_fwd, *args, **kwargs):
assert not kwargs

unwrapped_inputs = ctx.unwrap_tensors(args)
with ctx.redispatch_to_next():
# TODO: dynamo safety checks on orig_fwd
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes please. It's a little hard to believe that there isn't an easy utility for this?

Copy link
Member Author

@xmfan xmfan Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is, speculate_subgraph, but it needs dynamo to apply it, and pass a safe orig_fwd_graph instead of orig_fwd to the hop.

Or were you thinking of torch._subclasses.functional_tensor.*.functionalize

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a concrete implementation in mind, I'm just working off of the prior that we already have a lot of HOPs in PyTorch core and this one, by all accounts, is a very tame one, and so it should be subsumed by some of the implementations already in core. Hmm... paging @zou3519 @ydwu4

Copy link

@ydwu4 ydwu4 Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a util check_input_alias_and_mutation_return_outputs for checking aliasing and mutations of tensor inputs. For others like side-effects, lifting tensor closures/symints as inputs (to make sure dependency correctly captured), we would need to turn on dynamo and use speculate_subgraph.

out = local_map_hop(orig_fwd, *unwrapped_inputs)
return ctx.wrap_tensors(out)


@local_map_hop.py_impl(FakeTensorMode)
def fake_mode_key(
mode,
orig_fwd,
*args,
**kwargs,
):
with mode:
return orig_fwd(*args, **kwargs)


@local_map_hop.py_impl(ProxyTorchDispatchMode)
def proxy_mode_key(
proxy_mode,
orig_fwd,
*args,
**kwargs,
):
assert (
proxy_mode is not None
), "Mode should always be enabled for python fallback key"
assert len(kwargs) == 0

example_out = local_map_hop(orig_fwd, *args, **kwargs)
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)

def call_local_map(*another_args, **another_kwargs):
return functools.partial(local_map_hop, orig_fwd)(
*another_args, **another_kwargs
)

out_proxy = proxy_mode.tracer.create_proxy(
"call_function", call_local_map, proxy_args, {}
)
out_proxy.node.meta["custom"] = {
"dtensor_local_map_kwargs": orig_fwd.local_map_kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Storing this in the meta gives me the jeebies, because if you lose the meta there's no way to get this information back, it is load bearing lol

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is sus lol, but meta is already load bearing for export/ac/ though, and we have special cased keys: https://github.com/pytorch/pytorch/blame/255c0545e7eac2ec6d00a41a3fc9d6d8201f8f39/torch/fx/proxy.py#L107-L120. "custom" was added specifically to try to decouple. I think it's probably one of the better places on the fx.Graph, and haven't really look at passing this info along outside of the graph obj.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@oulgen @jamesjwu I hope we're caching on this meta key!!!

Copy link

@oulgen oulgen Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@oulgen @jamesjwu I hope we're caching on this meta key!!!

I'm pretty sure we are not. I think @zou3519 said that most meta's get dropped before getting to inductor?

I do agree that we should not be using meta as a dumping ground though

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Concretely, this sounds like AC is not cacheable by AOTAutograd then. But I guess AC isn't cacheable for other reasons too lol.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea I remember @xmfan and @jamesjwu talked about this. @xmfan are there any explicit plans to support caching?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In defense of node.meta, this pattern exist because it's quite difficult to propagate information between different compiler modules when tracing. Same story with writing to some global like TracingContext/Virtualized. As long as the types can be cached, can't we snapshot their values at cache lookup time?

}
return track_tensor_tree(
example_out, out_proxy, constant=None, tracer=proxy_mode.tracer
)


# Running HOP in eager with real tensors
@local_map_hop.py_impl(torch._C.DispatchKey.CPU)
@local_map_hop.py_impl(torch._C.DispatchKey.CUDA)
def real_impl(
orig_fwd,
*args,
**kwargs,
):
return orig_fwd(*args, **kwargs)


def apply_local_map(*local_map_args, **local_map_kwargs):
# NOTE: We manually issue the hop, which will not be not necessary with a dynamo frontend.
# 1. Same as local_map, must be applied on a function, not a method.
# 2. the local_map'd function must be make_fx traceable. Otherwise, we may
# inline the wrong graph. In a dynamo frontend, speculate_subgraph will handle this.
# 3. All inputs to the local_map'd function must be Tensor types. Otherwise, we won't
# know which tensors to apply _from_fun to. For instance, don't pass nn.Modules to local_map.
# In dynamo frontend, tensors will be lifted, and will modify the wrapped function's signature.

assert local_map_kwargs[
"redistribute_inputs"
], "Autoparallel should always be allowed to redistribute inputs"
assert local_map_kwargs["in_grad_placements"] is None, "Not yet implemented"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my first impression was in_grad_placements can be supported out-of-box, but it seems not the case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the user should be allowed to specify in_grad_placements, since in autoparallel, the user would have no way of knowing the solver's decision at the time of writing local_map. But there could be an API to actually tell the solver to shard the grads in a certain way. TBD, don't have a clear proposal yet

assert local_map_kwargs["device_mesh"] is None, "Must be provided by Autoparallel"

def decorator(fn):
@functools.wraps(fn)
def wrapped(*args, **kwargs):
def orig_fwd(*runtime_args, **runtime_kwargs):
# wrap the functools.partial for hop utils to work out of box
return local_map(
fn,
*local_map_args,
**local_map_kwargs,
)(*runtime_args, **runtime_kwargs)

orig_fwd.local_map_kwargs = local_map_kwargs
return local_map_hop(orig_fwd, *args, **kwargs)

return wrapped

return decorator
31 changes: 26 additions & 5 deletions autoparallel/optimize_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
estimate_strategy_runtime_cost,
)
from .propagation_rules import _create_all_options
from .utils import get_placement_options
from .utils import get_local_map_placement_option, get_placement_options


def _debug_node(node):
Expand Down Expand Up @@ -146,10 +146,31 @@ def build_sharding_metadata(self):
user_kwargs = tree_map_only(
torch.fx.Node, lambda x: x.meta["val"], node.kwargs
)
strat = get_placement_options(
self.mesh, node.target, user_strats, user_args, user_kwargs
)
strats[node] = strat
if local_map_kwargs := node.meta.get("custom", {}).get(
"dtensor_local_map_kwargs"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we testing the custom meta as opposed to looking at op.target to identify a local map HOP?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we still need to pass the placement kwargs somehow, op.target alone wouldn't be enough

):
assert "call_local_map" in str(node.target)
assert not user_kwargs
strat = get_local_map_placement_option(
self.mesh,
user_strats,
user_args,
node.meta["val"],
local_map_kwargs["in_placements"],
local_map_kwargs["out_placements"],
)

assert not node.kwargs
node.kwargs = {
"_inline": True
} # notify the HOP to desugar in the next trace

strats[node] = strat
else:
strat = get_placement_options(
self.mesh, node.target, user_strats, user_args, user_kwargs
)
strats[node] = strat
elif node.op == "output":
user_strats = tree_map_only(
torch.fx.Node, lambda x: strats[x], node.args
Expand Down
78 changes: 76 additions & 2 deletions autoparallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,15 @@
# LICENSE file in the root directory of this source tree.

import torch
from torch.distributed._tensor.placement_types import TensorMeta
from torch.distributed._tensor.placement_types import Placement, TensorMeta
from torch.distributed.device_mesh import _get_device_handle
from torch.distributed.tensor._op_schema import OpSchema, OpStrategy, TupleStrategy
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor._op_schema import (
OpSchema,
OpSpec,
OpStrategy,
TupleStrategy,
)
from torch.distributed.tensor._ops.utils import generate_redistribute_costs
from torch.utils._pytree import tree_flatten, tree_map_only

Expand Down Expand Up @@ -123,6 +129,74 @@ def get_placement_options(mesh, op, specs, user_args, user_kwargs):
return out_strat


def get_local_map_placement_option(
mesh,
specs,
user_args,
output_val,
in_placements,
out_placements,
):
in_specs = []
for example, placement in zip(user_args, in_placements):
if placement is None:
# not a dtensor
assert False, "Not sure how to create DTensorSpec for this input"

in_specs.append(
DTensorSpec(
mesh=mesh,
placements=placement,
tensor_meta=TensorMeta(
shape=example.shape,
stride=example.stride(),
dtype=example.dtype,
),
)
)

out_specs = []
assert isinstance(output_val, (torch.Tensor, list, tuple))
outs = output_val if isinstance(output_val, (list, tuple)) else [output_val]
for example, placement in zip(outs, out_placements):
if placement is None:
# not a dtensor
assert False, "Not sure how to create DTensorSpec for this output"
elif isinstance(placement, Placement):
placement = [placement]

assert isinstance(placement, (list, tuple)), "Not implemented"
out_specs.append(
DTensorSpec(
mesh=mesh,
placements=placement,
tensor_meta=TensorMeta(
shape=example.shape,
stride=example.stride(),
dtype=example.dtype,
),
)
)

if len(out_specs) == 1:
out_specs = out_specs[0]

redistribute_costs = []
for user_strategy, input_spec in zip(specs, in_specs):
costs = generate_redistribute_costs(user_strategy, input_spec)
redistribute_costs.append(costs)

return OpStrategy(
[
OpSpec(
output_specs=out_specs,
input_specs=in_specs,
redistribute_cost=redistribute_costs,
)
]
)


def _get_device_from_mesh(mesh):
if mesh.device_type == "cpu":
return torch.device("cpu")
Expand Down
Loading