Skip to content

Zygote ext #218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SparseDiffTools"
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
authors = ["Pankaj Mishra <[email protected]>", "Chris Rackauckas <[email protected]>"]
version = "2.00.0"
version = "2.0.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -12,16 +12,24 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718"
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"
VertexSafeGraphs = "19fa3120-7c27-5ec5-8db8-b0b0aa330d6f"

[weakdeps]
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
SparseDiffToolsZygote = "Zygote"

[compat]
Adapt = "3.0"
ArrayInterface = "6, 7"
ArrayInterface = "7"
Compat = "4"
DataStructures = "0.18"
FiniteDiff = "2.8.1"
Expand All @@ -30,8 +38,10 @@ Graphs = "1"
Requires = "1"
SciMLOperators = "0.1.19, 0.2"
StaticArrays = "1"
StaticArrayInterface = "1.3"
Tricks = "0.1.6"
VertexSafeGraphs = "0.2"
Zygote = "0.6"
julia = "1.6"

[extras]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,24 @@
function numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v))
module SparseDiffToolsZygote

if isdefined(Base, :get_extension)
import Zygote
using LinearAlgebra
using SparseDiffTools: SparseDiffTools, DeivVecTag, FwdModeAutoDiffVecProd, VecJac
using ForwardDiff: ForwardDiff, Dual, partials
using SciMLOperators: FunctionOperator
using Tricks: static_hasmethod
else
import ..Zygote
using ..LinearAlgebra
using ..SparseDiffTools: SparseDiffTools, DeivVecTag, FwdModeAutoDiffVecProd, VecJac
using ..ForwardDiff: ForwardDiff, Dual, partials
using ..SciMLOperators: FunctionOperator
using ..Tricks: static_hasmethod
end

### Jac, Hes products

function SparseDiffTools.numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v))
g = let f = f
(dx, x) -> dx .= first(Zygote.gradient(f, x))
end
Expand All @@ -12,7 +32,7 @@ function numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v))
@. dy = (cache1 - cache2) / (2ϵ)
end

function numback_hesvec(f, x, v)
function SparseDiffTools.numback_hesvec(f, x, v)
g = x -> first(Zygote.gradient(f, x))
T = eltype(x)
# Should it be min? max? mean?
Expand All @@ -24,7 +44,7 @@ function numback_hesvec(f, x, v)
(gxp - gxm) / (2ϵ)
end

function autoback_hesvec!(dy, f, x, v,
function SparseDiffTools.autoback_hesvec!(dy, f, x, v,
cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))),
eltype(x), 1
}.(x,
Expand All @@ -42,16 +62,16 @@ function autoback_hesvec!(dy, f, x, v,
dy .= partials.(cache2, 1)
end

function autoback_hesvec(f, x, v)
function SparseDiffTools.autoback_hesvec(f, x, v)
g = x -> first(Zygote.gradient(f, x))
y = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))), eltype(x), 1
}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
ForwardDiff.partials.(g(y), 1)
end

### Operator Forms
# Operator Forms

function ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
function SparseDiffTools.ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)

if autodiff
cache1 = Dual{
Expand All @@ -65,8 +85,8 @@ function ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff =

cache = (cache1, cache2,)

vecprod = autodiff ? autoback_hesvec : numback_hesvec
vecprod! = autodiff ? autoback_hesvec! : numback_hesvec!
vecprod = autodiff ? SparseDiffTools.autoback_hesvec : SparseDiffTools.numback_hesvec
vecprod! = autodiff ? SparseDiffTools.autoback_hesvec! : SparseDiffTools.numback_hesvec!

outofplace = static_hasmethod(f, typeof((u,)))
isinplace = static_hasmethod(f, typeof((u,)))
Expand All @@ -82,4 +102,21 @@ function ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff =
p = p, t = t, islinear = true,
)
end
#

## VecJac products

function SparseDiffTools.auto_vecjac!(du, f, x, v, cache1 = nothing, cache2 = nothing)
!hasmethod(f, (typeof(x),)) && error("For inplace function use autodiff = false")
du .= reshape(SparseDiffTools.auto_vecjac(f, x, v), size(du))
end

function SparseDiffTools.auto_vecjac(f, x, v)
vv, back = Zygote.pullback(f, x)
return vec(back(reshape(v, size(vv)))[1])
end

function SparseDiffTools.ZygoteVecJac(args...; autodiff = true, kwargs...)
VecJac(args...; autodiff = autodiff, kwargs...)
end

end # module
34 changes: 24 additions & 10 deletions src/SparseDiffTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ using FiniteDiff
using ForwardDiff
using Graphs
using Graphs: SimpleGraph
using Requires
using VertexSafeGraphs
using Adapt

Expand All @@ -21,7 +20,7 @@ using ArrayInterface: matrix_colors

using SciMLOperators
import SciMLOperators: update_coefficients, update_coefficients!
using Tricks: static_hasmethod
using Tricks: Tricks, static_hasmethod

abstract type AbstractAutoDiffVecProd end

Expand Down Expand Up @@ -69,16 +68,31 @@ Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
parameterless_type(x) = parameterless_type(typeof(x))
parameterless_type(x::Type) = __parameterless_type(x)

function __init__()
@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin
export numback_hesvec, numback_hesvec!,
autoback_hesvec, autoback_hesvec!,
auto_vecjac, auto_vecjac!,
ZygoteVecJac, ZygoteHesVec
import Requires
import Reexport

include("differentiation/vecjac_products_zygote.jl")
include("differentiation/jaches_products_zygote.jl")
function numback_hesvec end
function numback_hesvec! end
function autoback_hesvec end
function autoback_hesvec! end
function auto_vecjac end
function auto_vecjac! end
function ZygoteVecJac end
function ZygoteHesVec end

@static if !isdefined(Base, :get_extension)
function __init__()
Requires.@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
include("../ext/SparseDiffToolsZygote.jl")
Reexport.@reexport using .SparseDiffToolsZygote
end
end
end

export
numback_hesvec, numback_hesvec!,
autoback_hesvec, autoback_hesvec!,
auto_vecjac, auto_vecjac!,
ZygoteVecJac, ZygoteHesVec

end # module
14 changes: 0 additions & 14 deletions src/differentiation/vecjac_products_zygote.jl

This file was deleted.