From ce9d253615e7af929819dade830761f1201f56e0 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Fri, 18 Jul 2025 18:17:06 -0700 Subject: [PATCH 01/10] local_map wip --- autoparallel/optimize_sharding.py | 79 +++++++++- examples/example_local_map.py | 253 ++++++++++++++++++++++++++++++ 2 files changed, 327 insertions(+), 5 deletions(-) create mode 100644 examples/example_local_map.py diff --git a/autoparallel/optimize_sharding.py b/autoparallel/optimize_sharding.py index 137be51b..aee3e168 100644 --- a/autoparallel/optimize_sharding.py +++ b/autoparallel/optimize_sharding.py @@ -81,7 +81,7 @@ import pulp import torch -from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta from torch.distributed.tensor.placement_types import Replicate, Shard from torch.utils._pytree import tree_flatten, tree_map_only @@ -146,10 +146,79 @@ def build_sharding_metadata(self): user_kwargs = tree_map_only( torch.fx.Node, lambda x: x.meta["val"], node.kwargs ) - strat = get_placement_options( - self.mesh, node.target, user_strats, user_args, user_kwargs - ) - strats[node] = strat + if local_map_kwargs := node.meta.get("custom", {}).get("local_map_kwargs"): + in_placements = local_map_kwargs["in_placements"] + out_placements = local_map_kwargs["out_placements"] + in_specs = [] + for input_arg, placement in zip(node.args, in_placements): + if placement is None: + # not a dtensor + assert False, "Not sure how to create DTensorSpec for this input" + + assert isinstance(placement, list), "Not implemented" + example = input_arg.meta["val"] + in_specs.append( + DTensorSpec( + mesh=self.mesh, + placements=placement, + tensor_meta=TensorMeta( + shape=example.shape, + stride=example.stride(), + dtype=example.dtype, + ) + ) + ) + + out_specs = [] + assert isinstance(node.meta["val"], (torch.Tensor, list, tuple)) + outs = node.meta["val"] if isinstance(node.meta["val"], (list, tuple)) else [node.meta["val"]] + for example, placement in zip(outs, out_placements): + from torch.distributed.tensor.placement_types import Placement + if placement is None: + # not a dtensor + assert False, "Not sure how to create DTensorSpec for this output" + elif isinstance(placement, Placement): + placement = [placement] + + assert isinstance(placement, (list, tuple)), "Not implemented" + out_specs.append( + DTensorSpec( + mesh=self.mesh, + placements=placement, + tensor_meta=TensorMeta( + shape=example.shape, + stride=example.stride(), + dtype=example.dtype, + ) + ) + ) + + from torch.distributed.tensor._op_schema import OpStrategy, OpSpec + from torch.distributed.tensor._ops.utils import generate_redistribute_costs + + redistribute_costs = [] + for input_arg, input_spec in zip(node.args, in_specs): + assert isinstance(input_arg, torch.fx.Node) + input_node_strategy = strats[input_arg] + costs = generate_redistribute_costs(input_node_strategy, input_spec) + redistribute_costs.append(costs) + + if len(out_specs) == 1: + out_specs = out_specs[0] + + strat = OpStrategy([ + OpSpec( + output_specs=out_specs, + input_specs=in_specs, + redistribute_cost=redistribute_costs + ) + ]) + strats[node] = strat + else: + strat = get_placement_options( + self.mesh, node.target, user_strats, user_args, user_kwargs + ) + strats[node] = strat elif node.op == "output": user_strats = tree_map_only( torch.fx.Node, lambda x: strats[x], node.args diff --git a/examples/example_local_map.py b/examples/example_local_map.py new file mode 100644 index 00000000..6e4d4262 --- /dev/null +++ b/examples/example_local_map.py @@ -0,0 +1,253 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.distributed.tensor.placement_types import Replicate, Shard +from torch.testing._internal.distributed.fake_pg import FakeStore +from torch.distributed._tensor.experimental import local_map +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch._ops import HigherOrderOperator +import functools + +from autoparallel.api import AutoParallel + +from torch.fx.experimental.proxy_tensor import ( + get_proxy_slot, + ProxyTorchDispatchMode, + disable_proxy_modes_tracing, + track_tensor_tree +) +import torch.utils._pytree as pytree + + +# just to dump tlparse +torch.compile(lambda x: x + 1, backend="eager")(torch.rand(10)) + +class LocalMapAOTExportModule(HigherOrderOperator): + """ + A HOP that integrates with autoparallel's current frontend (aot_export_module). + This HOP exists starting the pre-solver graph and lives until we apply sharding. + During which, runtime_func will be inlined into the post-solver graph. + """ + + def __init__(self): + super().__init__("local_map_hop") + + def __call__(self, runtime_func, *args, **kwargs): + return super().__call__(runtime_func, *args, **kwargs) + + +local_map_hop = LocalMapAOTExportModule() + +# def fn(x): +# return x + +# local_map_hop(fn, torch.randn(10, 10)) +# breakpoint() + +class LocalMapAutogradOp(torch.autograd.Function): + @staticmethod + def forward(ctx, runtime_func, *args, **kwargs): + ctx.save_for_backward(*args) + + with torch._C._AutoDispatchBelowAutograd(): # why + return local_map_hop(runtime_func, *args, **kwargs) + # out = runtime_func(*args, **kwargs) + # return out + # out = runtime_func(*args, **kwargs) + # breakpoint() + # return out + + @staticmethod + def backward(ctx, *grads): + # mmmmm could really use the backward graph here + fwd_inputs = ctx.saved_tensors + return None, *[torch.ones_like(i) * 12345 for i in fwd_inputs] + +@local_map_hop.py_impl(torch._C.DispatchKey.Autograd) +def autograd_key( + runtime_func, + *args, + **kwargs, +): + return LocalMapAutogradOp.apply(runtime_func, *args, **kwargs) + +@local_map_hop.py_functionalize_impl +def functional_mode_key(ctx, runtime_func, *args, **kwargs): + assert not kwargs + + + unwrapped_inputs = ctx.unwrap_tensors(args) + with ctx.redispatch_to_next(): + # TODO: local_map mutation checks + out = local_map_hop(runtime_func, *unwrapped_inputs) + return ctx.wrap_tensors(out) + +@local_map_hop.py_impl(FakeTensorMode) +def fake_mode_key( + mode, + runtime_func, + *args, + **kwargs, +): + with mode: + return runtime_func(*args, **kwargs) + +@local_map_hop.py_impl(ProxyTorchDispatchMode) +def proxy_mode_key( + proxy_mode, + runtime_func, + *args, + **kwargs, +): + assert proxy_mode is not None, "Mode should always be enabled for python fallback key" + assert len(kwargs) == 0 + + example_out = local_map_hop(runtime_func, *args, **kwargs) + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) + def another_wrapper(*another_args, **another_kwargs): + return functools.partial(local_map_hop, runtime_func)(*another_args, **another_kwargs) + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", another_wrapper, proxy_args, {} + ) + out_proxy.node.meta["custom"] = { + "local_map_kwargs": runtime_func.local_map_kwargs, + } + return track_tensor_tree( + example_out, out_proxy, constant=None, tracer=proxy_mode.tracer + ) + + +from typing import Callable, Optional + +from torch.distributed.tensor import DeviceMesh, DTensor, Placement, Replicate +from torch.distributed.tensor.experimental._func_map import InputPlacements, OutputPlacements +from torch.distributed.device_mesh import _mesh_resources + +def apply_local_map(*local_map_args, **local_map_kwargs): + assert local_map_kwargs["redistribute_inputs"], "Autoparallel should always be allowed to redistribute inputs" + + # manually issue the hop, which will not be not necessary with a dynamo frontend + def decorator(fn): + @functools.wraps(fn) + def wrapped(*args, **kwargs): + def runtime_func(*runtime_args, **runtime_kwargs): + # hop doesn't like the functools.partial created by local_map + return local_map( + fn, + *local_map_args, + **local_map_kwargs, + )(*runtime_args, **runtime_kwargs) + runtime_func.local_map_kwargs = local_map_kwargs + return local_map_hop(runtime_func, *args, **kwargs) + + return wrapped + return decorator + + +@apply_local_map( + out_placements=[Replicate(),], + in_placements=([Replicate()], [Replicate()]), # intentionally suboptimal, just to test + redistribute_inputs=True, +) +def boosted(w, x): + return torch.matmul(x, w.t()) * 12345 + +class Block(nn.Module): + def __init__(self, nheads, dim1, dim2): + super().__init__() + self.nheads = nheads + bias = False + self.wq = nn.Linear(dim1, dim1, bias=bias) + self.wk = nn.Linear(dim1, dim1, bias=bias) + self.wv = nn.Linear(dim1, dim1, bias=bias) + self.wo = nn.Linear(dim1, dim1, bias=bias) + self.w1 = nn.Linear(dim1, dim2, bias=bias) + self.w2 = nn.Linear(dim2, dim1, bias=bias) + + def init_weights(self): + for lin in [self.wq, self.wk, self.wv, self.wo, self.w1, self.w2]: + torch.nn.init.normal_(lin.weight) + if lin.bias is not None: + torch.nn.init.normal_(lin.bias) + + def forward(self, x): + q = self.wq(x) + k = boosted(self.wk.weight, x) + # k = self.wk(x) + v = self.wv(x) + + q = q.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3) + k = k.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3) + v = v.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3) + + o = nn.functional.scaled_dot_product_attention(q, k, v) + o = o.permute(0, 2, 1, 3).flatten(-2) + + o = self.wo(o) + + o0 = o + x + + o = self.w1(o0) + o = torch.nn.functional.relu(o) + o = self.w2(o) + + o = o0 + o + + return o + + +world_size = 256 + +fake_store = FakeStore() +torch.distributed.init_process_group( + "fake", store=fake_store, rank=0, world_size=world_size +) +# mesh = torch.distributed.device_mesh.init_device_mesh("cuda", (world_size,), mesh_dim_names=("dp",)) +mesh = torch.distributed.device_mesh.init_device_mesh( + "cuda", + (world_size // 8, 8), + mesh_dim_names=( + "dp", + "tp", + ), +) + +bs = 8 * mesh.shape[0] +seq_len = 256 +nheads = 48 +dim1 = 6144 +dim2 = dim1 * 4 + + +def input_fn(): + return torch.rand(bs, seq_len, dim1, device="cuda") + + +# parallelize the model +with torch.device("meta"): + model = Block(nheads, dim1, dim2) +autop = AutoParallel(model, input_fn, mesh) +autop.add_parameter_memory_constraint(low=None, high=None) + +x_sharding = (Shard(0), Replicate()) + +autop.add_input_constraints([x_sharding]) +autop.add_output_constraints([x_sharding]) + +sharding_placement = autop.optimize_placement() + +# AutoParallel produces a module with meta-DTensor parameters that need to be initialized +parallel_mod = autop.apply_placement(sharding_placement) +parallel_mod.to_empty(device="cuda") +parallel_mod.init_weights() + +# now let's run it +x = (torch.rand(bs // mesh.shape[0], seq_len, dim1, device="cuda"),) +out = parallel_mod(*x) +out.backward(torch.randn_like(out)) + +print("All good!") From 466cd1b45c75b4bce41a083c631cbcd8a9dd1f9b Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 22 Jul 2025 23:07:05 -0700 Subject: [PATCH 02/10] inference working, backward sharding not enforced --- autoparallel/optimize_sharding.py | 55 ++++--- examples/example_local_map.py | 239 +++++++++++++++++++++++++----- 2 files changed, 241 insertions(+), 53 deletions(-) diff --git a/autoparallel/optimize_sharding.py b/autoparallel/optimize_sharding.py index aee3e168..405fbcac 100644 --- a/autoparallel/optimize_sharding.py +++ b/autoparallel/optimize_sharding.py @@ -146,14 +146,20 @@ def build_sharding_metadata(self): user_kwargs = tree_map_only( torch.fx.Node, lambda x: x.meta["val"], node.kwargs ) - if local_map_kwargs := node.meta.get("custom", {}).get("local_map_kwargs"): + if local_map_kwargs := node.meta.get("custom", {}).get( + "local_map_kwargs" + ): + assert not node.kwargs + node.kwargs = {"_inline": True} in_placements = local_map_kwargs["in_placements"] out_placements = local_map_kwargs["out_placements"] in_specs = [] for input_arg, placement in zip(node.args, in_placements): if placement is None: # not a dtensor - assert False, "Not sure how to create DTensorSpec for this input" + assert ( + False + ), "Not sure how to create DTensorSpec for this input" assert isinstance(placement, list), "Not implemented" example = input_arg.meta["val"] @@ -165,18 +171,25 @@ def build_sharding_metadata(self): shape=example.shape, stride=example.stride(), dtype=example.dtype, - ) + ), ) ) out_specs = [] assert isinstance(node.meta["val"], (torch.Tensor, list, tuple)) - outs = node.meta["val"] if isinstance(node.meta["val"], (list, tuple)) else [node.meta["val"]] + outs = ( + node.meta["val"] + if isinstance(node.meta["val"], (list, tuple)) + else [node.meta["val"]] + ) for example, placement in zip(outs, out_placements): from torch.distributed.tensor.placement_types import Placement + if placement is None: # not a dtensor - assert False, "Not sure how to create DTensorSpec for this output" + assert ( + False + ), "Not sure how to create DTensorSpec for this output" elif isinstance(placement, Placement): placement = [placement] @@ -189,30 +202,36 @@ def build_sharding_metadata(self): shape=example.shape, stride=example.stride(), dtype=example.dtype, - ) + ), ) ) - from torch.distributed.tensor._op_schema import OpStrategy, OpSpec - from torch.distributed.tensor._ops.utils import generate_redistribute_costs - + from torch.distributed.tensor._op_schema import OpSpec, OpStrategy + from torch.distributed.tensor._ops.utils import ( + generate_redistribute_costs, + ) + redistribute_costs = [] for input_arg, input_spec in zip(node.args, in_specs): assert isinstance(input_arg, torch.fx.Node) input_node_strategy = strats[input_arg] - costs = generate_redistribute_costs(input_node_strategy, input_spec) + costs = generate_redistribute_costs( + input_node_strategy, input_spec + ) redistribute_costs.append(costs) if len(out_specs) == 1: out_specs = out_specs[0] - - strat = OpStrategy([ - OpSpec( - output_specs=out_specs, - input_specs=in_specs, - redistribute_cost=redistribute_costs - ) - ]) + + strat = OpStrategy( + [ + OpSpec( + output_specs=out_specs, + input_specs=in_specs, + redistribute_cost=redistribute_costs, + ) + ] + ) strats[node] = strat else: strat = get_placement_options( diff --git a/examples/example_local_map.py b/examples/example_local_map.py index 6e4d4262..2e0b9277 100644 --- a/examples/example_local_map.py +++ b/examples/example_local_map.py @@ -3,29 +3,36 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. -import torch -from torch import nn -from torch.distributed.tensor.placement_types import Replicate, Shard -from torch.testing._internal.distributed.fake_pg import FakeStore -from torch.distributed._tensor.experimental import local_map -from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode -from torch._ops import HigherOrderOperator import functools +from typing import Callable, Optional, Union -from autoparallel.api import AutoParallel - +import torch +import torch.utils._pytree as pytree +from torch import Tensor, nn +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.distributed._tensor.experimental import local_map +from torch.distributed.device_mesh import _mesh_resources +from torch.distributed.tensor import DeviceMesh, DTensor, Placement, Replicate +from torch.distributed.tensor.experimental._func_map import ( + InputPlacements, + OutputPlacements, +) +from torch.distributed.tensor.placement_types import Replicate, Shard from torch.fx.experimental.proxy_tensor import ( - get_proxy_slot, ProxyTorchDispatchMode, disable_proxy_modes_tracing, - track_tensor_tree + get_proxy_slot, + track_tensor_tree, ) -import torch.utils._pytree as pytree +from torch.testing._internal.distributed.fake_pg import FakeStore +from autoparallel.api import AutoParallel # just to dump tlparse torch.compile(lambda x: x + 1, backend="eager")(torch.rand(10)) + class LocalMapAOTExportModule(HigherOrderOperator): """ A HOP that integrates with autoparallel's current frontend (aot_export_module). @@ -48,24 +55,167 @@ def __call__(self, runtime_func, *args, **kwargs): # local_map_hop(fn, torch.randn(10, 10)) # breakpoint() +from torch._higher_order_ops.utils import ( + _maybe_run_with_interpreter, + _set_compilation_env, + materialize_as_graph, + reenter_make_fx, + save_tensors_and_symints_for_backward, + saved_tensors_and_symints, + unique_graph_id, + validate_subgraph_args_types, +) + + +def create_fw_bw_graph( + fw_func, + *_args, +): + # See Note:[HOP create fw_bw graph] + + # All of these imports need to be here in order to avoid circular dependencies + from torch._dispatch.python import suspend_functionalization + from torch._functorch.aot_autograd import AOTConfig, create_joint + from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode + from torch._subclasses.functional_tensor import disable_functional_mode + from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing + + dummy_aot_config = AOTConfig( + fw_compiler=None, # type: ignore[arg-type] + bw_compiler=None, # type: ignore[arg-type] + partition_fn=None, # type: ignore[arg-type] + decompositions={}, + num_params_buffers=0, + aot_id=0, + keep_inference_input_mutations=False, + ) + + with suspend_functionalization(), disable_functional_mode(): + with disable_proxy_modes_tracing(): + + def _from_fun( + t: Union[Tensor, torch.SymInt, int], + ) -> Union[Tensor, torch.SymInt, int]: + if isinstance(t, torch.Tensor): + return torch.empty_strided( + t.size(), + t.stride(), + device=t.device, + dtype=t.dtype, + requires_grad=t.requires_grad, + ) + return t + + # If someone runs this hop under the default compiler backend ("eager") + # Then this path will be run with the actual user inputs. We convert them + # to fake tensors in order to not perform any actual compute. + from torch._guards import detect_fake_mode + + fake_mode = detect_fake_mode(_args) + if fake_mode is None: + fake_mode = FakeTensorMode(allow_non_fake_inputs=True) + + with fake_mode: + fw_inputs = pytree.tree_map(_from_fun, _args) + + assert all( + isinstance(t, (FakeTensor, int, torch.SymInt)) for t in fw_inputs + ) + + # redundant? we already _from_fun'd the inputs + example_flat_out = pytree.tree_map( + _from_fun, + fw_func(*fw_inputs), + ) + example_grad = _from_fun(example_flat_out) + + from torch.fx.experimental.proxy_tensor import make_fx + + def joint_f( + example_grad, + *fw_inputs, + ): + def run_fwd(*fw_inputs): + outs = fw_func(*fw_inputs) + if not isinstance(outs, (list, tuple)): + outs = (outs,) + masks = [o.requires_grad for o in outs] + return (outs, masks) + + joint = create_joint(run_fwd, aot_config=dummy_aot_config) + optional_grad = [example_grad] if example_grad.requires_grad else [] + _, grads = joint(fw_inputs, optional_grad) + + return grads + + joint_graph = make_fx(joint_f)(example_grad, *fw_inputs) + # do i need to return fw_graph here? by definition it is traceable, so should be fine to run again with runtime_func + return None, joint_graph + + +# def create_fw_bw_graph( +# runtime_wrapper, +# *args, +# ): +# from torch._dispatch.python import suspend_functionalization +# from torch._functorch.aot_autograd import AOTConfig, create_joint +# from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +# from torch._subclasses.functional_tensor import disable_functional_mode +# from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing +# from typing import Union + +# with suspend_functionalization(), disable_functional_mode(): +# with disable_proxy_modes_tracing(): +# def _from_fun( +# t: Union[torch.Tensor, torch.SymInt, int], +# ) -> Union[torch.Tensor, torch.SymInt, int]: +# if isinstance(t, torch.Tensor): +# return torch.empty_strided( +# t.size(), +# t.stride(), +# device=t.device, +# dtype=t.dtype, +# requires_grad=t.requires_grad, +# ) +# return t + + +# fw_inputs = pytree.tree_map(_from_fun, args) +# assert all( +# isinstance(t, (FakeTensor, int, torch.SymInt)) +# for t in fw_inputs +# ) + +# out = runtime_wrapper(*fw_inputs) +# example_flat_out = pytree.tree_map( +# _from_fun, +# out, +# ) +# example_grad = _from_fun(example_flat_out) +# breakpoint() + +# return None + + class LocalMapAutogradOp(torch.autograd.Function): @staticmethod - def forward(ctx, runtime_func, *args, **kwargs): + def forward(ctx, runtime_func, bwd_func, *args, **kwargs): ctx.save_for_backward(*args) + save_tensors_and_symints_for_backward(ctx, args) + ctx._bwd_func = bwd_func + with torch._C._AutoDispatchBelowAutograd(): # why return local_map_hop(runtime_func, *args, **kwargs) - # out = runtime_func(*args, **kwargs) - # return out - # out = runtime_func(*args, **kwargs) - # breakpoint() - # return out @staticmethod def backward(ctx, *grads): - # mmmmm could really use the backward graph here - fwd_inputs = ctx.saved_tensors - return None, *[torch.ones_like(i) * 12345 for i in fwd_inputs] + args = saved_tensors_and_symints(ctx) + grad_ins = ctx._bwd_func(*grads, *args) + # TODO: hopify to make opaque + # grad_ins = local_map_backward_hop(ctx._bwd_func, *grads, *args) + return None, None, *grad_ins + @local_map_hop.py_impl(torch._C.DispatchKey.Autograd) def autograd_key( @@ -73,12 +223,19 @@ def autograd_key( *args, **kwargs, ): - return LocalMapAutogradOp.apply(runtime_func, *args, **kwargs) + if "_inline" in kwargs: + del kwargs["_inline"] + return runtime_func(*args, **kwargs) + + # else trace + # trace joint, pass to .apply + _, bw_graph = create_fw_bw_graph(runtime_func, *args) + return LocalMapAutogradOp.apply(runtime_func, bw_graph, *args, **kwargs) + @local_map_hop.py_functionalize_impl def functional_mode_key(ctx, runtime_func, *args, **kwargs): - assert not kwargs - + assert not kwargs unwrapped_inputs = ctx.unwrap_tensors(args) with ctx.redispatch_to_next(): @@ -86,6 +243,7 @@ def functional_mode_key(ctx, runtime_func, *args, **kwargs): out = local_map_hop(runtime_func, *unwrapped_inputs) return ctx.wrap_tensors(out) + @local_map_hop.py_impl(FakeTensorMode) def fake_mode_key( mode, @@ -96,6 +254,7 @@ def fake_mode_key( with mode: return runtime_func(*args, **kwargs) + @local_map_hop.py_impl(ProxyTorchDispatchMode) def proxy_mode_key( proxy_mode, @@ -103,13 +262,19 @@ def proxy_mode_key( *args, **kwargs, ): - assert proxy_mode is not None, "Mode should always be enabled for python fallback key" + assert ( + proxy_mode is not None + ), "Mode should always be enabled for python fallback key" assert len(kwargs) == 0 example_out = local_map_hop(runtime_func, *args, **kwargs) proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) + def another_wrapper(*another_args, **another_kwargs): - return functools.partial(local_map_hop, runtime_func)(*another_args, **another_kwargs) + return functools.partial(local_map_hop, runtime_func)( + *another_args, **another_kwargs + ) + out_proxy = proxy_mode.tracer.create_proxy( "call_function", another_wrapper, proxy_args, {} ) @@ -121,14 +286,10 @@ def another_wrapper(*another_args, **another_kwargs): ) -from typing import Callable, Optional - -from torch.distributed.tensor import DeviceMesh, DTensor, Placement, Replicate -from torch.distributed.tensor.experimental._func_map import InputPlacements, OutputPlacements -from torch.distributed.device_mesh import _mesh_resources - def apply_local_map(*local_map_args, **local_map_kwargs): - assert local_map_kwargs["redistribute_inputs"], "Autoparallel should always be allowed to redistribute inputs" + assert local_map_kwargs[ + "redistribute_inputs" + ], "Autoparallel should always be allowed to redistribute inputs" # manually issue the hop, which will not be not necessary with a dynamo frontend def decorator(fn): @@ -141,21 +302,29 @@ def runtime_func(*runtime_args, **runtime_kwargs): *local_map_args, **local_map_kwargs, )(*runtime_args, **runtime_kwargs) + runtime_func.local_map_kwargs = local_map_kwargs return local_map_hop(runtime_func, *args, **kwargs) return wrapped + return decorator @apply_local_map( - out_placements=[Replicate(),], - in_placements=([Replicate()], [Replicate()]), # intentionally suboptimal, just to test + out_placements=[ + [Replicate(), Replicate()], + ], + in_placements=( + [Replicate(), Replicate()], + [Replicate(), Replicate()], + ), # intentionally suboptimal, just to test redistribute_inputs=True, ) def boosted(w, x): return torch.matmul(x, w.t()) * 12345 + class Block(nn.Module): def __init__(self, nheads, dim1, dim2): super().__init__() From 460ebe07816332a872871c8dd6c775ddf1594bdd Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Wed, 23 Jul 2025 08:15:33 -0700 Subject: [PATCH 03/10] clean up --- autoparallel/optimize_sharding.py | 99 ++------- autoparallel/utils.py | 74 ++++++- examples/example_local_map.py | 333 +++--------------------------- 3 files changed, 115 insertions(+), 391 deletions(-) diff --git a/autoparallel/optimize_sharding.py b/autoparallel/optimize_sharding.py index 405fbcac..f4587bdf 100644 --- a/autoparallel/optimize_sharding.py +++ b/autoparallel/optimize_sharding.py @@ -81,7 +81,7 @@ import pulp import torch -from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._dtensor_spec import DTensorSpec from torch.distributed.tensor.placement_types import Replicate, Shard from torch.utils._pytree import tree_flatten, tree_map_only @@ -90,7 +90,7 @@ estimate_strategy_runtime_cost, ) from .propagation_rules import _create_all_options -from .utils import get_placement_options +from .utils import get_local_map_placement_option, get_placement_options def _debug_node(node): @@ -147,91 +147,24 @@ def build_sharding_metadata(self): torch.fx.Node, lambda x: x.meta["val"], node.kwargs ) if local_map_kwargs := node.meta.get("custom", {}).get( - "local_map_kwargs" + "dtensor_local_map_kwargs" ): - assert not node.kwargs - node.kwargs = {"_inline": True} - in_placements = local_map_kwargs["in_placements"] - out_placements = local_map_kwargs["out_placements"] - in_specs = [] - for input_arg, placement in zip(node.args, in_placements): - if placement is None: - # not a dtensor - assert ( - False - ), "Not sure how to create DTensorSpec for this input" - - assert isinstance(placement, list), "Not implemented" - example = input_arg.meta["val"] - in_specs.append( - DTensorSpec( - mesh=self.mesh, - placements=placement, - tensor_meta=TensorMeta( - shape=example.shape, - stride=example.stride(), - dtype=example.dtype, - ), - ) - ) - - out_specs = [] - assert isinstance(node.meta["val"], (torch.Tensor, list, tuple)) - outs = ( - node.meta["val"] - if isinstance(node.meta["val"], (list, tuple)) - else [node.meta["val"]] + strat = get_local_map_placement_option( + self.mesh, + node.target, + user_strats, + user_args, + user_kwargs, + node.meta["val"], + local_map_kwargs["in_placements"], + local_map_kwargs["out_placements"], ) - for example, placement in zip(outs, out_placements): - from torch.distributed.tensor.placement_types import Placement - - if placement is None: - # not a dtensor - assert ( - False - ), "Not sure how to create DTensorSpec for this output" - elif isinstance(placement, Placement): - placement = [placement] - - assert isinstance(placement, (list, tuple)), "Not implemented" - out_specs.append( - DTensorSpec( - mesh=self.mesh, - placements=placement, - tensor_meta=TensorMeta( - shape=example.shape, - stride=example.stride(), - dtype=example.dtype, - ), - ) - ) - from torch.distributed.tensor._op_schema import OpSpec, OpStrategy - from torch.distributed.tensor._ops.utils import ( - generate_redistribute_costs, - ) + assert not node.kwargs + node.kwargs = { + "_inline": True + } # notify the HOP to desugar in the next trace - redistribute_costs = [] - for input_arg, input_spec in zip(node.args, in_specs): - assert isinstance(input_arg, torch.fx.Node) - input_node_strategy = strats[input_arg] - costs = generate_redistribute_costs( - input_node_strategy, input_spec - ) - redistribute_costs.append(costs) - - if len(out_specs) == 1: - out_specs = out_specs[0] - - strat = OpStrategy( - [ - OpSpec( - output_specs=out_specs, - input_specs=in_specs, - redistribute_cost=redistribute_costs, - ) - ] - ) strats[node] = strat else: strat = get_placement_options( diff --git a/autoparallel/utils.py b/autoparallel/utils.py index e9825ffb..72656c60 100644 --- a/autoparallel/utils.py +++ b/autoparallel/utils.py @@ -4,9 +4,15 @@ # LICENSE file in the root directory of this source tree. import torch -from torch.distributed._tensor.placement_types import TensorMeta +from torch.distributed._tensor.placement_types import Placement, TensorMeta from torch.distributed.device_mesh import _get_device_handle -from torch.distributed.tensor._op_schema import OpSchema, OpStrategy, TupleStrategy +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import ( + OpSchema, + OpSpec, + OpStrategy, + TupleStrategy, +) from torch.distributed.tensor._ops.utils import generate_redistribute_costs from torch.utils._pytree import tree_flatten, tree_map_only @@ -123,6 +129,70 @@ def get_placement_options(mesh, op, specs, user_args, user_kwargs): return out_strat +def get_local_map_placement_option( + mesh, op, specs, user_args, user_kwargs, output_val, in_placements, out_placements +): + in_specs = [] + for example, placement in zip(user_args, in_placements): + if placement is None: + # not a dtensor + assert False, "Not sure how to create DTensorSpec for this input" + + assert isinstance(placement, (list, tuple)), "Not implemented" + in_specs.append( + DTensorSpec( + mesh=mesh, + placements=placement, + tensor_meta=TensorMeta( + shape=example.shape, + stride=example.stride(), + dtype=example.dtype, + ), + ) + ) + + out_specs = [] + assert isinstance(output_val, (torch.Tensor, list, tuple)) + outs = output_val if isinstance(output_val, (list, tuple)) else [output_val] + for example, placement in zip(outs, out_placements): + if placement is None: + # not a dtensor + assert False, "Not sure how to create DTensorSpec for this output" + elif isinstance(placement, Placement): + placement = [placement] + + assert isinstance(placement, (list, tuple)), "Not implemented" + out_specs.append( + DTensorSpec( + mesh=mesh, + placements=placement, + tensor_meta=TensorMeta( + shape=example.shape, + stride=example.stride(), + dtype=example.dtype, + ), + ) + ) + + if len(out_specs) == 1: + out_specs = out_specs[0] + + redistribute_costs = [] + for user_strategy, input_spec in zip(specs, in_specs): + costs = generate_redistribute_costs(user_strategy, input_spec) + redistribute_costs.append(costs) + + return OpStrategy( + [ + OpSpec( + output_specs=out_specs, + input_specs=in_specs, + redistribute_cost=redistribute_costs, + ) + ] + ) + + def _get_device_from_mesh(mesh): if mesh.device_type == "cpu": return torch.device("cpu") diff --git a/examples/example_local_map.py b/examples/example_local_map.py index 2e0b9277..12a6fe10 100644 --- a/examples/example_local_map.py +++ b/examples/example_local_map.py @@ -3,314 +3,18 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. -import functools -from typing import Callable, Optional, Union - import torch -import torch.utils._pytree as pytree -from torch import Tensor, nn -from torch._ops import HigherOrderOperator -from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode -from torch.distributed._tensor.experimental import local_map -from torch.distributed.device_mesh import _mesh_resources -from torch.distributed.tensor import DeviceMesh, DTensor, Placement, Replicate -from torch.distributed.tensor.experimental._func_map import ( - InputPlacements, - OutputPlacements, -) +from torch import nn from torch.distributed.tensor.placement_types import Replicate, Shard -from torch.fx.experimental.proxy_tensor import ( - ProxyTorchDispatchMode, - disable_proxy_modes_tracing, - get_proxy_slot, - track_tensor_tree, -) from torch.testing._internal.distributed.fake_pg import FakeStore from autoparallel.api import AutoParallel +from autoparallel.local_map_hop import apply_local_map -# just to dump tlparse +# just to force dump tlparse torch.compile(lambda x: x + 1, backend="eager")(torch.rand(10)) -class LocalMapAOTExportModule(HigherOrderOperator): - """ - A HOP that integrates with autoparallel's current frontend (aot_export_module). - This HOP exists starting the pre-solver graph and lives until we apply sharding. - During which, runtime_func will be inlined into the post-solver graph. - """ - - def __init__(self): - super().__init__("local_map_hop") - - def __call__(self, runtime_func, *args, **kwargs): - return super().__call__(runtime_func, *args, **kwargs) - - -local_map_hop = LocalMapAOTExportModule() - -# def fn(x): -# return x - -# local_map_hop(fn, torch.randn(10, 10)) -# breakpoint() - -from torch._higher_order_ops.utils import ( - _maybe_run_with_interpreter, - _set_compilation_env, - materialize_as_graph, - reenter_make_fx, - save_tensors_and_symints_for_backward, - saved_tensors_and_symints, - unique_graph_id, - validate_subgraph_args_types, -) - - -def create_fw_bw_graph( - fw_func, - *_args, -): - # See Note:[HOP create fw_bw graph] - - # All of these imports need to be here in order to avoid circular dependencies - from torch._dispatch.python import suspend_functionalization - from torch._functorch.aot_autograd import AOTConfig, create_joint - from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode - from torch._subclasses.functional_tensor import disable_functional_mode - from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing - - dummy_aot_config = AOTConfig( - fw_compiler=None, # type: ignore[arg-type] - bw_compiler=None, # type: ignore[arg-type] - partition_fn=None, # type: ignore[arg-type] - decompositions={}, - num_params_buffers=0, - aot_id=0, - keep_inference_input_mutations=False, - ) - - with suspend_functionalization(), disable_functional_mode(): - with disable_proxy_modes_tracing(): - - def _from_fun( - t: Union[Tensor, torch.SymInt, int], - ) -> Union[Tensor, torch.SymInt, int]: - if isinstance(t, torch.Tensor): - return torch.empty_strided( - t.size(), - t.stride(), - device=t.device, - dtype=t.dtype, - requires_grad=t.requires_grad, - ) - return t - - # If someone runs this hop under the default compiler backend ("eager") - # Then this path will be run with the actual user inputs. We convert them - # to fake tensors in order to not perform any actual compute. - from torch._guards import detect_fake_mode - - fake_mode = detect_fake_mode(_args) - if fake_mode is None: - fake_mode = FakeTensorMode(allow_non_fake_inputs=True) - - with fake_mode: - fw_inputs = pytree.tree_map(_from_fun, _args) - - assert all( - isinstance(t, (FakeTensor, int, torch.SymInt)) for t in fw_inputs - ) - - # redundant? we already _from_fun'd the inputs - example_flat_out = pytree.tree_map( - _from_fun, - fw_func(*fw_inputs), - ) - example_grad = _from_fun(example_flat_out) - - from torch.fx.experimental.proxy_tensor import make_fx - - def joint_f( - example_grad, - *fw_inputs, - ): - def run_fwd(*fw_inputs): - outs = fw_func(*fw_inputs) - if not isinstance(outs, (list, tuple)): - outs = (outs,) - masks = [o.requires_grad for o in outs] - return (outs, masks) - - joint = create_joint(run_fwd, aot_config=dummy_aot_config) - optional_grad = [example_grad] if example_grad.requires_grad else [] - _, grads = joint(fw_inputs, optional_grad) - - return grads - - joint_graph = make_fx(joint_f)(example_grad, *fw_inputs) - # do i need to return fw_graph here? by definition it is traceable, so should be fine to run again with runtime_func - return None, joint_graph - - -# def create_fw_bw_graph( -# runtime_wrapper, -# *args, -# ): -# from torch._dispatch.python import suspend_functionalization -# from torch._functorch.aot_autograd import AOTConfig, create_joint -# from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode -# from torch._subclasses.functional_tensor import disable_functional_mode -# from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing -# from typing import Union - -# with suspend_functionalization(), disable_functional_mode(): -# with disable_proxy_modes_tracing(): -# def _from_fun( -# t: Union[torch.Tensor, torch.SymInt, int], -# ) -> Union[torch.Tensor, torch.SymInt, int]: -# if isinstance(t, torch.Tensor): -# return torch.empty_strided( -# t.size(), -# t.stride(), -# device=t.device, -# dtype=t.dtype, -# requires_grad=t.requires_grad, -# ) -# return t - - -# fw_inputs = pytree.tree_map(_from_fun, args) -# assert all( -# isinstance(t, (FakeTensor, int, torch.SymInt)) -# for t in fw_inputs -# ) - -# out = runtime_wrapper(*fw_inputs) -# example_flat_out = pytree.tree_map( -# _from_fun, -# out, -# ) -# example_grad = _from_fun(example_flat_out) -# breakpoint() - -# return None - - -class LocalMapAutogradOp(torch.autograd.Function): - @staticmethod - def forward(ctx, runtime_func, bwd_func, *args, **kwargs): - ctx.save_for_backward(*args) - - save_tensors_and_symints_for_backward(ctx, args) - ctx._bwd_func = bwd_func - - with torch._C._AutoDispatchBelowAutograd(): # why - return local_map_hop(runtime_func, *args, **kwargs) - - @staticmethod - def backward(ctx, *grads): - args = saved_tensors_and_symints(ctx) - grad_ins = ctx._bwd_func(*grads, *args) - # TODO: hopify to make opaque - # grad_ins = local_map_backward_hop(ctx._bwd_func, *grads, *args) - return None, None, *grad_ins - - -@local_map_hop.py_impl(torch._C.DispatchKey.Autograd) -def autograd_key( - runtime_func, - *args, - **kwargs, -): - if "_inline" in kwargs: - del kwargs["_inline"] - return runtime_func(*args, **kwargs) - - # else trace - # trace joint, pass to .apply - _, bw_graph = create_fw_bw_graph(runtime_func, *args) - return LocalMapAutogradOp.apply(runtime_func, bw_graph, *args, **kwargs) - - -@local_map_hop.py_functionalize_impl -def functional_mode_key(ctx, runtime_func, *args, **kwargs): - assert not kwargs - - unwrapped_inputs = ctx.unwrap_tensors(args) - with ctx.redispatch_to_next(): - # TODO: local_map mutation checks - out = local_map_hop(runtime_func, *unwrapped_inputs) - return ctx.wrap_tensors(out) - - -@local_map_hop.py_impl(FakeTensorMode) -def fake_mode_key( - mode, - runtime_func, - *args, - **kwargs, -): - with mode: - return runtime_func(*args, **kwargs) - - -@local_map_hop.py_impl(ProxyTorchDispatchMode) -def proxy_mode_key( - proxy_mode, - runtime_func, - *args, - **kwargs, -): - assert ( - proxy_mode is not None - ), "Mode should always be enabled for python fallback key" - assert len(kwargs) == 0 - - example_out = local_map_hop(runtime_func, *args, **kwargs) - proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) - - def another_wrapper(*another_args, **another_kwargs): - return functools.partial(local_map_hop, runtime_func)( - *another_args, **another_kwargs - ) - - out_proxy = proxy_mode.tracer.create_proxy( - "call_function", another_wrapper, proxy_args, {} - ) - out_proxy.node.meta["custom"] = { - "local_map_kwargs": runtime_func.local_map_kwargs, - } - return track_tensor_tree( - example_out, out_proxy, constant=None, tracer=proxy_mode.tracer - ) - - -def apply_local_map(*local_map_args, **local_map_kwargs): - assert local_map_kwargs[ - "redistribute_inputs" - ], "Autoparallel should always be allowed to redistribute inputs" - - # manually issue the hop, which will not be not necessary with a dynamo frontend - def decorator(fn): - @functools.wraps(fn) - def wrapped(*args, **kwargs): - def runtime_func(*runtime_args, **runtime_kwargs): - # hop doesn't like the functools.partial created by local_map - return local_map( - fn, - *local_map_args, - **local_map_kwargs, - )(*runtime_args, **runtime_kwargs) - - runtime_func.local_map_kwargs = local_map_kwargs - return local_map_hop(runtime_func, *args, **kwargs) - - return wrapped - - return decorator - - @apply_local_map( out_placements=[ [Replicate(), Replicate()], @@ -318,18 +22,34 @@ def runtime_func(*runtime_args, **runtime_kwargs): in_placements=( [Replicate(), Replicate()], [Replicate(), Replicate()], - ), # intentionally suboptimal, just to test + [Replicate(), Replicate()], + ), + redistribute_inputs=True, + in_grad_placements=None, + device_mesh=None, +) +def replicate_linear(w, bias, x): + return torch.matmul(x, w.t()) + bias + + +@apply_local_map( + out_placements=[ + [Shard(0), Shard(0)], + ], + in_placements=([Shard(0), Shard(0)],), redistribute_inputs=True, + in_grad_placements=None, + device_mesh=None, ) -def boosted(w, x): - return torch.matmul(x, w.t()) * 12345 +def sharded_pointwise(x): + return x + 10 class Block(nn.Module): def __init__(self, nheads, dim1, dim2): super().__init__() self.nheads = nheads - bias = False + bias = True self.wq = nn.Linear(dim1, dim1, bias=bias) self.wk = nn.Linear(dim1, dim1, bias=bias) self.wv = nn.Linear(dim1, dim1, bias=bias) @@ -344,9 +64,10 @@ def init_weights(self): torch.nn.init.normal_(lin.bias) def forward(self, x): - q = self.wq(x) - k = boosted(self.wk.weight, x) - # k = self.wk(x) + boosted_weight = sharded_pointwise(self.wq.weight) + q = replicate_linear(boosted_weight, self.wq.bias, x) + # q = self.wq(x) + k = self.wk(x) v = self.wv(x) q = q.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3) From 1d4c098c3984cf3077d46f16bb781f39fc4bd085 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Wed, 23 Jul 2025 10:47:49 -0700 Subject: [PATCH 04/10] add missing file --- autoparallel/local_map_hop.py | 243 ++++++++++++++++++++++++++++++++++ 1 file changed, 243 insertions(+) create mode 100644 autoparallel/local_map_hop.py diff --git a/autoparallel/local_map_hop.py b/autoparallel/local_map_hop.py new file mode 100644 index 00000000..79428e59 --- /dev/null +++ b/autoparallel/local_map_hop.py @@ -0,0 +1,243 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +# NOTE: this file may be removed once we move to a dynamo frontend + +import functools + + +import torch +import torch.utils._pytree as pytree +from torch._higher_order_ops.utils import ( + save_tensors_and_symints_for_backward, + saved_tensors_and_symints, +) +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.distributed._tensor.experimental import local_map +from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree + +class LocalMapAOTExportModule(HigherOrderOperator): + """ + A HOP that integrates with autoparallel's current frontend (aot_export_module). + This HOP exists starting the pre-solver graph and lives until we apply sharding. + During which, orig_fwd will be inlined into the post-solver graph. + """ + + def __init__(self): + super().__init__("local_map_hop") + + def __call__(self, orig_fwd, *args, **kwargs): + return super().__call__(orig_fwd, *args, **kwargs) + + +local_map_hop = LocalMapAOTExportModule() + +def create_hop_joint_graph( + fw_func, + *_args, +): + # Keeping these imports here + # Avoid circular dependencies once we upstream with dynamo frontend + from torch._dispatch.python import suspend_functionalization + from torch._functorch.aot_autograd import AOTConfig, create_joint + from torch._guards import detect_fake_mode + from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode + from torch._subclasses.functional_tensor import disable_functional_mode + from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing, make_fx + + dummy_aot_config = AOTConfig( + fw_compiler=None, # type: ignore[arg-type] + bw_compiler=None, # type: ignore[arg-type] + partition_fn=None, # type: ignore[arg-type] + decompositions={}, + num_params_buffers=0, + aot_id=0, + keep_inference_input_mutations=False, + ) + + with suspend_functionalization(), disable_functional_mode(): + with disable_proxy_modes_tracing(): + + def _from_fun(t): + if isinstance(t, torch.Tensor): + return torch.empty_strided( + t.size(), + t.stride(), + device=t.device, + dtype=t.dtype, + requires_grad=t.requires_grad, + ) + return t + + # If someone runs this hop under the default compiler backend ("eager") + # Then this path will be run with the actual user inputs. We convert them + # to fake tensors in order to not perform any actual compute. + + fake_mode = detect_fake_mode(_args) + if fake_mode is None: + fake_mode = FakeTensorMode(allow_non_fake_inputs=True) + + with fake_mode: + fw_inputs = pytree.tree_map(_from_fun, _args) + + assert all( + isinstance(t, (FakeTensor, int, torch.SymInt)) for t in fw_inputs + ) + + # redundant? we already _from_fun'd the inputs + example_flat_out = pytree.tree_map( + _from_fun, + fw_func(*fw_inputs), + ) + example_grads = _from_fun(example_flat_out) + if not isinstance(example_grads, (list, tuple)): + example_grads = [example_grads] + + def joint_f( + *primals_and_tangents, + ): + fw_inputs = primals_and_tangents[: len(_args)] + example_grads = primals_and_tangents[len(_args) :] + + def run_fwd(*fw_inputs): + outs = fw_func(*fw_inputs) + if not isinstance(outs, (list, tuple)): + outs = (outs,) + masks = [o.requires_grad for o in outs] + return (outs, masks) + + joint = create_joint(run_fwd, aot_config=dummy_aot_config) + optional_grads = [] + for example_grad in example_grads: + if example_grad.requires_grad: + optional_grads.append(example_grad) + _, grads = joint(fw_inputs, optional_grads) + return grads + + primals_and_tangents = [*fw_inputs, *example_grads] + joint_graph = make_fx(joint_f)(*primals_and_tangents) + return None, joint_graph + + +class LocalMapAutogradOp(torch.autograd.Function): + @staticmethod + def forward(ctx, orig_fwd, joint_graph, *args, **kwargs): + ctx.save_for_backward(*args) + + save_tensors_and_symints_for_backward(ctx, args) + ctx.joint_graph = joint_graph + + with torch._C._AutoDispatchBelowAutograd(): # why + return local_map_hop(orig_fwd, *args, **kwargs) + + @staticmethod + def backward(ctx, *grads): + args = saved_tensors_and_symints(ctx) + grad_ins = ctx.joint_graph(*args, *grads) + # TODO: hopify to support local_map'd function containing custom autograd.Function + return None, None, *grad_ins + + +@local_map_hop.py_impl(torch._C.DispatchKey.Autograd) +def autograd_key( + orig_fwd, + *args, + **kwargs, +): + if "_inline" in kwargs: + # Solver pass adds a _inline kwarg, which tells this hop to desugar on the next trace + del kwargs["_inline"] + return orig_fwd(*args, **kwargs) + + _, joint_graph = create_hop_joint_graph(orig_fwd, *args) + return LocalMapAutogradOp.apply(orig_fwd, joint_graph, *args, **kwargs) + + +@local_map_hop.py_functionalize_impl +def functional_mode_key(ctx, orig_fwd, *args, **kwargs): + assert not kwargs + + unwrapped_inputs = ctx.unwrap_tensors(args) + with ctx.redispatch_to_next(): + # TODO: dynamo safety checks on orig_fwd + out = local_map_hop(orig_fwd, *unwrapped_inputs) + return ctx.wrap_tensors(out) + + +@local_map_hop.py_impl(FakeTensorMode) +def fake_mode_key( + mode, + orig_fwd, + *args, + **kwargs, +): + with mode: + return orig_fwd(*args, **kwargs) + + +@local_map_hop.py_impl(ProxyTorchDispatchMode) +def proxy_mode_key( + proxy_mode, + orig_fwd, + *args, + **kwargs, +): + assert ( + proxy_mode is not None + ), "Mode should always be enabled for python fallback key" + assert len(kwargs) == 0 + + example_out = local_map_hop(orig_fwd, *args, **kwargs) + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) + + def call_local_map(*another_args, **another_kwargs): + return functools.partial(local_map_hop, orig_fwd)( + *another_args, **another_kwargs + ) + + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", call_local_map, proxy_args, {} + ) + out_proxy.node.meta["custom"] = { + "dtensor_local_map_kwargs": orig_fwd.local_map_kwargs, + } + return track_tensor_tree( + example_out, out_proxy, constant=None, tracer=proxy_mode.tracer + ) + + +def apply_local_map(*local_map_args, **local_map_kwargs): + # NOTE: We manually issue the hop, which will not be not necessary with a dynamo frontend. + # 1. Same as local_map, must be applied on a function, not a method. + # 2. the local_map'd function must be make_fx traceable. Otherwise, we may + # inline the wrong graph. In a dynamo frontend, speculate_subgraph will handle this. + # 3. All inputs to the local_map'd function must be Tensor types. Otherwise, we won't + # know which tensors to apply _from_fun to. For instance, don't pass nn.Modules to local_map. + # In dynamo frontend, tensors will be lifted, and will modify the wrapped function's signature. + + assert local_map_kwargs[ + "redistribute_inputs" + ], "Autoparallel should always be allowed to redistribute inputs" + assert local_map_kwargs["in_grad_placements"] is None, "Not yet implemented" + assert local_map_kwargs["device_mesh"] is None, "Must be provided by Autoparallel" + + def decorator(fn): + @functools.wraps(fn) + def wrapped(*args, **kwargs): + def orig_fwd(*runtime_args, **runtime_kwargs): + # wrap the functools.partial for hop utils to work out of box + return local_map( + fn, + *local_map_args, + **local_map_kwargs, + )(*runtime_args, **runtime_kwargs) + + orig_fwd.local_map_kwargs = local_map_kwargs + return local_map_hop(orig_fwd, *args, **kwargs) + + return wrapped + + return decorator From ee6a2184a526193805dcea52d52865b7c734aaf3 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Wed, 23 Jul 2025 10:50:06 -0700 Subject: [PATCH 05/10] lint --- autoparallel/local_map_hop.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/autoparallel/local_map_hop.py b/autoparallel/local_map_hop.py index 79428e59..b123a7af 100644 --- a/autoparallel/local_map_hop.py +++ b/autoparallel/local_map_hop.py @@ -7,7 +7,6 @@ import functools - import torch import torch.utils._pytree as pytree from torch._higher_order_ops.utils import ( @@ -19,6 +18,7 @@ from torch.distributed._tensor.experimental import local_map from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree + class LocalMapAOTExportModule(HigherOrderOperator): """ A HOP that integrates with autoparallel's current frontend (aot_export_module). @@ -35,6 +35,7 @@ def __call__(self, orig_fwd, *args, **kwargs): local_map_hop = LocalMapAOTExportModule() + def create_hop_joint_graph( fw_func, *_args, From 55976a56365e922b73c77cd375d1259a1b89ad56 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Thu, 24 Jul 2025 09:29:17 -0700 Subject: [PATCH 06/10] HOP in eager --- autoparallel/local_map_hop.py | 11 +++++++++++ examples/example_local_map.py | 10 ++++++++++ 2 files changed, 21 insertions(+) diff --git a/autoparallel/local_map_hop.py b/autoparallel/local_map_hop.py index b123a7af..ed662ca1 100644 --- a/autoparallel/local_map_hop.py +++ b/autoparallel/local_map_hop.py @@ -210,6 +210,17 @@ def call_local_map(*another_args, **another_kwargs): ) +# Running HOP in eager with real tensors +@local_map_hop.py_impl(torch._C.DispatchKey.CPU) +@local_map_hop.py_impl(torch._C.DispatchKey.CUDA) +def real_impl( + orig_fwd, + *args, + **kwargs, +): + return orig_fwd(*args, **kwargs) + + def apply_local_map(*local_map_args, **local_map_kwargs): # NOTE: We manually issue the hop, which will not be not necessary with a dynamo frontend. # 1. Same as local_map, must be applied on a function, not a method. diff --git a/examples/example_local_map.py b/examples/example_local_map.py index 12a6fe10..10aeef2c 100644 --- a/examples/example_local_map.py +++ b/examples/example_local_map.py @@ -117,6 +117,16 @@ def input_fn(): return torch.rand(bs, seq_len, dim1, device="cuda") +# HOP runs in eager with fake tensors +# from torch._subclasses import FakeTensorMode +# with FakeTensorMode(): +# model = Block(nheads, dim1, dim2).cuda() +# model(input_fn()) + +# HOP runs in eager with real tensors +# model = Block(nheads, dim1, dim2).cuda() +# model(input_fn()) + # parallelize the model with torch.device("meta"): model = Block(nheads, dim1, dim2) From 51814d434c274692d3ddc0a9b5e7ec79fd5942c2 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Thu, 24 Jul 2025 10:05:17 -0700 Subject: [PATCH 07/10] typing --- autoparallel/optimize_sharding.py | 4 ++-- autoparallel/utils.py | 14 +++++++++++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/autoparallel/optimize_sharding.py b/autoparallel/optimize_sharding.py index f4587bdf..a372195c 100644 --- a/autoparallel/optimize_sharding.py +++ b/autoparallel/optimize_sharding.py @@ -149,12 +149,12 @@ def build_sharding_metadata(self): if local_map_kwargs := node.meta.get("custom", {}).get( "dtensor_local_map_kwargs" ): + assert "call_local_map" in str(node.target) + assert not user_kwargs strat = get_local_map_placement_option( self.mesh, - node.target, user_strats, user_args, - user_kwargs, node.meta["val"], local_map_kwargs["in_placements"], local_map_kwargs["out_placements"], diff --git a/autoparallel/utils.py b/autoparallel/utils.py index 72656c60..c75f6d73 100644 --- a/autoparallel/utils.py +++ b/autoparallel/utils.py @@ -3,9 +3,12 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +from typing import Union + import torch from torch.distributed._tensor.placement_types import Placement, TensorMeta from torch.distributed.device_mesh import _get_device_handle +from torch.distributed.tensor import DeviceMesh from torch.distributed.tensor._dtensor_spec import DTensorSpec from torch.distributed.tensor._op_schema import ( OpSchema, @@ -14,6 +17,10 @@ TupleStrategy, ) from torch.distributed.tensor._ops.utils import generate_redistribute_costs +from torch.distributed.tensor.experimental._func_map import ( + InputPlacements, + OutputPlacements, +) from torch.utils._pytree import tree_flatten, tree_map_only from .propagation_rules import _op_partial_rules, _op_rules, remove_invalid_configs @@ -130,7 +137,12 @@ def get_placement_options(mesh, op, specs, user_args, user_kwargs): def get_local_map_placement_option( - mesh, op, specs, user_args, user_kwargs, output_val, in_placements, out_placements + mesh: DeviceMesh, + specs: tuple[OpStrategy], + user_args: tuple[torch.Tensor], + output_val: Union[torch.Tensor, tuple[torch.Tensor]], + in_placements: InputPlacements, + out_placements: OutputPlacements, ): in_specs = [] for example, placement in zip(user_args, in_placements): From 4fe5e37a0519cb360d5645ebcde82171648651bd Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Thu, 24 Jul 2025 19:59:30 -0700 Subject: [PATCH 08/10] CP example using 3D mesh --- examples/example_local_map.py | 42 ++++++++++++++++++++++++++--------- 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/examples/example_local_map.py b/examples/example_local_map.py index 10aeef2c..14a0484e 100644 --- a/examples/example_local_map.py +++ b/examples/example_local_map.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import torch +import torch.nn.functional as F from torch import nn from torch.distributed.tensor.placement_types import Replicate, Shard from torch.testing._internal.distributed.fake_pg import FakeStore @@ -17,12 +18,12 @@ @apply_local_map( out_placements=[ - [Replicate(), Replicate()], + [Replicate(), Replicate(), Replicate()], ], in_placements=( - [Replicate(), Replicate()], - [Replicate(), Replicate()], - [Replicate(), Replicate()], + [Replicate(), Replicate(), Replicate()], + [Replicate(), Replicate(), Replicate()], + [Replicate(), Replicate(), Replicate()], ), redistribute_inputs=True, in_grad_placements=None, @@ -34,9 +35,9 @@ def replicate_linear(w, bias, x): @apply_local_map( out_placements=[ - [Shard(0), Shard(0)], + [Shard(0), Shard(0), Replicate()], ], - in_placements=([Shard(0), Shard(0)],), + in_placements=([Shard(0), Shard(0), Replicate()],), redistribute_inputs=True, in_grad_placements=None, device_mesh=None, @@ -45,6 +46,26 @@ def sharded_pointwise(x): return x + 10 +@apply_local_map( + out_placements=[ + [Replicate(), Replicate(), Shard(2)], + ], + in_placements=( + [Replicate(), Replicate(), Shard(2)], + [Replicate(), Replicate(), Shard(2)], + [Replicate(), Replicate(), Shard(2)], + ), + redistribute_inputs=True, + in_grad_placements=None, + device_mesh=None, +) +def context_parallel_attention(query, key, value): + out = F.scaled_dot_product_attention( + query=query, key=key, value=value, is_causal=True + ) + return out + + class Block(nn.Module): def __init__(self, nheads, dim1, dim2): super().__init__() @@ -74,7 +95,7 @@ def forward(self, x): k = k.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3) v = v.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3) - o = nn.functional.scaled_dot_product_attention(q, k, v) + o = context_parallel_attention(q, k, v) o = o.permute(0, 2, 1, 3).flatten(-2) o = self.wo(o) @@ -99,10 +120,11 @@ def forward(self, x): # mesh = torch.distributed.device_mesh.init_device_mesh("cuda", (world_size,), mesh_dim_names=("dp",)) mesh = torch.distributed.device_mesh.init_device_mesh( "cuda", - (world_size // 8, 8), + (world_size // 32, 8, 4), mesh_dim_names=( "dp", "tp", + "cp", ), ) @@ -133,7 +155,7 @@ def input_fn(): autop = AutoParallel(model, input_fn, mesh) autop.add_parameter_memory_constraint(low=None, high=None) -x_sharding = (Shard(0), Replicate()) +x_sharding = (Shard(0), Replicate(), Shard(1)) autop.add_input_constraints([x_sharding]) autop.add_output_constraints([x_sharding]) @@ -146,7 +168,7 @@ def input_fn(): parallel_mod.init_weights() # now let's run it -x = (torch.rand(bs // mesh.shape[0], seq_len, dim1, device="cuda"),) +x = (torch.rand(bs // mesh.shape[0], seq_len // mesh.shape[2], dim1, device="cuda"),) out = parallel_mod(*x) out.backward(torch.randn_like(out)) From 175f447e23b9b9fa02b63c71308c39b00b3606d9 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Fri, 25 Jul 2025 14:33:15 -0700 Subject: [PATCH 09/10] update CP sharding --- autoparallel/local_map_hop.py | 3 ++- examples/example_local_map.py | 10 +++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/autoparallel/local_map_hop.py b/autoparallel/local_map_hop.py index ed662ca1..d2346d25 100644 --- a/autoparallel/local_map_hop.py +++ b/autoparallel/local_map_hop.py @@ -62,6 +62,7 @@ def create_hop_joint_graph( with suspend_functionalization(), disable_functional_mode(): with disable_proxy_modes_tracing(): + # create a tensor (fake) from a compiler wrapped FunctionalTensor def _from_fun(t): if isinstance(t, torch.Tensor): return torch.empty_strided( @@ -131,7 +132,7 @@ def forward(ctx, orig_fwd, joint_graph, *args, **kwargs): save_tensors_and_symints_for_backward(ctx, args) ctx.joint_graph = joint_graph - with torch._C._AutoDispatchBelowAutograd(): # why + with torch._C._AutoDispatchBelowAutograd(): return local_map_hop(orig_fwd, *args, **kwargs) @staticmethod diff --git a/examples/example_local_map.py b/examples/example_local_map.py index 14a0484e..4bc57b4c 100644 --- a/examples/example_local_map.py +++ b/examples/example_local_map.py @@ -48,12 +48,12 @@ def sharded_pointwise(x): @apply_local_map( out_placements=[ - [Replicate(), Replicate(), Shard(2)], + [Shard(0), Shard(1), Shard(2)], ], in_placements=( - [Replicate(), Replicate(), Shard(2)], - [Replicate(), Replicate(), Shard(2)], - [Replicate(), Replicate(), Shard(2)], + [Shard(0), Shard(1), Shard(2)], + [Shard(0), Shard(1), Replicate()], + [Shard(0), Shard(1), Replicate()], ), redistribute_inputs=True, in_grad_placements=None, @@ -61,7 +61,7 @@ def sharded_pointwise(x): ) def context_parallel_attention(query, key, value): out = F.scaled_dot_product_attention( - query=query, key=key, value=value, is_causal=True + query=query, key=key, value=value, is_causal=False ) return out From 17f543eb8bc093acdbd58ad06fa0afa5d28301d6 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Fri, 25 Jul 2025 15:45:06 -0700 Subject: [PATCH 10/10] undo lint --- autoparallel/utils.py | 20 ++++++-------------- examples/example_local_map.py | 26 ++++++++++---------------- 2 files changed, 16 insertions(+), 30 deletions(-) diff --git a/autoparallel/utils.py b/autoparallel/utils.py index c75f6d73..21618e51 100644 --- a/autoparallel/utils.py +++ b/autoparallel/utils.py @@ -3,12 +3,9 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. -from typing import Union - import torch from torch.distributed._tensor.placement_types import Placement, TensorMeta from torch.distributed.device_mesh import _get_device_handle -from torch.distributed.tensor import DeviceMesh from torch.distributed.tensor._dtensor_spec import DTensorSpec from torch.distributed.tensor._op_schema import ( OpSchema, @@ -17,10 +14,6 @@ TupleStrategy, ) from torch.distributed.tensor._ops.utils import generate_redistribute_costs -from torch.distributed.tensor.experimental._func_map import ( - InputPlacements, - OutputPlacements, -) from torch.utils._pytree import tree_flatten, tree_map_only from .propagation_rules import _op_partial_rules, _op_rules, remove_invalid_configs @@ -137,12 +130,12 @@ def get_placement_options(mesh, op, specs, user_args, user_kwargs): def get_local_map_placement_option( - mesh: DeviceMesh, - specs: tuple[OpStrategy], - user_args: tuple[torch.Tensor], - output_val: Union[torch.Tensor, tuple[torch.Tensor]], - in_placements: InputPlacements, - out_placements: OutputPlacements, + mesh, + specs, + user_args, + output_val, + in_placements, + out_placements, ): in_specs = [] for example, placement in zip(user_args, in_placements): @@ -150,7 +143,6 @@ def get_local_map_placement_option( # not a dtensor assert False, "Not sure how to create DTensorSpec for this input" - assert isinstance(placement, (list, tuple)), "Not implemented" in_specs.append( DTensorSpec( mesh=mesh, diff --git a/examples/example_local_map.py b/examples/example_local_map.py index 4bc57b4c..f2a14863 100644 --- a/examples/example_local_map.py +++ b/examples/example_local_map.py @@ -17,13 +17,11 @@ @apply_local_map( - out_placements=[ - [Replicate(), Replicate(), Replicate()], - ], + out_placements=((Replicate(), Replicate(), Replicate()),), in_placements=( - [Replicate(), Replicate(), Replicate()], - [Replicate(), Replicate(), Replicate()], - [Replicate(), Replicate(), Replicate()], + (Replicate(), Replicate(), Replicate()), + (Replicate(), Replicate(), Replicate()), + (Replicate(), Replicate(), Replicate()), ), redistribute_inputs=True, in_grad_placements=None, @@ -34,10 +32,8 @@ def replicate_linear(w, bias, x): @apply_local_map( - out_placements=[ - [Shard(0), Shard(0), Replicate()], - ], - in_placements=([Shard(0), Shard(0), Replicate()],), + out_placements=((Shard(0), Shard(0), Replicate()),), + in_placements=((Shard(0), Shard(0), Replicate()),), redistribute_inputs=True, in_grad_placements=None, device_mesh=None, @@ -47,13 +43,11 @@ def sharded_pointwise(x): @apply_local_map( - out_placements=[ - [Shard(0), Shard(1), Shard(2)], - ], + out_placements=((Shard(0), Shard(1), Shard(2)),), in_placements=( - [Shard(0), Shard(1), Shard(2)], - [Shard(0), Shard(1), Replicate()], - [Shard(0), Shard(1), Replicate()], + (Shard(0), Shard(1), Shard(2)), + (Shard(0), Shard(1), Replicate()), + (Shard(0), Shard(1), Replicate()), ), redistribute_inputs=True, in_grad_placements=None,