Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 38 additions & 24 deletions rfcs/20180507-cond-v2.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,26 @@
:-------------- |:---------------------------------------------------- |
| **Author(s)** | Skye Wanderman-Milne ([email protected]) |
| **Created** | 2018-05-07 |
| **Updated** | 2018-08-22 |
| **Updated** | 2019-02-26 |
| **Implementation** | [cond_v2.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/cond_v2.py) |

## Objective

**Switch tf.cond to emit a single If op.**

We can do tf.while_loop next.

This would make mapping to XLA's control flow constructs easier/possible. In particular, just switching to the If op would be a big win (more work needed to get cond working with XLA than while_loop, which already had a lot of work done), and easier than while loop. It will also making debugging and analysis of cond constructs much simpler, e.g. to implement higher-order derivatives.
Benefits:
* Higher-order derivatives
* Better XLA/TPU integration
* Better error messages
* Fewer bugs

Note that cond will still support side-effecting ops (e.g. variable updates).


## Background material

Related tf.while_loop RFC: https://github.com/tensorflow/community/blob/master/rfcs/20180821-differentiable-functional-while.md

tf.cond API: https://www.tensorflow.org/api_docs/python/tf/cond

If op: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/functional_ops.cc#L104
Expand All @@ -33,7 +38,7 @@ Overview of current control flow implementation: [Implementation of Control Flow

The signature of `tf.cond` will stay the same: boolean predicate Tensor, and Python callables for the two branches. The two callables each take no arguments (they instead close over any input tensors), and are required to return the same number and type of tensors.

We need to convert this to the If op signature, which is a boolean predicate, and FunctionDefs for the two branches. The FunctionDefs are required to have the same number and type of inputs and outputs. Luckily, tfe.defun already gives us the machinery to convert the Python callables into FunctionDefs, including converting closures to inputs and adding extra inputs to make the branch signatures match. This is done via an overloaded Graph subclass, [FuncGraph](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/eager/function.py#L191), which gives us the full flexibility of graphs while creating the branch functions.
We need to convert this to the If op signature, which is a boolean predicate, and FunctionDefs for the two branches. The FunctionDefs are required to have the same number and type of inputs and outputs. Luckily, tf.function already gives us the machinery to convert the Python callables into FunctionDefs, including converting closures to inputs and adding extra inputs to make the branch signatures match. This is done via an overloaded Graph subclass, [FuncGraph](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/func_graph.py#L117), which gives us the full flexibility of graphs while creating the branch functions.

This conversion results in a single If op representing the `tf.cond`.

Expand Down Expand Up @@ -70,21 +75,17 @@ We don't want to lower If ops that will eventually be consumed by the [XLA encap

See the "Gradients" section. We somehow need to add intermediate tensors as outputs to the already-created forward-pass If op and its branch functions. Options:

Solution:\
Modify the existing If op in-place. We do this by replacing the branch functions, and changing the outputs of the op (tricky but doable).

Since this is mutating an existing graph element, if the graph has already been run by a Session, this theoretically invalidates the session! In practice the Session appears to still be usable though.

1. Create a new If op with the required outputs. To prevent running both the original and new ops, we need to rewire the outputs of the original op to use the new op (and ideally modify any existing Tensor objects as well).
1. Modify the existing If op in-place. This involves either modifying or replacing the branch functions, and changing the outputs of the op (tricky, but probably doable).

Note that both of these options require mutating existing graph elements. If the graph has already been run, **this will invalidate any existing Sessions!** Other options:

This is the same method that tf.function and [while_v2](https://github.com/tensorflow/community/blob/master/rfcs/20180821-differentiable-functional-while.md) use for intermediate gradients, making all of them compose nicely.


1. Use placeholders for intermediates during construction, then use a C++ rewrite (Grappler or GraphOptimizationPass) to rewire the graph.
1. Output every possible intermediate.
1. It might already work as-is.
1. Except for ExtendGraph -- solution could be to make C API and Session share Graph*

**Update**: we went with (2) output every possible intermediate
Alternatives considered:
1. Output every possible intermediate and rely on pruning to clean it up. The original implementation did this, but we changed it to match tf.function.
1. Create a new If op with the required outputs. To prevent running both the original and new ops, we need to rewire the outputs of the original op to use the new op (and ideally modify any existing Tensor objects as well). This also requires a Session-invalidating graph mutation.
1. Use placeholders for intermediates during construction, then use a C++ rewrite (Grappler or GraphOptimizationPass) to rewire the graph. This is effectively creating an alternative representation of the dataflow graph, which is undesireable (e.g. all graph traversal code would need to know about these special placeholders).


### Making branch function outputs match
Expand All @@ -93,26 +94,34 @@ After adding the intermediate outputs to the forward If op's branch functions, i

Note that the "mirror" tensors never need to be read. The original output is only consumed by the corresponding gradient function, which is only executed if the original output's branch is taken. Thus, if the mirror tensor is produced, no consumer of it will be run. However, without pruning and/or non-strict execution, the If op must still produce some value for the mirror tensor.

_Solution:_
Solution:\
Wrap all intermediate outputs in optionals. Optionals are like maybe or [optional types](https://en.wikipedia.org/wiki/Option_type) in TensorFlow. They are variant-type tensors that may or may not contain a value tensor, which are created and introspected by [these ops](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/dataset_ops.cc#L629).

Introduce a special op to output mirror tensors. This op's shape inference function will claim to output the same shape and type of the mirrored output, but since the tensor isn't actually needed the kernel will produce some small value to avoid producing large unnecessary values. If/when the op doesn't need to produce a value (e.g. via lowering + pruning), the kernel can CHECK or similar.
In the branch where the intermediate is actually produced, the intermediate tensor is wrapped in an optional via the [`OptionalFromValue` op](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/dataset_ops.cc#L629), meaning the output optional will contain the intermediate if that branch is taken. In the other branch, the "mirror" tensor is produced via the [`OptionalNone` op](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/dataset_ops.cc#L635), meaning the output optional will have no value if the other branch is taken.

Each gradient branch then only unwraps the optionals from its corresponding forward branch to pass to the gradient computation.

### Taking the gradient of deserialized If ops
Alternatives considered:
1. Output dead tensors as the "mirror" tensors, similar to the current tf.cond implementation. This requires changes to the executor to make it not mark If ops, While ops, functions ops, and possibly other special cases as dead if they have dead inputs, and prevents us from someday simplifying the excutor by removing the dead tensor logic.
1. Introduce a special op to output mirror tensors. This op's shape inference function will claim to output the same shape and type of the mirrored output, but since the tensor isn't actually needed the kernel will produce some small value to avoid producing large unnecessary values.

We need a graph representing the branch function of an If op in order to take its gradient. We already have a graph as part of creating the function, but if the graph was loaded from a GraphDef, we no longer have this graph. Options:

### Taking the gradient of deserialized If ops

We need a graph representing the branch function of an If op in order to take its gradient. We already have a graph as part of creating the function, but if the graph was loaded from a GraphDef, we no longer have this graph. Options:

1. FunctionDef → Graph method
Solution:\
[function_def_to_graph method](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/function_def_to_graph.py)


### Variable initialization
**TODO: this section needs updating**

Variables created in the `cond` input callables must be created in the main graph, not in the temporary `FuncGraphs`. Luckily this is already handled by [init_scope](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/ops.py#L5230), which should already be used as necessary to handle creating variables in Defuns, etc.


### Collections
**TODO: this section needs updating**

We must support reading and writing to collections in the `cond` input callables.

Expand All @@ -133,13 +142,15 @@ For example, how are people using collections within `cond` branches? How do the


### Name/device/colocation scope
**TODO: this section needs updating**

Similar to reading collections, any graph-wide stacks and other state can be copied into the `FuncGraphs`. New scopes can then be added within the FuncGraph, and the semantics prevent any added state from persisting beyond the input callable.

For colocation, we can possibly use external tensor names as-is, since they'll either be lowered into the main graph or compiled by XLA.


### Control dependencies
**TODO: this section needs updating**

If the `tf.cond` call occurs inside a control_dependencies block, the control inputs will be added directly to the resulting If op.

Expand All @@ -149,6 +160,7 @@ _The following concerns are avoided by lowering If ops before execution (see "Ex


### Devices
**TODO: this section needs updating**

Akshay is working on allowing functions to run across multiple devices. My understanding is that it's mostly working, with a few limitations (e.g. all arguments to the function must go through the caller device, colocation with external tensors doesn't work).

Expand All @@ -165,6 +177,8 @@ The current `cond` implementation allows each op in the taken branch to be run a

## Future work

**tf.while_loop**. This effort will solve most of the problems with switching to a functional While representation (or a recursive function representation?). The remaining challenges are inserting stacks for the gradients, and support parallel iterations.

**C API support.** Ideally other language bindings support conditional execution as well. The C API already includes the primitives for other bindings to implement something similar to `tf.cond` that produces an `If` op, but the C API `TF_AddGradients` method would need to support `If` ops in order for other bindings to (easily) allow autodiff of conditionals.

## Update log

2019-02-26: Updated some sections to reflect what was built and added link to implementation. Marked other sections as still needing update; many of these concerns are common to cond_v2, while_v2, and functions, so we may wanna include these as part of a function design doc.