From 38ca627955e5ccdb7eb59e4a362f57e89e46d32a Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Tue, 26 Feb 2019 10:02:53 -0800 Subject: [PATCH 1/4] cond_v2 update --- rfcs/20180507-cond-v2.md | 58 +++++++++++++++++++++++----------------- 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/rfcs/20180507-cond-v2.md b/rfcs/20180507-cond-v2.md index fa06ca61e..35ac22ba9 100644 --- a/rfcs/20180507-cond-v2.md +++ b/rfcs/20180507-cond-v2.md @@ -4,21 +4,26 @@ :-------------- |:---------------------------------------------------- | | **Author(s)** | Skye Wanderman-Milne (skyewm@gmail.com) | | **Created** | 2018-05-07 | -| **Updated** | 2018-08-22 | +| **Updated** | 2019-02-26 | +| **Implementation** | [cond_v2.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/cond_v2.py) | ## Objective **Switch tf.cond to emit a single If op.** -We can do tf.while_loop next. - -This would make mapping to XLA's control flow constructs easier/possible. In particular, just switching to the If op would be a big win (more work needed to get cond working with XLA than while_loop, which already had a lot of work done), and easier than while loop. It will also making debugging and analysis of cond constructs much simpler, e.g. to implement higher-order derivatives. +Benefits: +* Second-order derivatives +* Better XLA/TPU integration +* Better error messages +* Fewer bugs Note that cond will still support side-effecting ops (e.g. variable updates). ## Background material +Related tf.while_loop RFC: https://github.com/tensorflow/community/blob/master/rfcs/20180821-differentiable-functional-while.md + tf.cond API: https://www.tensorflow.org/api_docs/python/tf/cond If op: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/functional_ops.cc#L104 @@ -33,7 +38,7 @@ Overview of current control flow implementation: [Implementation of Control Flow The signature of `tf.cond` will stay the same: boolean predicate Tensor, and Python callables for the two branches. The two callables each take no arguments (they instead close over any input tensors), and are required to return the same number and type of tensors. -We need to convert this to the If op signature, which is a boolean predicate, and FunctionDefs for the two branches. The FunctionDefs are required to have the same number and type of inputs and outputs. Luckily, tfe.defun already gives us the machinery to convert the Python callables into FunctionDefs, including converting closures to inputs and adding extra inputs to make the branch signatures match. This is done via an overloaded Graph subclass, [FuncGraph](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/eager/function.py#L191), which gives us the full flexibility of graphs while creating the branch functions. +We need to convert this to the If op signature, which is a boolean predicate, and FunctionDefs for the two branches. The FunctionDefs are required to have the same number and type of inputs and outputs. Luckily, tf.function already gives us the machinery to convert the Python callables into FunctionDefs, including converting closures to inputs and adding extra inputs to make the branch signatures match. This is done via an overloaded Graph subclass, [FuncGraph](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/func_graph.py#L117), which gives us the full flexibility of graphs while creating the branch functions. This conversion results in a single If op representing the `tf.cond`. @@ -70,21 +75,17 @@ We don't want to lower If ops that will eventually be consumed by the [XLA encap See the "Gradients" section. We somehow need to add intermediate tensors as outputs to the already-created forward-pass If op and its branch functions. Options: +Solution:\ +Modify the existing If op in-place. We do this by replacing the branch functions, and changing the outputs of the op (tricky but doable). +Since this is mutating an existing graph element, if the graph has already been run by a Session, this theoretically invalidates the session! In practice the Session appears to still be usable though. -1. Create a new If op with the required outputs. To prevent running both the original and new ops, we need to rewire the outputs of the original op to use the new op (and ideally modify any existing Tensor objects as well). -1. Modify the existing If op in-place. This involves either modifying or replacing the branch functions, and changing the outputs of the op (tricky, but probably doable). - -Note that both of these options require mutating existing graph elements. If the graph has already been run, **this will invalidate any existing Sessions!** Other options: - - - -1. Use placeholders for intermediates during construction, then use a C++ rewrite (Grappler or GraphOptimizationPass) to rewire the graph. -1. Output every possible intermediate. -1. It might already work as-is. - 1. Except for ExtendGraph -- solution could be to make C API and Session share Graph* +This is the same method that tf.function and [while_v2](https://github.com/tensorflow/community/blob/master/rfcs/20180821-differentiable-functional-while.md) use for intermediate gradients, making all of them compose nicely. -**Update**: we went with (2) output every possible intermediate +Alternatives considered:\ +1. Output every possible intermediate and rely on pruning to clean it up. The original implementation did this, but we changed it to match tf.function. +1. Create a new If op with the required outputs. To prevent running both the original and new ops, we need to rewire the outputs of the original op to use the new op (and ideally modify any existing Tensor objects as well). This also requires a Session-invalidating graph mutation. +1. Use placeholders for intermediates during construction, then use a C++ rewrite (Grappler or GraphOptimizationPass) to rewire the graph. This is effectively creating an alternative representation of the dataflow graph, which is undesireable (e.g. all graph traversal code would need to know about these special placeholders). ### Making branch function outputs match @@ -93,26 +94,34 @@ After adding the intermediate outputs to the forward If op's branch functions, i Note that the "mirror" tensors never need to be read. The original output is only consumed by the corresponding gradient function, which is only executed if the original output's branch is taken. Thus, if the mirror tensor is produced, no consumer of it will be run. However, without pruning and/or non-strict execution, the If op must still produce some value for the mirror tensor. -_Solution:_ +Solution:\ +Wrap all intermediate outputs in optionals. Optionals are like maybe or [optional types](https://en.wikipedia.org/wiki/Option_type) in TensorFlow. They are variant-type tensors that may or may not contain a value tensor, which are created and introspected by [these ops](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/dataset_ops.cc#L629). -Introduce a special op to output mirror tensors. This op's shape inference function will claim to output the same shape and type of the mirrored output, but since the tensor isn't actually needed the kernel will produce some small value to avoid producing large unnecessary values. If/when the op doesn't need to produce a value (e.g. via lowering + pruning), the kernel can CHECK or similar. +In the branch where the intermediate is actually produced, the intermediate tensor is wrapped in an optional via the [`OptionalFromValue` op](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/dataset_ops.cc#L629), meaning the output optional will contain the intermediate if that branch is taken. In the other branch, the "mirror" tensor is produced via the [`OptionalNone` op](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/dataset_ops.cc#L635), meaning the output optional will have no value if the other branch is taken. +Each gradient branch then only unwraps the optionals from its corresponding forward branch to pass to the gradient computation. -### Taking the gradient of deserialized If ops +Alternatives considered:\ +1. Output dead tensors as the "mirror" tensors, similar to the current tf.cond implementation. This requires changes to the executor to make it not mark If ops, While ops, functions ops, and possibly other special cases as dead if they have dead inputs, and prevents us from someday simplifying the excutor by removing the dead tensor logic. +1. Introduce a special op to output mirror tensors. This op's shape inference function will claim to output the same shape and type of the mirrored output, but since the tensor isn't actually needed the kernel will produce some small value to avoid producing large unnecessary values. -We need a graph representing the branch function of an If op in order to take its gradient. We already have a graph as part of creating the function, but if the graph was loaded from a GraphDef, we no longer have this graph. Options: +### Taking the gradient of deserialized If ops +We need a graph representing the branch function of an If op in order to take its gradient. We already have a graph as part of creating the function, but if the graph was loaded from a GraphDef, we no longer have this graph. Options: -1. FunctionDef → Graph method +Solution:\ +[function_def_to_graph method](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/function_def_to_graph.py) ### Variable initialization +**Needs update** Variables created in the `cond` input callables must be created in the main graph, not in the temporary `FuncGraphs`. Luckily this is already handled by [init_scope](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/ops.py#L5230), which should already be used as necessary to handle creating variables in Defuns, etc. ### Collections +**Needs update** We must support reading and writing to collections in the `cond` input callables. @@ -133,6 +142,7 @@ For example, how are people using collections within `cond` branches? How do the ### Name/device/colocation scope +**Needs update** Similar to reading collections, any graph-wide stacks and other state can be copied into the `FuncGraphs`. New scopes can then be added within the FuncGraph, and the semantics prevent any added state from persisting beyond the input callable. @@ -140,6 +150,7 @@ For colocation, we can possibly use external tensor names as-is, since they'll e ### Control dependencies +**Needs update** If the `tf.cond` call occurs inside a control_dependencies block, the control inputs will be added directly to the resulting If op. @@ -149,6 +160,7 @@ _The following concerns are avoided by lowering If ops before execution (see "Ex ### Devices +**Needs update** Akshay is working on allowing functions to run across multiple devices. My understanding is that it's mostly working, with a few limitations (e.g. all arguments to the function must go through the caller device, colocation with external tensors doesn't work). @@ -165,6 +177,4 @@ The current `cond` implementation allows each op in the taken branch to be run a ## Future work -**tf.while_loop**. This effort will solve most of the problems with switching to a functional While representation (or a recursive function representation?). The remaining challenges are inserting stacks for the gradients, and support parallel iterations. - **C API support.** Ideally other language bindings support conditional execution as well. The C API already includes the primitives for other bindings to implement something similar to `tf.cond` that produces an `If` op, but the C API `TF_AddGradients` method would need to support `If` ops in order for other bindings to (easily) allow autodiff of conditionals. From f6eb4ed3c8a4a95e2abce51e1899ea6e12721ab9 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Tue, 26 Feb 2019 13:21:51 -0800 Subject: [PATCH 2/4] minor fix --- rfcs/20180507-cond-v2.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/rfcs/20180507-cond-v2.md b/rfcs/20180507-cond-v2.md index 35ac22ba9..9d4010f64 100644 --- a/rfcs/20180507-cond-v2.md +++ b/rfcs/20180507-cond-v2.md @@ -115,13 +115,13 @@ Solution:\ ### Variable initialization -**Needs update** +**TODO: this section needs updating** Variables created in the `cond` input callables must be created in the main graph, not in the temporary `FuncGraphs`. Luckily this is already handled by [init_scope](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/ops.py#L5230), which should already be used as necessary to handle creating variables in Defuns, etc. ### Collections -**Needs update** +**TODO: this section needs updating** We must support reading and writing to collections in the `cond` input callables. @@ -142,7 +142,7 @@ For example, how are people using collections within `cond` branches? How do the ### Name/device/colocation scope -**Needs update** +**TODO: this section needs updating** Similar to reading collections, any graph-wide stacks and other state can be copied into the `FuncGraphs`. New scopes can then be added within the FuncGraph, and the semantics prevent any added state from persisting beyond the input callable. @@ -150,7 +150,7 @@ For colocation, we can possibly use external tensor names as-is, since they'll e ### Control dependencies -**Needs update** +**TODO: this section needs updating** If the `tf.cond` call occurs inside a control_dependencies block, the control inputs will be added directly to the resulting If op. @@ -160,7 +160,7 @@ _The following concerns are avoided by lowering If ops before execution (see "Ex ### Devices -**Needs update** +**TODO: this section needs updating** Akshay is working on allowing functions to run across multiple devices. My understanding is that it's mostly working, with a few limitations (e.g. all arguments to the function must go through the caller device, colocation with external tensors doesn't work). From 1efbf1e3277a36d2b8eeb4daa81e346e26704c20 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Wed, 27 Feb 2019 11:13:03 -0800 Subject: [PATCH 3/4] update log --- rfcs/20180507-cond-v2.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/rfcs/20180507-cond-v2.md b/rfcs/20180507-cond-v2.md index 9d4010f64..2ae50cf13 100644 --- a/rfcs/20180507-cond-v2.md +++ b/rfcs/20180507-cond-v2.md @@ -178,3 +178,7 @@ The current `cond` implementation allows each op in the taken branch to be run a ## Future work **C API support.** Ideally other language bindings support conditional execution as well. The C API already includes the primitives for other bindings to implement something similar to `tf.cond` that produces an `If` op, but the C API `TF_AddGradients` method would need to support `If` ops in order for other bindings to (easily) allow autodiff of conditionals. + +## Update log + +2019-02-26: Updated some sections to reflect what was built and added link to implementation. Marked other sections as still needing update; many of these concerns are common to cond_v2, while_v2, and functions, so we may wanna include these as part of a function design doc. From aeaafce396d5f4cce97627c55d77a4947006fa40 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Wed, 27 Feb 2019 12:13:49 -0800 Subject: [PATCH 4/4] srbs --- rfcs/20180507-cond-v2.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rfcs/20180507-cond-v2.md b/rfcs/20180507-cond-v2.md index 2ae50cf13..cd62c3fb7 100644 --- a/rfcs/20180507-cond-v2.md +++ b/rfcs/20180507-cond-v2.md @@ -12,7 +12,7 @@ **Switch tf.cond to emit a single If op.** Benefits: -* Second-order derivatives +* Higher-order derivatives * Better XLA/TPU integration * Better error messages * Fewer bugs @@ -82,7 +82,7 @@ Since this is mutating an existing graph element, if the graph has already been This is the same method that tf.function and [while_v2](https://github.com/tensorflow/community/blob/master/rfcs/20180821-differentiable-functional-while.md) use for intermediate gradients, making all of them compose nicely. -Alternatives considered:\ +Alternatives considered: 1. Output every possible intermediate and rely on pruning to clean it up. The original implementation did this, but we changed it to match tf.function. 1. Create a new If op with the required outputs. To prevent running both the original and new ops, we need to rewire the outputs of the original op to use the new op (and ideally modify any existing Tensor objects as well). This also requires a Session-invalidating graph mutation. 1. Use placeholders for intermediates during construction, then use a C++ rewrite (Grappler or GraphOptimizationPass) to rewire the graph. This is effectively creating an alternative representation of the dataflow graph, which is undesireable (e.g. all graph traversal code would need to know about these special placeholders). @@ -101,7 +101,7 @@ In the branch where the intermediate is actually produced, the intermediate tens Each gradient branch then only unwraps the optionals from its corresponding forward branch to pass to the gradient computation. -Alternatives considered:\ +Alternatives considered: 1. Output dead tensors as the "mirror" tensors, similar to the current tf.cond implementation. This requires changes to the executor to make it not mark If ops, While ops, functions ops, and possibly other special cases as dead if they have dead inputs, and prevents us from someday simplifying the excutor by removing the dead tensor logic. 1. Introduce a special op to output mirror tensors. This op's shape inference function will claim to output the same shape and type of the mirrored output, but since the tensor isn't actually needed the kernel will produce some small value to avoid producing large unnecessary values.