diff --git a/Project.toml b/Project.toml index 8f63f9dc..d227338a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SparseDiffTools" uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804" authors = ["Pankaj Mishra ", "Chris Rackauckas "] -version = "2.00.0" +version = "2.0.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -12,16 +12,24 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718" Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" VertexSafeGraphs = "19fa3120-7c27-5ec5-8db8-b0b0aa330d6f" +[weakdeps] +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[extensions] +SparseDiffToolsZygote = "Zygote" + [compat] Adapt = "3.0" -ArrayInterface = "6, 7" +ArrayInterface = "7" Compat = "4" DataStructures = "0.18" FiniteDiff = "2.8.1" @@ -30,8 +38,10 @@ Graphs = "1" Requires = "1" SciMLOperators = "0.1.19, 0.2" StaticArrays = "1" +StaticArrayInterface = "1.3" Tricks = "0.1.6" VertexSafeGraphs = "0.2" +Zygote = "0.6" julia = "1.6" [extras] diff --git a/src/differentiation/jaches_products_zygote.jl b/ext/SparseDiffToolsZygote.jl similarity index 58% rename from src/differentiation/jaches_products_zygote.jl rename to ext/SparseDiffToolsZygote.jl index 3187d3ff..377820b2 100644 --- a/src/differentiation/jaches_products_zygote.jl +++ b/ext/SparseDiffToolsZygote.jl @@ -1,4 +1,24 @@ -function numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v)) +module SparseDiffToolsZygote + +if isdefined(Base, :get_extension) + import Zygote + using LinearAlgebra + using SparseDiffTools: SparseDiffTools, DeivVecTag, FwdModeAutoDiffVecProd, VecJac + using ForwardDiff: ForwardDiff, Dual, partials + using SciMLOperators: FunctionOperator + using Tricks: static_hasmethod +else + import ..Zygote + using ..LinearAlgebra + using ..SparseDiffTools: SparseDiffTools, DeivVecTag, FwdModeAutoDiffVecProd, VecJac + using ..ForwardDiff: ForwardDiff, Dual, partials + using ..SciMLOperators: FunctionOperator + using ..Tricks: static_hasmethod +end + +### Jac, Hes products + +function SparseDiffTools.numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v)) g = let f = f (dx, x) -> dx .= first(Zygote.gradient(f, x)) end @@ -12,7 +32,7 @@ function numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v)) @. dy = (cache1 - cache2) / (2ϵ) end -function numback_hesvec(f, x, v) +function SparseDiffTools.numback_hesvec(f, x, v) g = x -> first(Zygote.gradient(f, x)) T = eltype(x) # Should it be min? max? mean? @@ -24,7 +44,7 @@ function numback_hesvec(f, x, v) (gxp - gxm) / (2ϵ) end -function autoback_hesvec!(dy, f, x, v, +function SparseDiffTools.autoback_hesvec!(dy, f, x, v, cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))), eltype(x), 1 }.(x, @@ -42,16 +62,16 @@ function autoback_hesvec!(dy, f, x, v, dy .= partials.(cache2, 1) end -function autoback_hesvec(f, x, v) +function SparseDiffTools.autoback_hesvec(f, x, v) g = x -> first(Zygote.gradient(f, x)) y = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))), eltype(x), 1 }.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))) ForwardDiff.partials.(g(y), 1) end -### Operator Forms +# Operator Forms -function ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true) +function SparseDiffTools.ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true) if autodiff cache1 = Dual{ @@ -65,8 +85,8 @@ function ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = cache = (cache1, cache2,) - vecprod = autodiff ? autoback_hesvec : numback_hesvec - vecprod! = autodiff ? autoback_hesvec! : numback_hesvec! + vecprod = autodiff ? SparseDiffTools.autoback_hesvec : SparseDiffTools.numback_hesvec + vecprod! = autodiff ? SparseDiffTools.autoback_hesvec! : SparseDiffTools.numback_hesvec! outofplace = static_hasmethod(f, typeof((u,))) isinplace = static_hasmethod(f, typeof((u,))) @@ -82,4 +102,21 @@ function ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = p = p, t = t, islinear = true, ) end -# + +## VecJac products + +function SparseDiffTools.auto_vecjac!(du, f, x, v, cache1 = nothing, cache2 = nothing) + !hasmethod(f, (typeof(x),)) && error("For inplace function use autodiff = false") + du .= reshape(SparseDiffTools.auto_vecjac(f, x, v), size(du)) +end + +function SparseDiffTools.auto_vecjac(f, x, v) + vv, back = Zygote.pullback(f, x) + return vec(back(reshape(v, size(vv)))[1]) +end + +function SparseDiffTools.ZygoteVecJac(args...; autodiff = true, kwargs...) + VecJac(args...; autodiff = autodiff, kwargs...) +end + +end # module diff --git a/src/SparseDiffTools.jl b/src/SparseDiffTools.jl index 621f7dbf..95c72a1a 100644 --- a/src/SparseDiffTools.jl +++ b/src/SparseDiffTools.jl @@ -5,7 +5,6 @@ using FiniteDiff using ForwardDiff using Graphs using Graphs: SimpleGraph -using Requires using VertexSafeGraphs using Adapt @@ -21,7 +20,7 @@ using ArrayInterface: matrix_colors using SciMLOperators import SciMLOperators: update_coefficients, update_coefficients! -using Tricks: static_hasmethod +using Tricks: Tricks, static_hasmethod abstract type AbstractAutoDiffVecProd end @@ -69,16 +68,31 @@ Base.@pure __parameterless_type(T) = Base.typename(T).wrapper parameterless_type(x) = parameterless_type(typeof(x)) parameterless_type(x::Type) = __parameterless_type(x) -function __init__() - @require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin - export numback_hesvec, numback_hesvec!, - autoback_hesvec, autoback_hesvec!, - auto_vecjac, auto_vecjac!, - ZygoteVecJac, ZygoteHesVec +import Requires +import Reexport - include("differentiation/vecjac_products_zygote.jl") - include("differentiation/jaches_products_zygote.jl") +function numback_hesvec end +function numback_hesvec! end +function autoback_hesvec end +function autoback_hesvec! end +function auto_vecjac end +function auto_vecjac! end +function ZygoteVecJac end +function ZygoteHesVec end + +@static if !isdefined(Base, :get_extension) + function __init__() + Requires.@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin + include("../ext/SparseDiffToolsZygote.jl") + Reexport.@reexport using .SparseDiffToolsZygote + end end end +export + numback_hesvec, numback_hesvec!, + autoback_hesvec, autoback_hesvec!, + auto_vecjac, auto_vecjac!, + ZygoteVecJac, ZygoteHesVec + end # module diff --git a/src/differentiation/vecjac_products_zygote.jl b/src/differentiation/vecjac_products_zygote.jl deleted file mode 100644 index f5f22623..00000000 --- a/src/differentiation/vecjac_products_zygote.jl +++ /dev/null @@ -1,14 +0,0 @@ -function auto_vecjac!(du, f, x, v, cache1 = nothing, cache2 = nothing) - !hasmethod(f, (typeof(x),)) && error("For inplace function use autodiff = false") - du .= reshape(auto_vecjac(f, x, v), size(du)) -end - -function auto_vecjac(f, x, v) - vv, back = Zygote.pullback(f, x) - return vec(back(reshape(v, size(vv)))[1]) -end - -#ZygoteVecJac = VecJac -ZygoteVecJac(args...; autodiff = true, kwargs...) = VecJac(args...; autodiff = autodiff, kwargs...) - -#