-
Notifications
You must be signed in to change notification settings - Fork 10.6k
Description
Description
Since #66873 was merged the compiler is now able to differentiate through functions with multiple results (such as functions with a differentiable inout parameter that also return a result).
Unfortunately we cannot directly ask for the pullback of these functions however due to missing implementations of valueWithPullback with inout parameters.
A potential function signature would be (for arity1):
@inlinable
public func valueWithPullback<T, R>(
at x: inout T, of f: @differentiable(reverse) (inout T) -> R
) -> (value: R, pullback: (R.TangentVector, inout T.TangentVector) -> Void) {
return Builtin.applyDerivative_vjp(f, x) // Currently missing Builtin
}Currently we can get around this missing feature by making a copy of the parameter of a non inout function:
@differentiable(reverse)
func square(x: inout Double) { // we can't directly call valueWithPullback on this function
x * x
}
@differentiable(reverse)
func nonInoutSquare(x: Double) -> Double {
var x = x
square(x: x)
return x
}
let result = valueWithPullback(at: 5.0, of: nonInoutSquare)This kind of defeats the point of course in terms of expressivity and performance since we have to make additional copies here that would be avoided when directly using inout parameters.
Potential issue:
There are currently three valueWithPullback implementations from arity 1 to 3. Due to the underlying Builtins we unfortunately can't simplify these using parameter packs (as far as I can tell). Adding potential functions with inout parameters here will greatly increase the amount of overloads for all the unique combinations of parameters being "normal" or "inout" and functions having differentiable results or not. inout parameters also don't lend themselves to parameter packs at this time unfortunately (afaik).
Do people see any other potential roadblocks for this feature?
Additional information
No response