Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 28 additions & 32 deletions rfcs/20180821-differentiable-functional-while.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
:---------------|:-----------------------------------------------------|
| **Author** | Saurabh Saxena (Google) |
| **Sponsor** | Skye Wanderman-Milne (Google) |
| **Updated** | 2018-08-23 |
| **Updated** | 2018-03-01 |
| **Implementation** | [while_v2.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/while_v2.py) |


## Objective
Expand All @@ -28,29 +29,6 @@ We recently added a differentiable version of the [functional If/cond op](https:
## Design Proposal


### Accumulating intermediates


#### Stack vs TensorArray vs TensorList

The current implementation uses [Stacks](https://github.com/tensorflow/tensorflow/blob/51100a8de57ef53e36a8a9f5a9829cbd33fbed04/tensorflow/python/ops/control_flow_ops.py#L1002) for accumulating intermediate values from the forward pass that may be needed for gradient computation. This implementation will use [TensorLists](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/list_ops.cc)(TL) instead which, unlike Stack and TensorArray, do not have a mutable internal state making them easy to differentiate.


#### Algorithm

For each intermediate tensor of the while loop function body that may be needed for gradient computation, we create an empty TensorList and add it to the list of loop_vars. We then push the intermediate values to the TL using the [TensorListPushBack](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/list_ops.cc#L40) op. Note that this way we may be accumulating more tensors than are actually needed for gradient computation. It is even possible that the graph is just used for inference and hence we do not need the accumulators at all! We rely on the [C++ optimization pass](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/grappler/optimizers/model_pruner.cc) that happens after the While op is lowered to remove all such superfluous accumulators. So adding extra accumulators will not have any performance or memory overhead at runtime.

To facilitate use-cases where lowering is not desired we can perform a few optimizations to the functional form of the While op:

* Expose only those intermediate values that are required by the backward pass by building the gradient graph in the forward pass.
* This will increase graph building time.
* Do not accumulate Const nodes. We can lift these outside the while loop.
* Do not accumulate loop vars that are passed-through unchanged.
* Rewrite the forward pass to add accumulators when gradients are requested.
* This will require creating a new While op and new FunctionDefs for the loop condition and body.
* Since we cannot remove nodes from the Graph there will be unused functions and the dangling While op in the GraphDef. These will however be pruned out at runtime and hence will not affect performance or correctness.


### Computing gradients

Excerpt from white paper on [Control Flow in TensorFlow](http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf):
Expand All @@ -65,22 +43,37 @@ Excerpt from white paper on [Control Flow in TensorFlow](http://download.tensorf
>
> Where `N` is the number of iterations that the forward while loop runs, `g_body` is the gradient of the forward loop body, and `g_vars` is the initial values for the loop variables. As we will see later, `g_vars` includes the initial gradients for the loop variables of the forward while loop.

We use the same logic here as well. To get a count of the number of forward iterations we add an integer counter which is initialized to 0 and is incremented in the loop body. Note that we just need the total number of iterations for the gradient pass so we do not need to accumulate the intermediate values of the counter. This counter is always the first output of the While op.
We use the same logic here as well. To compute *g_body* we use the [gradients_impl._GradientsHelper](https://github.com/tensorflow/tensorflow/blob/600caf99897e82cd0db8665acca5e7630ec1a292/tensorflow/python/ops/gradients_impl.py#L599) function which supports computing the gradient of a given [src_graph](https://github.com/tensorflow/tensorflow/blob/600caf99897e82cd0db8665acca5e7630ec1a292/tensorflow/python/ops/gradients_impl.py#L607) in another graph, which in this case is a [FuncGraph](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/func_graph.py#L117). This gradient graph captures references to the intermediate values of the forward graph (the src_graph). Each iteration of *g_body* needs to use the intermediate values from the corresponding iteration of *body*; thus, we accumulate the needed forward values and replace these references with values from the accumulators. See the *Accumulating intermediates* section below.

To get a count of the number of forward iterations we add an integer counter to every while_loop which is initialized to 0 and is incremented in the loop body. Note that we just need the total number of iterations for the gradient pass so we do not need to accumulate the intermediate values of the counter. This counter is always the first output of the While op.

Note that the [While gradient function](https://github.com/tensorflow/tensorflow/blob/ec1effdb69d33c947f30a5155c5cc4104c07a87e/tensorflow/python/ops/while_v2.py#L248) assumes that the first loop output is the
number of loop iterations. The While op generated by the gradient function satisfies the above constraints and hence can be differentiated again to generate higher-order derivatives. However, arbitrary While ops generated outside of this design may be differentiated incorrectly.


### Accumulating intermediates

To compute *g_body* we use the [gradients_impl._GradientsHelper](https://github.com/tensorflow/tensorflow/blob/600caf99897e82cd0db8665acca5e7630ec1a292/tensorflow/python/ops/gradients_impl.py#L599) function which supports computing the gradient of a given [src_graph](https://github.com/tensorflow/tensorflow/blob/600caf99897e82cd0db8665acca5e7630ec1a292/tensorflow/python/ops/gradients_impl.py#L607) in another graph, which in this case is a [_FuncGraph](https://github.com/tensorflow/tensorflow/blob/600caf99897e82cd0db8665acca5e7630ec1a292/tensorflow/python/framework/function.py#L621). This gradient graph captures references to the intermediate values of the forward graph(the src_graph). We replace these references with popped values from the accumulators of the intermediate tensors. Note that these accumulators were already added to the list of loop_vars of the While op and hence were in the list of outputs of the forward While op.

We will register a custom python [gradient function](https://github.com/tensorflow/tensorflow/blob/0440ccfc199cbffc10aae19fde07f0100c823ed9/tensorflow/python/framework/ops.py#L2352) to compute the gradient of a functional While op. This will allow taking the gradient of any functional While op(not only the ones generated by the new while_loop function) which satisfies the following conditions:
#### Stack vs TensorArray vs TensorList

The current implementation uses [Stacks](https://github.com/tensorflow/tensorflow/blob/51100a8de57ef53e36a8a9f5a9829cbd33fbed04/tensorflow/python/ops/control_flow_ops.py#L1002) for accumulating intermediate values from the forward pass that may be needed for gradient computation. The new implementation uses [TensorLists](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/list_ops.cc)(TL) instead which, unlike Stack and TensorArray, do not have a mutable internal state making them easy to differentiate.


1. The first loop output must be the number of loop iterations.
1. Each intermediate tensor of the While body which may be needed during gradient computation must be accumulated in a TensorList. We will check to make sure that the TensorList is indeed unique to the intermediate value.
1. The position of the accumulator in the list of inputs and outputs must be the same.
#### Algorithm

The While op generated by the gradient function satisfies the above constraints and hence can be differentiated again to generate the 2nd order derivative and so on.
After constructing the gradient function of the loop body, we rewrite the loop body function and forward While op to output the needed intermediates. Specifically, for each intermediate tensor of the loop body function that's be needed for the gradient computation, we create an empty TensorList (TL) and add it to the list of forward input loop_vars. In the forward loop body, we then push the intermediate values to the TL using the [TensorListPushBack](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/list_ops.cc#L40) op, and output the final list of accumulated intermediates. The gradient computation then takes the TL as input and pops intermediate values off the TL.

In the case of nested while loops, we will accumulate the intermediate values of inner while loops in nested TensorLists.

To-be-implemented improvements to this algorithm:
* Do not accumulate Const nodes. We can lift these outside the while loop.
* Do not accumulate loop vars that are passed-through unchanged.

Alternatives considered:
* Output every possible intermediate and rely on pruning to clean it up. The original implementation did this, but we changed it to match tf.function.
* Expose only those intermediate values that are required by the backward pass by building the gradient graph in the forward pass.
* This will increase graph building time.


### Memory management

Expand All @@ -94,7 +87,6 @@ In order to get feature parity with the current implementation we will lower the


1. We can perform parallel iterations which are not possible due to the strict mode execution of functions which requires that all inputs to the function must be ready before the function can start executing. We will need to add a `parallel_iterations` attr to the While op.
1. The FunctionLibraryRuntime currently does not allow running multi-device functions.
1. We can perform global grappler optimizations without needing to cross function boundaries. E.g. we can remove accumulators for intermediate values which are not consumed downstream.


Expand Down Expand Up @@ -162,3 +154,7 @@ Accumulators:
## Discussion notes

Please see notes in [tensorflow/community#13](https://github.com/tensorflow/community/pull/13#issuecomment-422591773).

## Update log

2019-03-1: Updated some sections to reflect what was built and added link to implementation.