-
Notifications
You must be signed in to change notification settings - Fork 8
Add local_map support #47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ce9d253
466cd1b
460ebe0
1d4c098
ee6a218
55976a5
51814d4
4fe5e37
175f447
17f543e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,256 @@ | ||
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| # NOTE: this file may be removed once we move to a dynamo frontend | ||
|
|
||
| import functools | ||
|
|
||
| import torch | ||
| import torch.utils._pytree as pytree | ||
| from torch._higher_order_ops.utils import ( | ||
| save_tensors_and_symints_for_backward, | ||
| saved_tensors_and_symints, | ||
| ) | ||
| from torch._ops import HigherOrderOperator | ||
| from torch._subclasses.fake_tensor import FakeTensorMode | ||
| from torch.distributed._tensor.experimental import local_map | ||
| from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree | ||
|
|
||
|
|
||
| class LocalMapAOTExportModule(HigherOrderOperator): | ||
| """ | ||
| A HOP that integrates with autoparallel's current frontend (aot_export_module). | ||
| This HOP exists starting the pre-solver graph and lives until we apply sharding. | ||
| During which, orig_fwd will be inlined into the post-solver graph. | ||
xmfan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
|
|
||
| def __init__(self): | ||
| super().__init__("local_map_hop") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: maybe this is hard, but it would be nice if some info about orig_fwd (e.g. name/line) were included in the name of the hop, so that the graph doesn't just show "hop1, hop2, hop3" with no info about what's inside
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. having stack traces in the autoparallel graphs for all nodes would address this issue too |
||
|
|
||
| def __call__(self, orig_fwd, *args, **kwargs): | ||
| return super().__call__(orig_fwd, *args, **kwargs) | ||
|
|
||
|
|
||
| local_map_hop = LocalMapAOTExportModule() | ||
|
|
||
|
|
||
| def create_hop_joint_graph( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this different than create_joint? |
||
| fw_func, | ||
| *_args, | ||
| ): | ||
| # Keeping these imports here | ||
| # Avoid circular dependencies once we upstream with dynamo frontend | ||
| from torch._dispatch.python import suspend_functionalization | ||
| from torch._functorch.aot_autograd import AOTConfig, create_joint | ||
| from torch._guards import detect_fake_mode | ||
| from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode | ||
| from torch._subclasses.functional_tensor import disable_functional_mode | ||
| from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing, make_fx | ||
|
|
||
| dummy_aot_config = AOTConfig( | ||
| fw_compiler=None, # type: ignore[arg-type] | ||
| bw_compiler=None, # type: ignore[arg-type] | ||
| partition_fn=None, # type: ignore[arg-type] | ||
| decompositions={}, | ||
| num_params_buffers=0, | ||
| aot_id=0, | ||
| keep_inference_input_mutations=False, | ||
| ) | ||
|
|
||
| with suspend_functionalization(), disable_functional_mode(): | ||
| with disable_proxy_modes_tracing(): | ||
|
|
||
| # create a tensor (fake) from a compiler wrapped FunctionalTensor | ||
| def _from_fun(t): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. more descriptive name? (maybe _empty_like() or something)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh, i totally missed that fun means functional. Although I still don't know what that means. |
||
| if isinstance(t, torch.Tensor): | ||
| return torch.empty_strided( | ||
| t.size(), | ||
| t.stride(), | ||
| device=t.device, | ||
| dtype=t.dtype, | ||
| requires_grad=t.requires_grad, | ||
| ) | ||
| return t | ||
|
|
||
| # If someone runs this hop under the default compiler backend ("eager") | ||
| # Then this path will be run with the actual user inputs. We convert them | ||
| # to fake tensors in order to not perform any actual compute. | ||
|
|
||
| fake_mode = detect_fake_mode(_args) | ||
| if fake_mode is None: | ||
| fake_mode = FakeTensorMode(allow_non_fake_inputs=True) | ||
|
|
||
| with fake_mode: | ||
| fw_inputs = pytree.tree_map(_from_fun, _args) | ||
|
|
||
| assert all( | ||
| isinstance(t, (FakeTensor, int, torch.SymInt)) for t in fw_inputs | ||
| ) | ||
|
|
||
| # redundant? we already _from_fun'd the inputs | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is the concern that usercode could allocate new tensors? if they did, they would already be fake because of the mode, so i'm not sure why we can't return them directly.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'm not sure why other HOPs do this. you shouldn't need to unwrap the outputs if the inputs were already unwrapped and the like you said allocating new tensors should just be fakes, and not be functional wrapped. the only thing i can think of is if |
||
| example_flat_out = pytree.tree_map( | ||
| _from_fun, | ||
| fw_func(*fw_inputs), | ||
| ) | ||
| example_grads = _from_fun(example_flat_out) | ||
| if not isinstance(example_grads, (list, tuple)): | ||
| example_grads = [example_grads] | ||
|
|
||
| def joint_f( | ||
| *primals_and_tangents, | ||
| ): | ||
| fw_inputs = primals_and_tangents[: len(_args)] | ||
| example_grads = primals_and_tangents[len(_args) :] | ||
|
|
||
| def run_fwd(*fw_inputs): | ||
| outs = fw_func(*fw_inputs) | ||
| if not isinstance(outs, (list, tuple)): | ||
| outs = (outs,) | ||
| masks = [o.requires_grad for o in outs] | ||
| return (outs, masks) | ||
|
|
||
| joint = create_joint(run_fwd, aot_config=dummy_aot_config) | ||
| optional_grads = [] | ||
| for example_grad in example_grads: | ||
| if example_grad.requires_grad: | ||
| optional_grads.append(example_grad) | ||
| _, grads = joint(fw_inputs, optional_grads) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i was expecting joint to return both forward outputs and grads. maybe i just misunderstood this?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we could, it doesn't matter if fw_func always gives the same graph |
||
| return grads | ||
|
|
||
| primals_and_tangents = [*fw_inputs, *example_grads] | ||
| joint_graph = make_fx(joint_f)(*primals_and_tangents) | ||
| return None, joint_graph | ||
|
|
||
|
|
||
| class LocalMapAutogradOp(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward(ctx, orig_fwd, joint_graph, *args, **kwargs): | ||
| ctx.save_for_backward(*args) | ||
|
|
||
| save_tensors_and_symints_for_backward(ctx, args) | ||
| ctx.joint_graph = joint_graph | ||
|
|
||
| with torch._C._AutoDispatchBelowAutograd(): | ||
| return local_map_hop(orig_fwd, *args, **kwargs) | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, *grads): | ||
| args = saved_tensors_and_symints(ctx) | ||
| grad_ins = ctx.joint_graph(*args, *grads) | ||
| # TODO: hopify to support local_map'd function containing custom autograd.Function | ||
| return None, None, *grad_ins | ||
|
|
||
|
|
||
| @local_map_hop.py_impl(torch._C.DispatchKey.Autograd) | ||
| def autograd_key( | ||
| orig_fwd, | ||
| *args, | ||
| **kwargs, | ||
| ): | ||
| if "_inline" in kwargs: | ||
| # Solver pass adds a _inline kwarg, which tells this hop to desugar on the next trace | ||
| del kwargs["_inline"] | ||
| return orig_fwd(*args, **kwargs) | ||
|
|
||
| _, joint_graph = create_hop_joint_graph(orig_fwd, *args) | ||
| return LocalMapAutogradOp.apply(orig_fwd, joint_graph, *args, **kwargs) | ||
|
|
||
|
|
||
| @local_map_hop.py_functionalize_impl | ||
| def functional_mode_key(ctx, orig_fwd, *args, **kwargs): | ||
| assert not kwargs | ||
|
|
||
| unwrapped_inputs = ctx.unwrap_tensors(args) | ||
| with ctx.redispatch_to_next(): | ||
| # TODO: dynamo safety checks on orig_fwd | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes please. It's a little hard to believe that there isn't an easy utility for this?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is, speculate_subgraph, but it needs dynamo to apply it, and pass a safe Or were you thinking of torch._subclasses.functional_tensor.*.functionalize
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have a util check_input_alias_and_mutation_return_outputs for checking aliasing and mutations of tensor inputs. For others like side-effects, lifting tensor closures/symints as inputs (to make sure dependency correctly captured), we would need to turn on dynamo and use speculate_subgraph. |
||
| out = local_map_hop(orig_fwd, *unwrapped_inputs) | ||
| return ctx.wrap_tensors(out) | ||
|
|
||
|
|
||
| @local_map_hop.py_impl(FakeTensorMode) | ||
| def fake_mode_key( | ||
| mode, | ||
| orig_fwd, | ||
| *args, | ||
| **kwargs, | ||
| ): | ||
| with mode: | ||
| return orig_fwd(*args, **kwargs) | ||
|
|
||
|
|
||
| @local_map_hop.py_impl(ProxyTorchDispatchMode) | ||
| def proxy_mode_key( | ||
| proxy_mode, | ||
| orig_fwd, | ||
| *args, | ||
| **kwargs, | ||
| ): | ||
| assert ( | ||
| proxy_mode is not None | ||
| ), "Mode should always be enabled for python fallback key" | ||
| assert len(kwargs) == 0 | ||
|
|
||
| example_out = local_map_hop(orig_fwd, *args, **kwargs) | ||
| proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) | ||
|
|
||
| def call_local_map(*another_args, **another_kwargs): | ||
| return functools.partial(local_map_hop, orig_fwd)( | ||
| *another_args, **another_kwargs | ||
| ) | ||
|
|
||
| out_proxy = proxy_mode.tracer.create_proxy( | ||
| "call_function", call_local_map, proxy_args, {} | ||
| ) | ||
| out_proxy.node.meta["custom"] = { | ||
| "dtensor_local_map_kwargs": orig_fwd.local_map_kwargs, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Storing this in the meta gives me the jeebies, because if you lose the meta there's no way to get this information back, it is load bearing lol
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is sus lol, but meta is already load bearing for export/ac/ though, and we have special cased keys: https://github.com/pytorch/pytorch/blame/255c0545e7eac2ec6d00a41a3fc9d6d8201f8f39/torch/fx/proxy.py#L107-L120. "custom" was added specifically to try to decouple. I think it's probably one of the better places on the fx.Graph, and haven't really look at passing this info along outside of the graph obj.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Concretely, this sounds like AC is not cacheable by AOTAutograd then. But I guess AC isn't cacheable for other reasons too lol. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In defense of node.meta, this pattern exist because it's quite difficult to propagate information between different compiler modules when tracing. Same story with writing to some global like TracingContext/Virtualized. As long as the types can be cached, can't we snapshot their values at cache lookup time? |
||
| } | ||
| return track_tensor_tree( | ||
| example_out, out_proxy, constant=None, tracer=proxy_mode.tracer | ||
| ) | ||
|
|
||
|
|
||
| # Running HOP in eager with real tensors | ||
| @local_map_hop.py_impl(torch._C.DispatchKey.CPU) | ||
| @local_map_hop.py_impl(torch._C.DispatchKey.CUDA) | ||
| def real_impl( | ||
| orig_fwd, | ||
| *args, | ||
| **kwargs, | ||
| ): | ||
| return orig_fwd(*args, **kwargs) | ||
|
|
||
|
|
||
| def apply_local_map(*local_map_args, **local_map_kwargs): | ||
| # NOTE: We manually issue the hop, which will not be not necessary with a dynamo frontend. | ||
| # 1. Same as local_map, must be applied on a function, not a method. | ||
| # 2. the local_map'd function must be make_fx traceable. Otherwise, we may | ||
| # inline the wrong graph. In a dynamo frontend, speculate_subgraph will handle this. | ||
| # 3. All inputs to the local_map'd function must be Tensor types. Otherwise, we won't | ||
| # know which tensors to apply _from_fun to. For instance, don't pass nn.Modules to local_map. | ||
| # In dynamo frontend, tensors will be lifted, and will modify the wrapped function's signature. | ||
|
|
||
| assert local_map_kwargs[ | ||
| "redistribute_inputs" | ||
| ], "Autoparallel should always be allowed to redistribute inputs" | ||
| assert local_map_kwargs["in_grad_placements"] is None, "Not yet implemented" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. my first impression was
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think the user should be allowed to specify |
||
| assert local_map_kwargs["device_mesh"] is None, "Must be provided by Autoparallel" | ||
|
|
||
| def decorator(fn): | ||
| @functools.wraps(fn) | ||
| def wrapped(*args, **kwargs): | ||
| def orig_fwd(*runtime_args, **runtime_kwargs): | ||
| # wrap the functools.partial for hop utils to work out of box | ||
| return local_map( | ||
| fn, | ||
| *local_map_args, | ||
| **local_map_kwargs, | ||
| )(*runtime_args, **runtime_kwargs) | ||
|
|
||
| orig_fwd.local_map_kwargs = local_map_kwargs | ||
| return local_map_hop(orig_fwd, *args, **kwargs) | ||
|
|
||
| return wrapped | ||
|
|
||
| return decorator | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -90,7 +90,7 @@ | |
| estimate_strategy_runtime_cost, | ||
| ) | ||
| from .propagation_rules import _create_all_options | ||
| from .utils import get_placement_options | ||
| from .utils import get_local_map_placement_option, get_placement_options | ||
|
|
||
|
|
||
| def _debug_node(node): | ||
|
|
@@ -146,10 +146,31 @@ def build_sharding_metadata(self): | |
| user_kwargs = tree_map_only( | ||
| torch.fx.Node, lambda x: x.meta["val"], node.kwargs | ||
| ) | ||
| strat = get_placement_options( | ||
| self.mesh, node.target, user_strats, user_args, user_kwargs | ||
| ) | ||
| strats[node] = strat | ||
| if local_map_kwargs := node.meta.get("custom", {}).get( | ||
| "dtensor_local_map_kwargs" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we testing the custom meta as opposed to looking at op.target to identify a local map HOP?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we still need to pass the placement kwargs somehow, op.target alone wouldn't be enough |
||
| ): | ||
| assert "call_local_map" in str(node.target) | ||
| assert not user_kwargs | ||
| strat = get_local_map_placement_option( | ||
| self.mesh, | ||
| user_strats, | ||
| user_args, | ||
| node.meta["val"], | ||
| local_map_kwargs["in_placements"], | ||
| local_map_kwargs["out_placements"], | ||
| ) | ||
|
|
||
| assert not node.kwargs | ||
| node.kwargs = { | ||
| "_inline": True | ||
| } # notify the HOP to desugar in the next trace | ||
|
|
||
| strats[node] = strat | ||
| else: | ||
| strat = get_placement_options( | ||
| self.mesh, node.target, user_strats, user_args, user_kwargs | ||
| ) | ||
| strats[node] = strat | ||
| elif node.op == "output": | ||
| user_strats = tree_map_only( | ||
| torch.fx.Node, lambda x: strats[x], node.args | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we said no HOPs out of core (@ezyang). Especially if you want to sync this to fbcode.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I will upstream ASAP