Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "AbstractDifferentiation"
uuid = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
authors = ["Mohamed Tarek <[email protected]> and contributors"]
version = "0.6.2"
version = "0.6.3"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -15,6 +15,7 @@ DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand All @@ -23,6 +24,7 @@ AbstractDifferentiationChainRulesCoreExt = "ChainRulesCore"
AbstractDifferentiationFiniteDifferencesExt = "FiniteDifferences"
AbstractDifferentiationForwardDiffExt = ["DiffResults", "ForwardDiff"]
AbstractDifferentiationReverseDiffExt = ["DiffResults", "ReverseDiff"]
AbstractDifferentiationStaticArraysExt = "StaticArrays"
AbstractDifferentiationTrackerExt = "Tracker"
AbstractDifferentiationZygoteExt = "Zygote"

Expand All @@ -35,6 +37,7 @@ FiniteDifferences = "0.12"
ForwardDiff = "0.10"
Requires = "1"
ReverseDiff = "1"
StaticArrays = "1"
Tracker = "0.2"
Zygote = "0.6"
julia = "1.6"
Expand All @@ -47,9 +50,10 @@ FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["ChainRulesCore", "DiffResults", "Documenter", "FiniteDifferences", "ForwardDiff", "Random", "ReverseDiff", "Test", "Tracker", "Zygote"]
test = ["ChainRulesCore", "DiffResults", "Documenter", "FiniteDifferences", "ForwardDiff", "Random", "ReverseDiff", "StaticArrays", "Test", "Tracker", "Zygote"]
3 changes: 2 additions & 1 deletion docs/src/implementer_guide.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Implementer guide

!!! warning "Work in progress"

Come back later!

## The macro `@primitive`
Expand Down Expand Up @@ -37,6 +37,7 @@ They are just listed here to help readers figure out the code structure:
- `value_and_second_derivative` calls `second_derivative`
- `value_gradient_and_hessian` calls `value_and_jacobian` and `gradient`
- `value_derivative_and_second_derivative` calls `value_and_derivative` and `second_derivative`
- `value_jacobian_and_hessian` calls `value_and_jacobian` and `jacobian`
- `pushforward_function` calls `jacobian`
- `value_and_pushforward_function` calls `pushforward_function`
- `pullback_function` calls `value_and_pullback_function`
Expand Down
1 change: 1 addition & 0 deletions docs/src/user_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ AbstractDifferentiation.value_and_second_derivative
AbstractDifferentiation.value_and_hessian
AbstractDifferentiation.value_derivative_and_second_derivative
AbstractDifferentiation.value_gradient_and_hessian
AbstractDifferentiation.value_jacobian_and_hessian
```

## Jacobian-vector products
Expand Down
8 changes: 8 additions & 0 deletions ext/AbstractDifferentiationStaticArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module AbstractDifferentiationStaticArraysExt

import AbstractDifferentiation as AD
using StaticArrays

AD.sameindex(idxs::Union{Base.OneTo, StaticArrays.SOneTo}...) = map(Base.OneTo, idxs)

end
89 changes: 67 additions & 22 deletions src/AbstractDifferentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ end
AD.value_and_gradient(ab::AD.AbstractBackend, f, xs...)

Return the tuple `(v, gs)` of the function value `v = f(xs...)` and the gradients `gs = AD.gradient(ab, f, xs...)`.

See also [`AbstractDifferentiation.gradient`](@ref).
"""
function value_and_gradient(ab::AbstractBackend, f, xs...)
Expand All @@ -148,7 +148,7 @@ end
AD.value_and_jacobian(ab::AD.AbstractBackend, f, xs...)

Return the tuple `(v, Js)` of the function value `v = f(xs...)` and the Jacobians `Js = AD.jacobian(ab, f, xs...)`.

See also [`AbstractDifferentiation.jacobian`](@ref).
"""
function value_and_jacobian(ab::AbstractBackend, f, xs...)
Expand All @@ -173,7 +173,7 @@ end

Return the tuple `(v, H)` of the function value `v = f(x)` and the Hessian `H = AD.hessian(ab, f, x)`.

See also [`AbstractDifferentiation.hessian`](@ref).
See also [`AbstractDifferentiation.hessian`](@ref).
"""
function value_and_hessian(ab::AbstractBackend, f, x)
if x isa Tuple
Expand Down Expand Up @@ -214,7 +214,7 @@ end

"""
AD.value_gradient_and_hessian(ab::AD.AbstractBackend, f, x)

Return the tuple `(v, g, H)` of the function value `v = f(x)`, the gradient `g = AD.gradient(ab, f, x)`, and the Hessian `H = AD.hessian(ab, f, x)`.

See also [`AbstractDifferentiation.gradient`](@ref) and [`AbstractDifferentiation.hessian`](@ref).
Expand All @@ -236,11 +236,38 @@ function value_gradient_and_hessian(ab::AbstractBackend, f, x)
return value, (grads,), hess
end

"""
AD.value_jacobian_and_hessian(ab::AD.AbstractBackend, f, x)

Return the tuple `(v, J, H)` of the function value `v = f(x)`, the jacobian `J = AD.jacobian(ab, f, x)`, and the Hessian `H = AD.hessian(ab, f, x)`.
Note that `H[i, :, :]` is the hessian for `f(x)[i]`.

See also [`AbstractDifferentiation.jacobian`](@ref) and [`AbstractDifferentiation.hessian`](@ref).
"""
function value_jacobian_and_hessian(ab::AbstractBackend, f, x)
if x isa Tuple
# only support computation of Hessian for functions with single input argument
x = only(x)
end

value = f(x)
jacs, hess = value_and_jacobian(
second_lowest(ab), _x -> begin
g = jacobian(lowest(ab), f, _x)
return g[1] # gradient returns a tuple
end, x
)
hess = reshape(only(hess), sameindex(eachindex(value), eachindex(x), eachindex(x)))

return value, (jacs,), (hess,)
end


"""
AD.pushforward_function(ab::AD.AbstractBackend, f, xs...)
Return the pushforward function `pff` of the function `f` at the inputs `xs` using backend `ab`.

Return the pushforward function `pff` of the function `f` at the inputs `xs` using backend `ab`.

The pushfoward function `pff` accepts as input a `Tuple` of tangents, one for each element in `xs`.
If `xs` consists of a single element, `pff` can also accept a single tangent instead of a 1-tuple.
"""
Expand All @@ -263,9 +290,9 @@ end

"""
AD.value_and_pushforward_function(ab::AD.AbstractBackend, f, xs...)

Return a single function `vpff` which, given tangents `ts`, computes the tuple `(v, p) = vpff(ts)` composed of

- the function value `v = f(xs...)`
- the pushforward value `p = pff(ts)` given by the pushforward function `pff = AD.pushforward_function(ab, f, xs...)` applied to `ts`.

Expand Down Expand Up @@ -308,8 +335,8 @@ end
"""
AD.pullback_function(ab::AD.AbstractBackend, f, xs...)

Return the pullback function `pbf` of the function `f` at the inputs `xs` using backend `ab`.
Return the pullback function `pbf` of the function `f` at the inputs `xs` using backend `ab`.

The pullback function `pbf` accepts as input a `Tuple` of cotangents, one for each output of `f`.
If `f` has a single output, `pbf` can also accept a single input instead of a 1-tuple.
"""
Expand Down Expand Up @@ -537,9 +564,9 @@ end

"""
AD.lazy_derivative(ab::AbstractBackend, f, xs::Number...)

Return an operator `ld` for multiplying by the derivative of `f` at `xs`.

You can apply the operator by multiplication e.g. `ld * y` where `y` is a number if `f` has a single input, a tuple of the same length as `xs` if `f` has multiple inputs, or an array of numbers/tuples.
"""
function lazy_derivative(ab::AbstractBackend, f, xs::Number...)
Expand All @@ -548,9 +575,9 @@ end

"""
AD.lazy_gradient(ab::AbstractBackend, f, xs...)

Return an operator `lg` for multiplying by the gradient of `f` at `xs`.

You can apply the operator by multiplication e.g. `lg * y` where `y` is a number if `f` has a single input or a tuple of the same length as `xs` if `f` has multiple inputs.
"""
function lazy_gradient(ab::AbstractBackend, f, xs...)
Expand All @@ -559,9 +586,9 @@ end

"""
AD.lazy_hessian(ab::AbstractBackend, f, x)

Return an operator `lh` for multiplying by the Hessian of the scalar-valued function `f` at `x`.

You can apply the operator by multiplication e.g. `lh * y` or `y' * lh` where `y` is a number or a vector of the appropriate length.
"""
function lazy_hessian(ab::AbstractBackend, f, xs...)
Expand All @@ -570,10 +597,10 @@ end

"""
AD.lazy_jacobian(ab::AbstractBackend, f, xs...)

Return an operator `lj` for multiplying by the Jacobian of `f` at `xs`.
You can apply the operator by multiplication e.g. `lj * y` or `y' * lj` where `y` is a number, vector or tuple of numbers and/or vectors.

You can apply the operator by multiplication e.g. `lj * y` or `y' * lj` where `y` is a number, vector or tuple of numbers and/or vectors.
If `f` has multiple inputs, `y` in `lj * y` should be a tuple.
If `f` has multiple outputs, `y` in `y' * lj` should be a tuple.
Otherwise, it should be a scalar or a vector of the appropriate length.
Expand Down Expand Up @@ -640,7 +667,7 @@ function define_pushforward_function_and_friends(fdef)
elseif eltype(identity_like) <: AbstractMatrix
# needed for the computation of the Hessian and Jacobian
ret = hcat.(mapslices(identity_like[1]; dims=1) do cols
# cols loop over basis states
# cols loop over basis states
pf = pff((cols,))
if typeof(pf) <: AbstractVector
# to make the hcat. work / get correct matrix-like, non-flat output dimension
Expand Down Expand Up @@ -676,7 +703,7 @@ function define_value_and_pullback_function_and_friends(fdef)
elseif eltype(identity_like) <: AbstractMatrix
# needed for Hessian computation:
# value is a (grad,). Then, identity_like is a (matrix,).
# cols loops over columns of the matrix
# cols loops over columns of the matrix
return vcat.(mapslices(identity_like[1]; dims=1) do cols
adjoint.(pbf((cols,)))
end...)
Expand All @@ -695,6 +722,18 @@ function identity_matrix_like(x)
throw("The function `identity_matrix_like` is not defined for the type $(typeof(x)).")
end

function identity_matrix_like(X::AbstractMatrix)
Base.require_one_based_indexing(X)
m, n = size(X)
A = Array{eltype(X)}(undef, m, n, n)
fill!(A, false)
for i in 1:m, j in 1:n
A[i, j, j] = true
end
return (A,)
# return (reshape(A, m*n, n),)
end

function identity_matrix_like(x::AbstractVector)
return (Matrix{eltype(x)}(I, length(x), length(x)),)
end
Expand Down Expand Up @@ -732,6 +771,9 @@ end
@inline asarray(x) = [x]
@inline asarray(x::AbstractArray) = x

sameindex(idxs::R...) where R<:AbstractUnitRange = idxs


include("backends.jl")

# TODO: Replace with proper version
Expand Down Expand Up @@ -761,6 +803,9 @@ end
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include(
"../ext/AbstractDifferentiationZygoteExt.jl"
)
@require StaticArrys ="90137ffa-7385-5640-81b9-e52037218182" include(
"../ext/AbstractDifferentiationStaticArraysExt.jl"
)
end
end

Expand Down
18 changes: 17 additions & 1 deletion test/test_utils.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using AbstractDifferentiation
using StaticArrays
using Test, LinearAlgebra
using Random
Random.seed!(1234)
Expand All @@ -20,6 +21,10 @@ end
dfjacdx(x, y) = I(length(x))
dfjacdy(x, y) = Bidiagonal(-ones(length(y)) * 3, ones(length(y) - 1) / 2, :U)

fjac2(x) = [sum(abs2, x) - 1; prod(x)]
dfjac2dx(x) = [2x'; prod(x) ./ x']
dfjac2dxdx(x) = permutedims(cat(2I(length(x)), prod(x) ./ (x * x') - Diagonal(diag(prod(x) ./ (x * x'))); dims=3), (3, 1, 2))

# Jvp
jxvp(x, y, v) = dfjacdx(x, y) * v
jyvp(x, y, v) = dfjacdy(x, y) * v
Expand All @@ -33,6 +38,7 @@ const yscalar = rand()

const xvec = rand(5)
const yvec = rand(5)
const sxvec = @SVector rand(5)

# to check if vectors get mutated
xvec2 = deepcopy(xvec)
Expand Down Expand Up @@ -184,7 +190,7 @@ end

function test_hessians(backend; multiple_inputs=false, test_types=true)
if multiple_inputs
# ... but
# ... but
error("multiple_inputs=true is not supported.")
else
# explicit test that AbstractDifferentiation throws an error
Expand Down Expand Up @@ -230,6 +236,16 @@ function test_hessians(backend; multiple_inputs=false, test_types=true)
@test hess4[1] isa Matrix{Float64}
end
@test minimum(isapprox.(hess4, hess1, atol=1e-10))

val, jac, hess = AD.value_jacobian_and_hessian(backend, fjac2, xvec)
@test val == fjac2(xvec)
@test only(jac) ≈ dfjac2dx(xvec) atol = 1e-10
@test only(hess) ≈ dfjac2dxdx(xvec) atol = 1e-10

val, jac, hess = AD.value_jacobian_and_hessian(backend, fjac2, sxvec)
@test val == fjac2(sxvec)
@test only(jac) ≈ dfjac2dx(sxvec) atol = 1e-10
@test only(hess) ≈ dfjac2dxdx(sxvec) atol = 1e-10
end

function test_jvp(
Expand Down