Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit c4faa81

Browse files
skyeewilderj
authored andcommitted
Post-hoc cond_v2 design doc (#12)
Make early design doc public.
1 parent e81f0f2 commit c4faa81

File tree

1 file changed

+170
-0
lines changed

1 file changed

+170
-0
lines changed

rfcs/20180507-cond-v2.md

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# **"Functional"** **cond design doc**
2+
3+
| Status | Approved |
4+
:-------------- |:---------------------------------------------------- |
5+
| **Author(s)** | Skye Wanderman-Milne ([email protected]) |
6+
| **Created** | 2018-05-07 |
7+
| **Updated** | 2018-08-22 |
8+
9+
## Objective
10+
11+
**Switch tf.cond to emit a single If op.**
12+
13+
We can do tf.while_loop next.
14+
15+
This would make mapping to XLA's control flow constructs easier/possible. In particular, just switching to the If op would be a big win (more work needed to get cond working with XLA than while_loop, which already had a lot of work done), and easier than while loop. It will also making debugging and analysis of cond constructs much simpler, e.g. to implement higher-order derivatives.
16+
17+
Note that cond will still support side-effecting ops (e.g. variable updates).
18+
19+
20+
## Background material
21+
22+
tf.cond API: https://www.tensorflow.org/api_docs/python/tf/cond
23+
24+
If op: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/functional_ops.cc#L104
25+
26+
Overview of current control flow implementation: [Implementation of Control Flow in TensorFlow](http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf)
27+
28+
29+
## Design overview
30+
31+
32+
### Functional tf.cond
33+
34+
The signature of `tf.cond` will stay the same: boolean predicate Tensor, and Python callables for the two branches. The two callables each take no arguments (they instead close over any input tensors), and are required to return the same number and type of tensors.
35+
36+
We need to convert this to the If op signature, which is a boolean predicate, and FunctionDefs for the two branches. The FunctionDefs are required to have the same number and type of inputs and outputs. Luckily, tfe.defun already gives us the machinery to convert the Python callables into FunctionDefs, including converting closures to inputs and adding extra inputs to make the branch signatures match. This is done via an overloaded Graph subclass, [FuncGraph](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/eager/function.py#L191), which gives us the full flexibility of graphs while creating the branch functions.
37+
38+
This conversion results in a single If op representing the `tf.cond`.
39+
40+
41+
### Gradients
42+
43+
The gradient of an If op is another If op. The predicate is the same as the forward op's, and each branch function is the gradient function of the corresponding forward branch.
44+
45+
This requires the gradient branch functions to access intermediate tensors of the forward branch functions. Internal tensors in a function can't be directly accessed, so we need to add the necessary intermediates as outputs to the forward If op (how to do this is discussed in the "Implementation challenges" section).
46+
47+
48+
### Execution
49+
50+
There are two choices for running the resulting If ops:
51+
52+
53+
54+
1. Use the `IfOp` kernel as-is, which runs the functions using `FunctionLibraryRuntime`.
55+
1. "Lower" the If ops to the current `tf.cond` implementation (i.e. `Switch` and `Merge` nodes).
56+
57+
(1) is simpler at a high level, but (2) will avoid some of the implementation challenges below.
58+
59+
The lowering can be implemented as an early (pre-placement) optimization pass, in order for the lowered control flow to be placed, pruned, partitioned, etc. as usual. There are already a few examples of similar passes: ParallelConcatRemovePass, AccumulateNV2RemovePass
60+
61+
**Update**: this is done: [LowerIfOpPass](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/common_runtime/lower_if_op.h)
62+
63+
We don't want to lower If ops that will eventually be consumed by the [XLA encapsulation pass](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/jit/jit_compilation_pass_registration.cc#L35), so the TF-XLA bridge can take advantage of the easy-to-convert functional representation. This can be achieved by setting an attribute on the If op indicating whether it should be lowered, determined by e.g. if the If op is in an `XLAContext`. This may prove useful for other future use cases as well, such as transitioning to using the functional representation in the main TF runtime.
64+
65+
66+
## Implementation challenges
67+
68+
69+
### Exposing intermediate tensors to gradient functions
70+
71+
See the "Gradients" section. We somehow need to add intermediate tensors as outputs to the already-created forward-pass If op and its branch functions. Options:
72+
73+
74+
75+
1. Create a new If op with the required outputs. To prevent running both the original and new ops, we need to rewire the outputs of the original op to use the new op (and ideally modify any existing Tensor objects as well).
76+
1. Modify the existing If op in-place. This involves either modifying or replacing the branch functions, and changing the outputs of the op (tricky, but probably doable).
77+
78+
Note that both of these options require mutating existing graph elements. If the graph has already been run, **this will invalidate any existing Sessions!** Other options:
79+
80+
81+
82+
1. Use placeholders for intermediates during construction, then use a C++ rewrite (Grappler or GraphOptimizationPass) to rewire the graph.
83+
1. Output every possible intermediate.
84+
1. It might already work as-is.
85+
1. Except for ExtendGraph -- solution could be to make C API and Session share Graph*
86+
87+
**Update**: we went with (2) output every possible intermediate
88+
89+
90+
### Making branch function outputs match
91+
92+
After adding the intermediate outputs to the forward If op's branch functions, it's likely the two functions don't have the same output signature anymore. For each new output of each branch, we need to add an extra output tensor to the other branch to mirror it (since the If op requires the two outputs signatures match).
93+
94+
Note that the "mirror" tensors never need to be read. The original output is only consumed by the corresponding gradient function, which is only executed if the original output's branch is taken. Thus, if the mirror tensor is produced, no consumer of it will be run. However, without pruning and/or non-strict execution, the If op must still produce some value for the mirror tensor.
95+
96+
_Solution:_
97+
98+
Introduce a special op to output mirror tensors. This op's shape inference function will claim to output the same shape and type of the mirrored output, but since the tensor isn't actually needed the kernel will produce some small value to avoid producing large unnecessary values. If/when the op doesn't need to produce a value (e.g. via lowering + pruning), the kernel can CHECK or similar.
99+
100+
101+
### Taking the gradient of deserialized If ops
102+
103+
We need a graph representing the branch function of an If op in order to take its gradient. We already have a graph as part of creating the function, but if the graph was loaded from a GraphDef, we no longer have this graph. Options:
104+
105+
106+
107+
1. FunctionDef → Graph method
108+
109+
110+
### Variable initialization
111+
112+
Variables created in the `cond` input callables must be created in the main graph, not in the temporary `FuncGraphs`. Luckily this is already handled by [init_scope](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/ops.py#L5230), which should already be used as necessary to handle creating variables in Defuns, etc.
113+
114+
115+
### Collections
116+
117+
We must support reading and writing to collections in the `cond` input callables.
118+
119+
Reading from collections in eager-mode defuns already works by copying the collections into the `FuncGraphs`, which should presumably work here as well.
120+
121+
For writing, we'll have to forward or copy the values back to the original collections. This is tricky and poorly-defined for Tensor and Operation values, and possibly intractable for data structures containing graph elements (e.g. `WhileContext`). Options:
122+
123+
124+
125+
1. Collections are supposed to go away in TF 2.0
126+
1. Somehow turn Tensors into function outputs
127+
1. Can some tensors/operations be pulled out of the function?
128+
1. Expose "legacy cond" in contrib, eventually deprecate.
129+
130+
**Writing to collections requires more investigation.**
131+
132+
For example, how are people using collections within `cond` branches? How do they avoid dead Tensors?
133+
134+
135+
### Name/device/colocation scope
136+
137+
Similar to reading collections, any graph-wide stacks and other state can be copied into the `FuncGraphs`. New scopes can then be added within the FuncGraph, and the semantics prevent any added state from persisting beyond the input callable.
138+
139+
For colocation, we can possibly use external tensor names as-is, since they'll either be lowered into the main graph or compiled by XLA.
140+
141+
142+
### Control dependencies
143+
144+
If the `tf.cond` call occurs inside a control_dependencies block, the control inputs will be added directly to the resulting If op.
145+
146+
If the `cond` input callables contain control_dependencies blocks referring external tensors, we can create Identity nodes of the external tensors inside the function definition, and then create internal control edges (functions only have data inputs).
147+
148+
_The following concerns are avoided by lowering If ops before execution (see "Execution" section):_
149+
150+
151+
### Devices
152+
153+
Akshay is working on allowing functions to run across multiple devices. My understanding is that it's mostly working, with a few limitations (e.g. all arguments to the function must go through the caller device, colocation with external tensors doesn't work).
154+
155+
156+
### Partial evaluation
157+
158+
TF graphs are pruned before execution, meaning only the subgraph needed to compute the requested output tensors is run (this doesn't work completely for ops in a conditional branch, but some pruning still occurs). This is not currently possible with TF functions; the entire function is run regardless of which outputs are needed. This would need to be supported for parity with the current `cond` implementation.
159+
160+
161+
### Non-strict execution
162+
163+
The current `cond` implementation allows each op in the taken branch to be run as soon as its inputs are ready, even if other ops in the branch aren't ready yet ("non-strict" execution). However, each TF op kernel will only begin running once it's inputs are all ready ("strict" execution), with `Merge` nodes being the only exception. If we replace the current `cond` construct with a single op, this will switch `cond` to strict execution. We would need to support non-strict execution of If ops and their branch functions.
164+
165+
166+
## Future work
167+
168+
**tf.while_loop**. This effort will solve most of the problems with switching to a functional While representation (or a recursive function representation?). The remaining challenges are inserting stacks for the gradients, and support parallel iterations.
169+
170+
**C API support.** Ideally other language bindings support conditional execution as well. The C API already includes the primitives for other bindings to implement something similar to `tf.cond` that produces an `If` op, but the C API `TF_AddGradients` method would need to support `If` ops in order for other bindings to (easily) allow autodiff of conditionals.

0 commit comments

Comments
 (0)