-
Notifications
You must be signed in to change notification settings - Fork 8
Add local_map support #47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
HOP should live in pytorch/pytorch proper. OK to land here for now but better to get to the main repo sooner rather than later (in case someone is doing HOP refactoring). |
autoparallel retraces with make_fx, so it seems like it can just evaporate it when that happens. |
| out = parallel_mod(*x) | ||
| out.backward(torch.randn_like(out)) | ||
|
|
||
| print("All good!") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This example doesn't actually tell us that the local map regions had their sharding respected, you have to explicitly check it in the output! Additionally, the local_map'ed over functions are a bit "too" simple, and are expressible as DTensor operations, IIUC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I didn't write the checks into the code, but from the tlparse, the solver estimated costs look like the sharding decision made it through. i'm gonna try with the context parallel to see if it works already
autoparallel/utils.py
Outdated
|
|
||
|
|
||
| def get_local_map_placement_option( | ||
| mesh, op, specs, user_args, user_kwargs, output_val, in_placements, out_placements |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
type signature here would be great
| ) | ||
| strats[node] = strat | ||
| if local_map_kwargs := node.meta.get("custom", {}).get( | ||
| "dtensor_local_map_kwargs" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we testing the custom meta as opposed to looking at op.target to identify a local map HOP?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we still need to pass the placement kwargs somehow, op.target alone wouldn't be enough
| assert local_map_kwargs[ | ||
| "redistribute_inputs" | ||
| ], "Autoparallel should always be allowed to redistribute inputs" | ||
| assert local_map_kwargs["in_grad_placements"] is None, "Not yet implemented" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my first impression was in_grad_placements can be supported out-of-box, but it seems not the case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the user should be allowed to specify in_grad_placements, since in autoparallel, the user would have no way of knowing the solver's decision at the time of writing local_map. But there could be an API to actually tell the solver to shard the grads in a certain way. TBD, don't have a clear proposal yet
| """ | ||
|
|
||
| def __init__(self): | ||
| super().__init__("local_map_hop") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe this is hard, but it would be nice if some info about orig_fwd (e.g. name/line) were included in the name of the hop, so that the graph doesn't just show "hop1, hop2, hop3" with no info about what's inside
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
having stack traces in the autoparallel graphs for all nodes would address this issue too
| with suspend_functionalization(), disable_functional_mode(): | ||
| with disable_proxy_modes_tracing(): | ||
|
|
||
| def _from_fun(t): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
more descriptive name? (maybe _empty_like() or something)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, i totally missed that fun means functional. Although I still don't know what that means.
| isinstance(t, (FakeTensor, int, torch.SymInt)) for t in fw_inputs | ||
| ) | ||
|
|
||
| # redundant? we already _from_fun'd the inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the concern that usercode could allocate new tensors? if they did, they would already be fake because of the mode, so i'm not sure why we can't return them directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm not sure why other HOPs do this. you shouldn't need to unwrap the outputs if the inputs were already unwrapped and the fw_func doesn't introduce new functional tensors.
like you said allocating new tensors should just be fakes, and not be functional wrapped. the only thing i can think of is if fw_inputs weren't properly unwrapped
| for example_grad in example_grads: | ||
| if example_grad.requires_grad: | ||
| optional_grads.append(example_grad) | ||
| _, grads = joint(fw_inputs, optional_grads) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i was expecting joint to return both forward outputs and grads. maybe i just misunderstood this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could, it doesn't matter if fw_func always gives the same graph
autoparallel/local_map_hop.py
Outdated
| save_tensors_and_symints_for_backward(ctx, args) | ||
| ctx.joint_graph = joint_graph | ||
|
|
||
| with torch._C._AutoDispatchBelowAutograd(): # why |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is this for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC we need to do this because issuing ops will always go through all their keys by default. Dispatching below autograd hits the functional key next which enters the redispatch_to_next context, which applies to all keys afterwards.
Here, I'm not sure why we need to explicitly use this API here instead of entering the redispatch_to_next context already. Just following some existing HOPs code...
|
|
||
| unwrapped_inputs = ctx.unwrap_tensors(args) | ||
| with ctx.redispatch_to_next(): | ||
| # TODO: dynamo safety checks on orig_fwd |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes please. It's a little hard to believe that there isn't an easy utility for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is, speculate_subgraph, but it needs dynamo to apply it, and pass a safe orig_fwd_graph instead of orig_fwd to the hop.
Or were you thinking of torch._subclasses.functional_tensor.*.functionalize
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have a util check_input_alias_and_mutation_return_outputs for checking aliasing and mutations of tensor inputs. For others like side-effects, lifting tensor closures/symints as inputs (to make sure dependency correctly captured), we would need to turn on dynamo and use speculate_subgraph.
| "call_function", call_local_map, proxy_args, {} | ||
| ) | ||
| out_proxy.node.meta["custom"] = { | ||
| "dtensor_local_map_kwargs": orig_fwd.local_map_kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Storing this in the meta gives me the jeebies, because if you lose the meta there's no way to get this information back, it is load bearing lol
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is sus lol, but meta is already load bearing for export/ac/ though, and we have special cased keys: https://github.com/pytorch/pytorch/blame/255c0545e7eac2ec6d00a41a3fc9d6d8201f8f39/torch/fx/proxy.py#L107-L120. "custom" was added specifically to try to decouple. I think it's probably one of the better places on the fx.Graph, and haven't really look at passing this info along outside of the graph obj.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Concretely, this sounds like AC is not cacheable by AOTAutograd then. But I guess AC isn't cacheable for other reasons too lol.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In defense of node.meta, this pattern exist because it's quite difficult to propagate information between different compiler modules when tracing. Same story with writing to some global like TracingContext/Virtualized. As long as the types can be cached, can't we snapshot their values at cache lookup time?
| local_map_hop = LocalMapAOTExportModule() | ||
|
|
||
|
|
||
| def create_hop_joint_graph( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this different than create_joint?
ezyang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have any problem with this going in as an unblock. I think it probably needs more work before going in pytorch core.
| from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree | ||
|
|
||
|
|
||
| class LocalMapAOTExportModule(HigherOrderOperator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we said no HOPs out of core (@ezyang). Especially if you want to sync this to fbcode.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I will upstream ASAP
examples/example_local_map.py
Outdated
| [Replicate(), Replicate(), Shard(2)], | ||
| [Replicate(), Replicate(), Shard(2)], | ||
| [Replicate(), Replicate(), Shard(2)], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we instead make the placement for the query be S(0), S(1), S(2) , and for the keys and values to be S(0), S(1), R ? This way we are closer to actually enforcing CP, as we will force to all-gather the keys and values, and things should work properly (as long as the is_causal=False).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
examples/example_local_map.py
Outdated
|
|
||
| @apply_local_map( | ||
| out_placements=[ | ||
| [Replicate(), Replicate(), Shard(2)], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we make the change as I mentioned before, the output placement would be S(0), S(1), S(2)
This PR adds a HOP working with the
aot_module_exportfrontend of autoparallel. Lack of Dynamo frontend imposes some limitations documented on theapply_local_mapAPI, which will also no longer be needed once we move to Dynamo.At a high-level:
local_map'd functions intolocal_map_hoplocal_map_hopimplements dispatch keys to prevent the compiler from inlining into it. It instead inserts acall_functionnode into the tracer's graph, with node metadata containing thelocal_mapplacement informationlocal_mapplacement information, and allows the node to be inlined in the next trace by adding a kwarg:_inline. This special handling is actually desired, we need the solver to consider the placements mandated bylocal_mapin it's global cost optimization.local_map_hopis inlined into, with tensors that should already respect thelocal_mapplacements, thuslocal_mapbecomes a no-op. I considered having the HOP exclude thelocal_mapand only wrap the inner function, but I opted out because it would require some special handling of the HOP, and it would only work for Autoparallel.Test running:
PYTHONPATH=. python examples/example_local_map.pySharding and cost estimations look reasonable: P1881707012.