From 73c1d7f097b67abce881adbb356658e5bf25a7bd Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Tue, 23 Apr 2024 16:49:48 -0500 Subject: [PATCH] WIP: `value_jacobian_and_hessian` This aims to support returning derivatives 0-2 for vector-valued functions. --- Project.toml | 8 +- docs/src/implementer_guide.md | 3 +- docs/src/user_guide.md | 1 + ext/AbstractDifferentiationStaticArraysExt.jl | 8 ++ src/AbstractDifferentiation.jl | 89 ++++++++++++++----- test/test_utils.jl | 18 +++- 6 files changed, 101 insertions(+), 26 deletions(-) create mode 100644 ext/AbstractDifferentiationStaticArraysExt.jl diff --git a/Project.toml b/Project.toml index 9db693d..e1c6d88 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "AbstractDifferentiation" uuid = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" authors = ["Mohamed Tarek and contributors"] -version = "0.6.2" +version = "0.6.3" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -15,6 +15,7 @@ DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -23,6 +24,7 @@ AbstractDifferentiationChainRulesCoreExt = "ChainRulesCore" AbstractDifferentiationFiniteDifferencesExt = "FiniteDifferences" AbstractDifferentiationForwardDiffExt = ["DiffResults", "ForwardDiff"] AbstractDifferentiationReverseDiffExt = ["DiffResults", "ReverseDiff"] +AbstractDifferentiationStaticArraysExt = "StaticArrays" AbstractDifferentiationTrackerExt = "Tracker" AbstractDifferentiationZygoteExt = "Zygote" @@ -35,6 +37,7 @@ FiniteDifferences = "0.12" ForwardDiff = "0.10" Requires = "1" ReverseDiff = "1" +StaticArrays = "1" Tracker = "0.2" Zygote = "0.6" julia = "1.6" @@ -47,9 +50,10 @@ FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["ChainRulesCore", "DiffResults", "Documenter", "FiniteDifferences", "ForwardDiff", "Random", "ReverseDiff", "Test", "Tracker", "Zygote"] +test = ["ChainRulesCore", "DiffResults", "Documenter", "FiniteDifferences", "ForwardDiff", "Random", "ReverseDiff", "StaticArrays", "Test", "Tracker", "Zygote"] diff --git a/docs/src/implementer_guide.md b/docs/src/implementer_guide.md index 4a14e43..fd5ab10 100644 --- a/docs/src/implementer_guide.md +++ b/docs/src/implementer_guide.md @@ -1,7 +1,7 @@ # Implementer guide !!! warning "Work in progress" - + Come back later! ## The macro `@primitive` @@ -37,6 +37,7 @@ They are just listed here to help readers figure out the code structure: - `value_and_second_derivative` calls `second_derivative` - `value_gradient_and_hessian` calls `value_and_jacobian` and `gradient` - `value_derivative_and_second_derivative` calls `value_and_derivative` and `second_derivative` + - `value_jacobian_and_hessian` calls `value_and_jacobian` and `jacobian` - `pushforward_function` calls `jacobian` - `value_and_pushforward_function` calls `pushforward_function` - `pullback_function` calls `value_and_pullback_function` diff --git a/docs/src/user_guide.md b/docs/src/user_guide.md index e09768c..fc3077c 100644 --- a/docs/src/user_guide.md +++ b/docs/src/user_guide.md @@ -75,6 +75,7 @@ AbstractDifferentiation.value_and_second_derivative AbstractDifferentiation.value_and_hessian AbstractDifferentiation.value_derivative_and_second_derivative AbstractDifferentiation.value_gradient_and_hessian +AbstractDifferentiation.value_jacobian_and_hessian ``` ## Jacobian-vector products diff --git a/ext/AbstractDifferentiationStaticArraysExt.jl b/ext/AbstractDifferentiationStaticArraysExt.jl new file mode 100644 index 0000000..f50a7d0 --- /dev/null +++ b/ext/AbstractDifferentiationStaticArraysExt.jl @@ -0,0 +1,8 @@ +module AbstractDifferentiationStaticArraysExt + +import AbstractDifferentiation as AD +using StaticArrays + +AD.sameindex(idxs::Union{Base.OneTo, StaticArrays.SOneTo}...) = map(Base.OneTo, idxs) + +end diff --git a/src/AbstractDifferentiation.jl b/src/AbstractDifferentiation.jl index 389f4c8..cf335b9 100644 --- a/src/AbstractDifferentiation.jl +++ b/src/AbstractDifferentiation.jl @@ -136,7 +136,7 @@ end AD.value_and_gradient(ab::AD.AbstractBackend, f, xs...) Return the tuple `(v, gs)` of the function value `v = f(xs...)` and the gradients `gs = AD.gradient(ab, f, xs...)`. - + See also [`AbstractDifferentiation.gradient`](@ref). """ function value_and_gradient(ab::AbstractBackend, f, xs...) @@ -148,7 +148,7 @@ end AD.value_and_jacobian(ab::AD.AbstractBackend, f, xs...) Return the tuple `(v, Js)` of the function value `v = f(xs...)` and the Jacobians `Js = AD.jacobian(ab, f, xs...)`. - + See also [`AbstractDifferentiation.jacobian`](@ref). """ function value_and_jacobian(ab::AbstractBackend, f, xs...) @@ -173,7 +173,7 @@ end Return the tuple `(v, H)` of the function value `v = f(x)` and the Hessian `H = AD.hessian(ab, f, x)`. -See also [`AbstractDifferentiation.hessian`](@ref). +See also [`AbstractDifferentiation.hessian`](@ref). """ function value_and_hessian(ab::AbstractBackend, f, x) if x isa Tuple @@ -214,7 +214,7 @@ end """ AD.value_gradient_and_hessian(ab::AD.AbstractBackend, f, x) - + Return the tuple `(v, g, H)` of the function value `v = f(x)`, the gradient `g = AD.gradient(ab, f, x)`, and the Hessian `H = AD.hessian(ab, f, x)`. See also [`AbstractDifferentiation.gradient`](@ref) and [`AbstractDifferentiation.hessian`](@ref). @@ -236,11 +236,38 @@ function value_gradient_and_hessian(ab::AbstractBackend, f, x) return value, (grads,), hess end +""" + AD.value_jacobian_and_hessian(ab::AD.AbstractBackend, f, x) + +Return the tuple `(v, J, H)` of the function value `v = f(x)`, the jacobian `J = AD.jacobian(ab, f, x)`, and the Hessian `H = AD.hessian(ab, f, x)`. +Note that `H[i, :, :]` is the hessian for `f(x)[i]`. + +See also [`AbstractDifferentiation.jacobian`](@ref) and [`AbstractDifferentiation.hessian`](@ref). +""" +function value_jacobian_and_hessian(ab::AbstractBackend, f, x) + if x isa Tuple + # only support computation of Hessian for functions with single input argument + x = only(x) + end + + value = f(x) + jacs, hess = value_and_jacobian( + second_lowest(ab), _x -> begin + g = jacobian(lowest(ab), f, _x) + return g[1] # gradient returns a tuple + end, x + ) + hess = reshape(only(hess), sameindex(eachindex(value), eachindex(x), eachindex(x))) + + return value, (jacs,), (hess,) +end + + """ AD.pushforward_function(ab::AD.AbstractBackend, f, xs...) - -Return the pushforward function `pff` of the function `f` at the inputs `xs` using backend `ab`. - + +Return the pushforward function `pff` of the function `f` at the inputs `xs` using backend `ab`. + The pushfoward function `pff` accepts as input a `Tuple` of tangents, one for each element in `xs`. If `xs` consists of a single element, `pff` can also accept a single tangent instead of a 1-tuple. """ @@ -263,9 +290,9 @@ end """ AD.value_and_pushforward_function(ab::AD.AbstractBackend, f, xs...) - + Return a single function `vpff` which, given tangents `ts`, computes the tuple `(v, p) = vpff(ts)` composed of - + - the function value `v = f(xs...)` - the pushforward value `p = pff(ts)` given by the pushforward function `pff = AD.pushforward_function(ab, f, xs...)` applied to `ts`. @@ -308,8 +335,8 @@ end """ AD.pullback_function(ab::AD.AbstractBackend, f, xs...) -Return the pullback function `pbf` of the function `f` at the inputs `xs` using backend `ab`. - +Return the pullback function `pbf` of the function `f` at the inputs `xs` using backend `ab`. + The pullback function `pbf` accepts as input a `Tuple` of cotangents, one for each output of `f`. If `f` has a single output, `pbf` can also accept a single input instead of a 1-tuple. """ @@ -537,9 +564,9 @@ end """ AD.lazy_derivative(ab::AbstractBackend, f, xs::Number...) - + Return an operator `ld` for multiplying by the derivative of `f` at `xs`. - + You can apply the operator by multiplication e.g. `ld * y` where `y` is a number if `f` has a single input, a tuple of the same length as `xs` if `f` has multiple inputs, or an array of numbers/tuples. """ function lazy_derivative(ab::AbstractBackend, f, xs::Number...) @@ -548,9 +575,9 @@ end """ AD.lazy_gradient(ab::AbstractBackend, f, xs...) - + Return an operator `lg` for multiplying by the gradient of `f` at `xs`. - + You can apply the operator by multiplication e.g. `lg * y` where `y` is a number if `f` has a single input or a tuple of the same length as `xs` if `f` has multiple inputs. """ function lazy_gradient(ab::AbstractBackend, f, xs...) @@ -559,9 +586,9 @@ end """ AD.lazy_hessian(ab::AbstractBackend, f, x) - + Return an operator `lh` for multiplying by the Hessian of the scalar-valued function `f` at `x`. - + You can apply the operator by multiplication e.g. `lh * y` or `y' * lh` where `y` is a number or a vector of the appropriate length. """ function lazy_hessian(ab::AbstractBackend, f, xs...) @@ -570,10 +597,10 @@ end """ AD.lazy_jacobian(ab::AbstractBackend, f, xs...) - + Return an operator `lj` for multiplying by the Jacobian of `f` at `xs`. - -You can apply the operator by multiplication e.g. `lj * y` or `y' * lj` where `y` is a number, vector or tuple of numbers and/or vectors. + +You can apply the operator by multiplication e.g. `lj * y` or `y' * lj` where `y` is a number, vector or tuple of numbers and/or vectors. If `f` has multiple inputs, `y` in `lj * y` should be a tuple. If `f` has multiple outputs, `y` in `y' * lj` should be a tuple. Otherwise, it should be a scalar or a vector of the appropriate length. @@ -640,7 +667,7 @@ function define_pushforward_function_and_friends(fdef) elseif eltype(identity_like) <: AbstractMatrix # needed for the computation of the Hessian and Jacobian ret = hcat.(mapslices(identity_like[1]; dims=1) do cols - # cols loop over basis states + # cols loop over basis states pf = pff((cols,)) if typeof(pf) <: AbstractVector # to make the hcat. work / get correct matrix-like, non-flat output dimension @@ -676,7 +703,7 @@ function define_value_and_pullback_function_and_friends(fdef) elseif eltype(identity_like) <: AbstractMatrix # needed for Hessian computation: # value is a (grad,). Then, identity_like is a (matrix,). - # cols loops over columns of the matrix + # cols loops over columns of the matrix return vcat.(mapslices(identity_like[1]; dims=1) do cols adjoint.(pbf((cols,))) end...) @@ -695,6 +722,18 @@ function identity_matrix_like(x) throw("The function `identity_matrix_like` is not defined for the type $(typeof(x)).") end +function identity_matrix_like(X::AbstractMatrix) + Base.require_one_based_indexing(X) + m, n = size(X) + A = Array{eltype(X)}(undef, m, n, n) + fill!(A, false) + for i in 1:m, j in 1:n + A[i, j, j] = true + end + return (A,) + # return (reshape(A, m*n, n),) +end + function identity_matrix_like(x::AbstractVector) return (Matrix{eltype(x)}(I, length(x), length(x)),) end @@ -732,6 +771,9 @@ end @inline asarray(x) = [x] @inline asarray(x::AbstractArray) = x +sameindex(idxs::R...) where R<:AbstractUnitRange = idxs + + include("backends.jl") # TODO: Replace with proper version @@ -761,6 +803,9 @@ end @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include( "../ext/AbstractDifferentiationZygoteExt.jl" ) + @require StaticArrys ="90137ffa-7385-5640-81b9-e52037218182" include( + "../ext/AbstractDifferentiationStaticArraysExt.jl" + ) end end diff --git a/test/test_utils.jl b/test/test_utils.jl index 22e00c3..03f7943 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -1,4 +1,5 @@ using AbstractDifferentiation +using StaticArrays using Test, LinearAlgebra using Random Random.seed!(1234) @@ -20,6 +21,10 @@ end dfjacdx(x, y) = I(length(x)) dfjacdy(x, y) = Bidiagonal(-ones(length(y)) * 3, ones(length(y) - 1) / 2, :U) +fjac2(x) = [sum(abs2, x) - 1; prod(x)] +dfjac2dx(x) = [2x'; prod(x) ./ x'] +dfjac2dxdx(x) = permutedims(cat(2I(length(x)), prod(x) ./ (x * x') - Diagonal(diag(prod(x) ./ (x * x'))); dims=3), (3, 1, 2)) + # Jvp jxvp(x, y, v) = dfjacdx(x, y) * v jyvp(x, y, v) = dfjacdy(x, y) * v @@ -33,6 +38,7 @@ const yscalar = rand() const xvec = rand(5) const yvec = rand(5) +const sxvec = @SVector rand(5) # to check if vectors get mutated xvec2 = deepcopy(xvec) @@ -184,7 +190,7 @@ end function test_hessians(backend; multiple_inputs=false, test_types=true) if multiple_inputs - # ... but + # ... but error("multiple_inputs=true is not supported.") else # explicit test that AbstractDifferentiation throws an error @@ -230,6 +236,16 @@ function test_hessians(backend; multiple_inputs=false, test_types=true) @test hess4[1] isa Matrix{Float64} end @test minimum(isapprox.(hess4, hess1, atol=1e-10)) + + val, jac, hess = AD.value_jacobian_and_hessian(backend, fjac2, xvec) + @test val == fjac2(xvec) + @test only(jac) ≈ dfjac2dx(xvec) atol = 1e-10 + @test only(hess) ≈ dfjac2dxdx(xvec) atol = 1e-10 + + val, jac, hess = AD.value_jacobian_and_hessian(backend, fjac2, sxvec) + @test val == fjac2(sxvec) + @test only(jac) ≈ dfjac2dx(sxvec) atol = 1e-10 + @test only(hess) ≈ dfjac2dxdx(sxvec) atol = 1e-10 end function test_jvp(