From d989fa8236eed75a478d47df8e1a9b291f9ed1c1 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Tue, 7 Mar 2023 20:12:34 -0500 Subject: [PATCH 01/11] dummy change to run tests --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8f63f9dc..51a3c815 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SparseDiffTools" uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804" authors = ["Pankaj Mishra ", "Chris Rackauckas "] -version = "2.00.0" +version = "2.0.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From f876266e3838067c047fb671bc1e5f23e1866a36 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Tue, 7 Mar 2023 20:29:59 -0500 Subject: [PATCH 02/11] zygote ext file --- .../SparseDiffToolsZygote.jl | 40 ++++++++++++++++++- src/SparseDiffTools.jl | 19 +++++---- src/differentiation/vecjac_products_zygote.jl | 14 ------- 3 files changed, 47 insertions(+), 26 deletions(-) rename src/differentiation/jaches_products_zygote.jl => ext/SparseDiffToolsZygote.jl (75%) delete mode 100644 src/differentiation/vecjac_products_zygote.jl diff --git a/src/differentiation/jaches_products_zygote.jl b/ext/SparseDiffToolsZygote.jl similarity index 75% rename from src/differentiation/jaches_products_zygote.jl rename to ext/SparseDiffToolsZygote.jl index 3187d3ff..5042c095 100644 --- a/src/differentiation/jaches_products_zygote.jl +++ b/ext/SparseDiffToolsZygote.jl @@ -1,3 +1,24 @@ +module SparseDiffToolsZygote + +import Zygote + +using SparseDiffTools +using SparseDiffTools: DeviVecTag, FwdModeAutoDiffVecProd + +using ForwardDiff +using ForwardDiff: Dual, Tag + +using SciMLOperators: FunctionOperator +using Tricks: static_hasmethod + +export + numback_hesvec, numback_hesvec!, + autoback_hesvec, autoback_hesvec!, + auto_vecjac, auto_vecjac!, + ZygoteVecJac, ZygoteHesVec + +### Jac, Hes products + function numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v)) g = let f = f (dx, x) -> dx .= first(Zygote.gradient(f, x)) @@ -49,7 +70,7 @@ function autoback_hesvec(f, x, v) ForwardDiff.partials.(g(y), 1) end -### Operator Forms +# Operator Forms function ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true) @@ -82,4 +103,19 @@ function ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = p = p, t = t, islinear = true, ) end -# + +## VecJac products + +function auto_vecjac!(du, f, x, v, cache1 = nothing, cache2 = nothing) + !hasmethod(f, (typeof(x),)) && error("For inplace function use autodiff = false") + du .= reshape(auto_vecjac(f, x, v), size(du)) +end + +function auto_vecjac(f, x, v) + vv, back = Zygote.pullback(f, x) + return vec(back(reshape(v, size(vv)))[1]) +end + +ZygoteVecJac(args...; autodiff = true, kwargs...) = VecJac(args...; autodiff = autodiff, kwargs...) + +end # module diff --git a/src/SparseDiffTools.jl b/src/SparseDiffTools.jl index 621f7dbf..88cf02c9 100644 --- a/src/SparseDiffTools.jl +++ b/src/SparseDiffTools.jl @@ -5,7 +5,6 @@ using FiniteDiff using ForwardDiff using Graphs using Graphs: SimpleGraph -using Requires using VertexSafeGraphs using Adapt @@ -69,16 +68,16 @@ Base.@pure __parameterless_type(T) = Base.typename(T).wrapper parameterless_type(x) = parameterless_type(typeof(x)) parameterless_type(x::Type) = __parameterless_type(x) -function __init__() - @require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin - export numback_hesvec, numback_hesvec!, - autoback_hesvec, autoback_hesvec!, - auto_vecjac, auto_vecjac!, - ZygoteVecJac, ZygoteHesVec +#if !isdefined(Base, :get_extension) + using Requires +#end - include("differentiation/vecjac_products_zygote.jl") - include("differentiation/jaches_products_zygote.jl") - end +function __init__() + #@static if !isdefined(Base, :get_extension) + @require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin + include("../ext/SparseDiffToolsZygote.jl") + end + #end end end # module diff --git a/src/differentiation/vecjac_products_zygote.jl b/src/differentiation/vecjac_products_zygote.jl deleted file mode 100644 index f5f22623..00000000 --- a/src/differentiation/vecjac_products_zygote.jl +++ /dev/null @@ -1,14 +0,0 @@ -function auto_vecjac!(du, f, x, v, cache1 = nothing, cache2 = nothing) - !hasmethod(f, (typeof(x),)) && error("For inplace function use autodiff = false") - du .= reshape(auto_vecjac(f, x, v), size(du)) -end - -function auto_vecjac(f, x, v) - vv, back = Zygote.pullback(f, x) - return vec(back(reshape(v, size(vv)))[1]) -end - -#ZygoteVecJac = VecJac -ZygoteVecJac(args...; autodiff = true, kwargs...) = VecJac(args...; autodiff = autodiff, kwargs...) - -# From 2c3891c189d74b161812fe2da0d4accd07c2ccae Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Tue, 7 Mar 2023 20:44:43 -0500 Subject: [PATCH 03/11] follow https://pkgdocs.julialang.org/dev/creating-packages/#Conditional-loading-of-code-in-packages-(Extensions) --- Project.toml | 6 ++++++ src/SparseDiffTools.jl | 8 ++++---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 51a3c815..da50c216 100644 --- a/Project.toml +++ b/Project.toml @@ -19,6 +19,12 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" VertexSafeGraphs = "19fa3120-7c27-5ec5-8db8-b0b0aa330d6f" +[weakdeps] +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[extensions] +SparseDiffToolsZygote = "Zygote" + [compat] Adapt = "3.0" ArrayInterface = "6, 7" diff --git a/src/SparseDiffTools.jl b/src/SparseDiffTools.jl index 88cf02c9..0077cc53 100644 --- a/src/SparseDiffTools.jl +++ b/src/SparseDiffTools.jl @@ -68,16 +68,16 @@ Base.@pure __parameterless_type(T) = Base.typename(T).wrapper parameterless_type(x) = parameterless_type(typeof(x)) parameterless_type(x::Type) = __parameterless_type(x) -#if !isdefined(Base, :get_extension) +if !isdefined(Base, :get_extension) using Requires -#end +end function __init__() - #@static if !isdefined(Base, :get_extension) + @static if !isdefined(Base, :get_extension) @require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin include("../ext/SparseDiffToolsZygote.jl") end - #end + end end end # module From 27f2aa90d9a12e6c7e44c38ed5d7462fa055fd54 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Wed, 8 Mar 2023 09:47:42 -0500 Subject: [PATCH 04/11] Update ext/SparseDiffToolsZygote.jl Co-authored-by: Christopher Rackauckas --- ext/SparseDiffToolsZygote.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/SparseDiffToolsZygote.jl b/ext/SparseDiffToolsZygote.jl index 5042c095..19c91734 100644 --- a/ext/SparseDiffToolsZygote.jl +++ b/ext/SparseDiffToolsZygote.jl @@ -9,7 +9,7 @@ using ForwardDiff using ForwardDiff: Dual, Tag using SciMLOperators: FunctionOperator -using Tricks: static_hasmethod +using SparseDiffTools.Tricks: static_hasmethod export numback_hesvec, numback_hesvec!, From c0ca189809f8c0fd07a7b53362e3f10efe2d0e59 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Wed, 8 Mar 2023 09:47:53 -0500 Subject: [PATCH 05/11] Update ext/SparseDiffToolsZygote.jl Co-authored-by: Christopher Rackauckas --- ext/SparseDiffToolsZygote.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/SparseDiffToolsZygote.jl b/ext/SparseDiffToolsZygote.jl index 19c91734..16c691b6 100644 --- a/ext/SparseDiffToolsZygote.jl +++ b/ext/SparseDiffToolsZygote.jl @@ -5,7 +5,7 @@ import Zygote using SparseDiffTools using SparseDiffTools: DeviVecTag, FwdModeAutoDiffVecProd -using ForwardDiff +using SparseDiffTools.ForwardDiff using ForwardDiff: Dual, Tag using SciMLOperators: FunctionOperator From b99cdf9090429f3bb86126c65cb2bb2a9e9439d5 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Wed, 8 Mar 2023 09:48:04 -0500 Subject: [PATCH 06/11] Update ext/SparseDiffToolsZygote.jl Co-authored-by: Christopher Rackauckas --- ext/SparseDiffToolsZygote.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/SparseDiffToolsZygote.jl b/ext/SparseDiffToolsZygote.jl index 16c691b6..2403dfa3 100644 --- a/ext/SparseDiffToolsZygote.jl +++ b/ext/SparseDiffToolsZygote.jl @@ -8,7 +8,7 @@ using SparseDiffTools: DeviVecTag, FwdModeAutoDiffVecProd using SparseDiffTools.ForwardDiff using ForwardDiff: Dual, Tag -using SciMLOperators: FunctionOperator +using SparseDiffTools.SciMLOperators: FunctionOperator using SparseDiffTools.Tricks: static_hasmethod export From ea28647b6ea7c7f9f94da3c7e6b4643fe9445b4e Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Sat, 11 Mar 2023 12:07:13 -0500 Subject: [PATCH 07/11] exports working in 1x8 --- Project.toml | 2 ++ ext/SparseDiffToolsZygote.jl | 23 +++++++++++++---------- src/SparseDiffTools.jl | 14 +++++++------- 3 files changed, 22 insertions(+), 17 deletions(-) diff --git a/Project.toml b/Project.toml index da50c216..958fc37d 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" @@ -38,6 +39,7 @@ SciMLOperators = "0.1.19, 0.2" StaticArrays = "1" Tricks = "0.1.6" VertexSafeGraphs = "0.2" +Zygote = "0.6" julia = "1.6" [extras] diff --git a/ext/SparseDiffToolsZygote.jl b/ext/SparseDiffToolsZygote.jl index 2403dfa3..e105b03c 100644 --- a/ext/SparseDiffToolsZygote.jl +++ b/ext/SparseDiffToolsZygote.jl @@ -1,15 +1,18 @@ module SparseDiffToolsZygote -import Zygote - -using SparseDiffTools -using SparseDiffTools: DeviVecTag, FwdModeAutoDiffVecProd - -using SparseDiffTools.ForwardDiff -using ForwardDiff: Dual, Tag - -using SparseDiffTools.SciMLOperators: FunctionOperator -using SparseDiffTools.Tricks: static_hasmethod +if isdefined(Base, :get_extension) + import Zygote + using SparseDiffTools: SparseDiffTools, DeivVecTag, FwdModeAutoDiffVecProd + using ForwardDiff: ForwardDiff, Dual + using SciMLOperators: FunctionOperator + using Tricks: static_hasmethod +else + import ..Zygote + using ..SparseDiffTools: SparseDiffTools, DeivVecTag, FwdModeAutoDiffVecProd + using ..ForwardDiff: ForwardDiff, Dual + using ..SciMLOperators: FunctionOperator + using ..Tricks: static_hasmethod +end export numback_hesvec, numback_hesvec!, diff --git a/src/SparseDiffTools.jl b/src/SparseDiffTools.jl index 0077cc53..9bb5e89f 100644 --- a/src/SparseDiffTools.jl +++ b/src/SparseDiffTools.jl @@ -20,7 +20,7 @@ using ArrayInterface: matrix_colors using SciMLOperators import SciMLOperators: update_coefficients, update_coefficients! -using Tricks: static_hasmethod +using Tricks: Tricks, static_hasmethod abstract type AbstractAutoDiffVecProd end @@ -68,14 +68,14 @@ Base.@pure __parameterless_type(T) = Base.typename(T).wrapper parameterless_type(x) = parameterless_type(typeof(x)) parameterless_type(x::Type) = __parameterless_type(x) -if !isdefined(Base, :get_extension) - using Requires -end +import Requires +import Reexport -function __init__() - @static if !isdefined(Base, :get_extension) - @require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin +@static if !isdefined(Base, :get_extension) + function __init__() + Requires.@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin include("../ext/SparseDiffToolsZygote.jl") + Reexport.@reexport using .SparseDiffToolsZygote end end end From 8596f605389b24d7fc39ae96b27cb6d45e458e4e Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Sat, 11 Mar 2023 12:14:40 -0500 Subject: [PATCH 08/11] 1x8 test passing --- ext/SparseDiffToolsZygote.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ext/SparseDiffToolsZygote.jl b/ext/SparseDiffToolsZygote.jl index e105b03c..8bc5c1b9 100644 --- a/ext/SparseDiffToolsZygote.jl +++ b/ext/SparseDiffToolsZygote.jl @@ -2,14 +2,16 @@ module SparseDiffToolsZygote if isdefined(Base, :get_extension) import Zygote - using SparseDiffTools: SparseDiffTools, DeivVecTag, FwdModeAutoDiffVecProd + using LinearAlgebra + using SparseDiffTools: SparseDiffTools, DeivVecTag, FwdModeAutoDiffVecProd, VecJac using ForwardDiff: ForwardDiff, Dual using SciMLOperators: FunctionOperator using Tricks: static_hasmethod else import ..Zygote - using ..SparseDiffTools: SparseDiffTools, DeivVecTag, FwdModeAutoDiffVecProd - using ..ForwardDiff: ForwardDiff, Dual + using ..LinearAlgebra + using ..SparseDiffTools: SparseDiffTools, DeivVecTag, FwdModeAutoDiffVecProd, VecJac + using ..ForwardDiff: ForwardDiff, Dual, partials using ..SciMLOperators: FunctionOperator using ..Tricks: static_hasmethod end From 3274b71620fd34b8523834a41920db3a99eced59 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sun, 12 Mar 2023 08:14:01 -0400 Subject: [PATCH 09/11] move function declarations to main --- ext/SparseDiffToolsZygote.jl | 6 ------ src/SparseDiffTools.jl | 15 +++++++++++++++ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/ext/SparseDiffToolsZygote.jl b/ext/SparseDiffToolsZygote.jl index 8bc5c1b9..a157b673 100644 --- a/ext/SparseDiffToolsZygote.jl +++ b/ext/SparseDiffToolsZygote.jl @@ -16,12 +16,6 @@ else using ..Tricks: static_hasmethod end -export - numback_hesvec, numback_hesvec!, - autoback_hesvec, autoback_hesvec!, - auto_vecjac, auto_vecjac!, - ZygoteVecJac, ZygoteHesVec - ### Jac, Hes products function numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v)) diff --git a/src/SparseDiffTools.jl b/src/SparseDiffTools.jl index 9bb5e89f..95c72a1a 100644 --- a/src/SparseDiffTools.jl +++ b/src/SparseDiffTools.jl @@ -71,6 +71,15 @@ parameterless_type(x::Type) = __parameterless_type(x) import Requires import Reexport +function numback_hesvec end +function numback_hesvec! end +function autoback_hesvec end +function autoback_hesvec! end +function auto_vecjac end +function auto_vecjac! end +function ZygoteVecJac end +function ZygoteHesVec end + @static if !isdefined(Base, :get_extension) function __init__() Requires.@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin @@ -80,4 +89,10 @@ import Reexport end end +export + numback_hesvec, numback_hesvec!, + autoback_hesvec, autoback_hesvec!, + auto_vecjac, auto_vecjac!, + ZygoteVecJac, ZygoteHesVec + end # module From 5cc3111a382dc227728bcd14a3dd4d9e3a57e54d Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sun, 12 Mar 2023 08:36:01 -0400 Subject: [PATCH 10/11] Update Project.toml --- Project.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 958fc37d..d227338a 100644 --- a/Project.toml +++ b/Project.toml @@ -17,6 +17,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718" Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" VertexSafeGraphs = "19fa3120-7c27-5ec5-8db8-b0b0aa330d6f" @@ -28,7 +29,7 @@ SparseDiffToolsZygote = "Zygote" [compat] Adapt = "3.0" -ArrayInterface = "6, 7" +ArrayInterface = "7" Compat = "4" DataStructures = "0.18" FiniteDiff = "2.8.1" @@ -37,6 +38,7 @@ Graphs = "1" Requires = "1" SciMLOperators = "0.1.19, 0.2" StaticArrays = "1" +StaticArrayInterface = "1.3" Tricks = "0.1.6" VertexSafeGraphs = "0.2" Zygote = "0.6" From 21a9a1efa1b1d9c33ae23371cc27996c553477a6 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Sun, 12 Mar 2023 10:50:28 -0400 Subject: [PATCH 11/11] overload functions from main --- ext/SparseDiffToolsZygote.jl | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/ext/SparseDiffToolsZygote.jl b/ext/SparseDiffToolsZygote.jl index a157b673..377820b2 100644 --- a/ext/SparseDiffToolsZygote.jl +++ b/ext/SparseDiffToolsZygote.jl @@ -4,7 +4,7 @@ if isdefined(Base, :get_extension) import Zygote using LinearAlgebra using SparseDiffTools: SparseDiffTools, DeivVecTag, FwdModeAutoDiffVecProd, VecJac - using ForwardDiff: ForwardDiff, Dual + using ForwardDiff: ForwardDiff, Dual, partials using SciMLOperators: FunctionOperator using Tricks: static_hasmethod else @@ -18,7 +18,7 @@ end ### Jac, Hes products -function numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v)) +function SparseDiffTools.numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v)) g = let f = f (dx, x) -> dx .= first(Zygote.gradient(f, x)) end @@ -32,7 +32,7 @@ function numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v)) @. dy = (cache1 - cache2) / (2ϵ) end -function numback_hesvec(f, x, v) +function SparseDiffTools.numback_hesvec(f, x, v) g = x -> first(Zygote.gradient(f, x)) T = eltype(x) # Should it be min? max? mean? @@ -44,7 +44,7 @@ function numback_hesvec(f, x, v) (gxp - gxm) / (2ϵ) end -function autoback_hesvec!(dy, f, x, v, +function SparseDiffTools.autoback_hesvec!(dy, f, x, v, cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))), eltype(x), 1 }.(x, @@ -62,7 +62,7 @@ function autoback_hesvec!(dy, f, x, v, dy .= partials.(cache2, 1) end -function autoback_hesvec(f, x, v) +function SparseDiffTools.autoback_hesvec(f, x, v) g = x -> first(Zygote.gradient(f, x)) y = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))), eltype(x), 1 }.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))) @@ -71,7 +71,7 @@ end # Operator Forms -function ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true) +function SparseDiffTools.ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true) if autodiff cache1 = Dual{ @@ -85,8 +85,8 @@ function ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = cache = (cache1, cache2,) - vecprod = autodiff ? autoback_hesvec : numback_hesvec - vecprod! = autodiff ? autoback_hesvec! : numback_hesvec! + vecprod = autodiff ? SparseDiffTools.autoback_hesvec : SparseDiffTools.numback_hesvec + vecprod! = autodiff ? SparseDiffTools.autoback_hesvec! : SparseDiffTools.numback_hesvec! outofplace = static_hasmethod(f, typeof((u,))) isinplace = static_hasmethod(f, typeof((u,))) @@ -105,16 +105,18 @@ end ## VecJac products -function auto_vecjac!(du, f, x, v, cache1 = nothing, cache2 = nothing) +function SparseDiffTools.auto_vecjac!(du, f, x, v, cache1 = nothing, cache2 = nothing) !hasmethod(f, (typeof(x),)) && error("For inplace function use autodiff = false") - du .= reshape(auto_vecjac(f, x, v), size(du)) + du .= reshape(SparseDiffTools.auto_vecjac(f, x, v), size(du)) end -function auto_vecjac(f, x, v) +function SparseDiffTools.auto_vecjac(f, x, v) vv, back = Zygote.pullback(f, x) return vec(back(reshape(v, size(vv)))[1]) end -ZygoteVecJac(args...; autodiff = true, kwargs...) = VecJac(args...; autodiff = autodiff, kwargs...) +function SparseDiffTools.ZygoteVecJac(args...; autodiff = true, kwargs...) + VecJac(args...; autodiff = autodiff, kwargs...) +end end # module