diff --git a/.gitignore b/.gitignore index 19091028..a7b6f14b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,8 @@ .DS_Store +.*.swp +.*.swo Manifest.toml /dev/ docs/build/ -docs/site/ \ No newline at end of file +docs/site/ diff --git a/Project.toml b/Project.toml index ee32314b..090b3555 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SparseDiffTools" uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804" authors = ["Pankaj Mishra ", "Chris Rackauckas "] -version = "1.31.0" +version = "2.00.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -13,8 +13,10 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Requires = "ae029012-a4dd-5104-9daa-d747884805df" +SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" VertexSafeGraphs = "19fa3120-7c27-5ec5-8db8-b0b0aa330d6f" [compat] @@ -25,8 +27,10 @@ DataStructures = "0.18" FiniteDiff = "2.8.1" ForwardDiff = "0.10" Graphs = "1" -Requires = "1.0" +Requires = "1" +SciMLOperators = "0.1.19" StaticArrays = "1" +Tricks = "0.1.6" VertexSafeGraphs = "0.2" julia = "1.6" diff --git a/src/SparseDiffTools.jl b/src/SparseDiffTools.jl index 2a9b64dc..621f7dbf 100644 --- a/src/SparseDiffTools.jl +++ b/src/SparseDiffTools.jl @@ -19,6 +19,12 @@ using DataStructures: DisjointSets, find_root!, union! using ArrayInterface: matrix_colors +using SciMLOperators +import SciMLOperators: update_coefficients, update_coefficients! +using Tricks: static_hasmethod + +abstract type AbstractAutoDiffVecProd end + export contract_color, greedy_d1, greedy_star1_coloring, @@ -42,7 +48,8 @@ export contract_color, autonum_hesvec, autonum_hesvec!, num_hesvecgrad, num_hesvecgrad!, auto_hesvecgrad, auto_hesvecgrad!, - JacVec, HesVec, HesVecGrad, + JacVec, HesVec, HesVecGrad, VecJac, + update_coefficients, update_coefficients!, value! include("coloring/high_level.jl") @@ -64,8 +71,10 @@ parameterless_type(x::Type) = __parameterless_type(x) function __init__() @require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin - export numback_hesvec, numback_hesvec!, autoback_hesvec, autoback_hesvec!, - auto_vecjac, auto_vecjac! + export numback_hesvec, numback_hesvec!, + autoback_hesvec, autoback_hesvec!, + auto_vecjac, auto_vecjac!, + ZygoteVecJac, ZygoteHesVec include("differentiation/vecjac_products_zygote.jl") include("differentiation/jaches_products_zygote.jl") diff --git a/src/differentiation/jaches_products.jl b/src/differentiation/jaches_products.jl index 56fab615..8717e857 100644 --- a/src/differentiation/jaches_products.jl +++ b/src/differentiation/jaches_products.jl @@ -198,112 +198,126 @@ end ### Operator Forms -struct JacVec{F, T1, T2, xType} +struct FwdModeAutoDiffVecProd{F,U,C,V,V!} <: AbstractAutoDiffVecProd f::F - cache1::T1 - cache2::T2 - x::xType - autodiff::Bool + u::U + cache::C + vecprod::V + vecprod!::V! end -function JacVec(f, x::AbstractArray, tag = DeivVecTag(); autodiff = true) - if autodiff - cache1 = Dual{typeof(ForwardDiff.Tag(tag, eltype(x))), eltype(x), 1 - }.(x, ForwardDiff.Partials.(tuple.(x))) - cache2 = Dual{typeof(ForwardDiff.Tag(tag, eltype(x))), eltype(x), 1 - }.(x, ForwardDiff.Partials.(tuple.(x))) - else - cache1 = similar(x) - cache2 = similar(x) - end - JacVec(f, cache1, cache2, x, autodiff) +function update_coefficients(L::FwdModeAutoDiffVecProd, u, p, t) + FwdModeAutoDiffVecProd(L.f, u, L.vecprod, L.vecprod!, L.cache) end -Base.eltype(L::JacVec) = eltype(L.x) -Base.size(L::JacVec) = (length(L.cache1), length(L.cache1)) -Base.size(L::JacVec, i::Int) = length(L.cache1) -function Base.:*(L::JacVec, v::AbstractVector) - L.autodiff ? auto_jacvec(_x -> L.f(_x), L.x, v) : - num_jacvec(_x -> L.f(_x), L.x, v) +function update_coefficients!(L::FwdModeAutoDiffVecProd, u, p, t) + copy!(L.u, u) + L end -function LinearAlgebra.mul!(dy::AbstractVector, L::JacVec, v::AbstractVector) - if L.autodiff - auto_jacvec!(dy, (_y, _x) -> L.f(_y, _x), L.x, v, L.cache1, L.cache2) - else - num_jacvec!(dy, (_y, _x) -> L.f(_y, _x), L.x, v, L.cache1, L.cache2) - end +function (L::FwdModeAutoDiffVecProd)(v, p, t) + L.vecprod(L.f, L.u, v) end -struct HesVec{F, T1, T2, xType} - f::F - cache1::T1 - cache2::T2 - cache3::T2 - x::xType - autodiff::Bool +function (L::FwdModeAutoDiffVecProd)(dv, v, p, t) + L.vecprod!(dv, L.f, L.u, v, L.cache...) end -function HesVec(f, x::AbstractArray; autodiff = true) +function JacVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true) + if autodiff - cache1 = ForwardDiff.GradientConfig(f, x) - cache2 = similar(x) - cache3 = similar(x) + cache1 = Dual{ + typeof(ForwardDiff.Tag(DeivVecTag(),eltype(u))), eltype(u), 1 + }.(u, ForwardDiff.Partials.(tuple.(u))) + + cache2 = copy(cache1) else - cache1 = similar(x) - cache2 = similar(x) - cache3 = similar(x) + cache1 = similar(u) + cache2 = similar(u) end - HesVec(f, cache1, cache2, cache3, x, autodiff) -end -Base.size(L::HesVec) = (length(L.cache2), length(L.cache2)) -Base.size(L::HesVec, i::Int) = length(L.cache2) -function Base.:*(L::HesVec, v::AbstractVector) - L.autodiff ? numauto_hesvec(L.f, L.x, v) : num_hesvec(L.f, L.x, v) -end + cache = (cache1, cache2,) -function LinearAlgebra.mul!(dy::AbstractVector, L::HesVec, v::AbstractVector) - if L.autodiff - numauto_hesvec!(dy, L.f, L.x, v, L.cache1, L.cache2, L.cache3) - else - num_hesvec!(dy, L.f, L.x, v, L.cache1, L.cache2, L.cache3) + vecprod = autodiff ? auto_jacvec : num_jacvec + vecprod! = autodiff ? auto_jacvec! : num_jacvec! + + outofplace = static_hasmethod(f, typeof((u,))) + isinplace = static_hasmethod(f, typeof((u, u,))) + + if !(isinplace) & !(outofplace) + error("$f must have signature f(u), or f(du, u).") end -end -struct HesVecGrad{G, T1, T2, uType} - g::G - cache1::T1 - cache2::T2 - x::uType - autodiff::Bool + L = FwdModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!) + + FunctionOperator(L, u, u; + isinplace = isinplace, outofplace = outofplace, + p = p, t = t, islinear = true, + ) end -function HesVecGrad(g, x::AbstractArray, tag = DeivVecTag(); autodiff = false) +function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true) + if autodiff - cache1 = Dual{typeof(ForwardDiff.Tag(tag, eltype(x))), eltype(x), 1 - }.(x, ForwardDiff.Partials.(tuple.(x))) - cache2 = Dual{typeof(ForwardDiff.Tag(tag, eltype(x))), eltype(x), 1 - }.(x, ForwardDiff.Partials.(tuple.(x))) + cache1 = ForwardDiff.GradientConfig(f, u) + cache2 = similar(u) + cache3 = similar(u) else - cache1 = similar(x) - cache2 = similar(x) + cache1 = similar(u) + cache2 = similar(u) + cache3 = similar(u) end - HesVecGrad(g, cache1, cache2, x, autodiff) -end -Base.size(L::HesVecGrad) = (length(L.cache2), length(L.cache2)) -Base.size(L::HesVecGrad, i::Int) = length(L.cache2) -function Base.:*(L::HesVecGrad, v::AbstractVector) - L.autodiff ? auto_hesvecgrad(L.g, L.x, v) : num_hesvecgrad(L.g, L.x, v) + cache = (cache1, cache2, cache3,) + + vecprod = autodiff ? numauto_hesvec : num_hesvec + vecprod! = autodiff ? numauto_hesvec! : num_hesvec! + + outofplace = static_hasmethod(f, typeof((u,))) + isinplace = static_hasmethod(f, typeof((u,))) + + if !(isinplace) & !(outofplace) + error("$f must have signature f(u).") + end + + L = FwdModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!) + + FunctionOperator(L, u, u; + isinplace = isinplace, outofplace = outofplace, + p = p, t = t, islinear = true, + ) end -function LinearAlgebra.mul!(dy::AbstractVector, - L::HesVecGrad, - v::AbstractVector) - if L.autodiff - auto_hesvecgrad!(dy, L.g, L.x, v, L.cache1, L.cache2) +function HesVecGrad(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true) + + if autodiff + cache1 = Dual{ + typeof(ForwardDiff.Tag(DeivVecTag(), eltype(u))), eltype(u), 1 + }.(u, ForwardDiff.Partials.(tuple.(u))) + + cache2 = copy(cache1) else - num_hesvecgrad!(dy, L.g, L.x, v, L.cache1, L.cache2) + cache1 = similar(u) + cache2 = similar(u) + end + + cache = (cache1, cache2,) + + vecprod = autodiff ? auto_hesvecgrad : num_hesvecgrad + vecprod! = autodiff ? auto_hesvecgrad! : num_hesvecgrad! + + outofplace = static_hasmethod(f, typeof((u,))) + isinplace = static_hasmethod(f, typeof((u, u,))) + + if !(isinplace) & !(outofplace) + error("$f must have signature f(u), or f(du, u).") end + + L = FwdModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!) + + FunctionOperator(L, u, u; + isinplace = isinplace, outofplace = outofplace, + p = p, t = t, islinear = true, + ) end +# diff --git a/src/differentiation/jaches_products_zygote.jl b/src/differentiation/jaches_products_zygote.jl index e97dbc95..3187d3ff 100644 --- a/src/differentiation/jaches_products_zygote.jl +++ b/src/differentiation/jaches_products_zygote.jl @@ -25,21 +25,21 @@ function numback_hesvec(f, x, v) end function autoback_hesvec!(dy, f, x, v, - cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))), + cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))), eltype(x), 1 }.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))), - cache3 = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))), + cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))), eltype(x), 1 }.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))) g = let f = f (dx, x) -> dx .= first(Zygote.gradient(f, x)) end - cache2 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))), eltype(x), 1 + cache1 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))), eltype(x), 1 }.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))) - g(cache3, cache2) - dy .= partials.(cache3, 1) + g(cache2, cache1) + dy .= partials.(cache2, 1) end function autoback_hesvec(f, x, v) @@ -48,3 +48,38 @@ function autoback_hesvec(f, x, v) }.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))) ForwardDiff.partials.(g(y), 1) end + +### Operator Forms + +function ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true) + + if autodiff + cache1 = Dual{ + typeof(ForwardDiff.Tag(DeivVecTag(),eltype(u))), eltype(u), 1 + }.(u, ForwardDiff.Partials.(tuple.(u))) + cache2 = copy(u) + else + cache1 = similar(u) + cache2 = similar(u) + end + + cache = (cache1, cache2,) + + vecprod = autodiff ? autoback_hesvec : numback_hesvec + vecprod! = autodiff ? autoback_hesvec! : numback_hesvec! + + outofplace = static_hasmethod(f, typeof((u,))) + isinplace = static_hasmethod(f, typeof((u,))) + + if !(isinplace) & !(outofplace) + error("$f must have signature f(u).") + end + + L = FwdModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!) + + FunctionOperator(L, u, u; + isinplace = isinplace, outofplace = outofplace, + p = p, t = t, islinear = true, + ) +end +# diff --git a/src/differentiation/vecjac_products.jl b/src/differentiation/vecjac_products.jl index d31e4172..2a3c4b89 100644 --- a/src/differentiation/vecjac_products.jl +++ b/src/differentiation/vecjac_products.jl @@ -34,3 +34,84 @@ function num_vecjac(f, x, v, f0 = nothing) end return vec(du) end + +### Operator Forms + +struct RevModeAutoDiffVecProd{ad,iip,oop,F,U,C,V,V!} <: AbstractAutoDiffVecProd + f::F + u::U + cache::C + vecprod::V + vecprod!::V! + + function RevModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!; autodiff = false, + isinplace = false, outofplace = true) + @assert isinplace || outofplace + + new{ + autodiff, + isinplace, + outofplace, + typeof(f), + typeof(u), + typeof(cache), + typeof(vecprod), + typeof(vecprod!), + }( + f, u, cache, vecprod, vecprod!, + ) + end +end + +function update_coefficients(L::RevModeAutoDiffVecProd, u, p, t) + RevModeAutoDiffVecProd(L.f, u, L.vecprod, L.vecprod!, L.cache) +end + +function update_coefficients!(L::RevModeAutoDiffVecProd, u, p, t) + copy!(L.u, u) + L +end + +# Interpret the call as df/du' * u +function (L::RevModeAutoDiffVecProd)(v, p, t) + L.vecprod(_u -> L.f(_u, p, t), L.u, v) +end + +# prefer non in-place method +function (L::RevModeAutoDiffVecProd{ad,iip,true})(dv, v, p, t) where{ad,iip} + L.vecprod!(dv, _u -> L.f(_u, p, t), L.u, v, L.cache...) +end + +function (L::RevModeAutoDiffVecProd{ad,true,false})(dv, v, p, t) where{ad} + L.vecprod!(dv, (_du, _u) -> L.f(_du, _u, p, t), L.u, v, L.cache...) +end + +function VecJac(f, u::AbstractArray, p = nothing, t = nothing; autodiff = false, + ishermitian = false, opnrom = true) + + if autodiff + @assert isdefined(SparseDiffTools, :auto_vecjac) "Please load Zygote with `using Zygote`, or `import Zygote` to use VecJac with `autodiff = true`." + end + + cache = (similar(u), similar(u),) + + vecprod = autodiff ? auto_vecjac : num_vecjac + vecprod! = autodiff ? auto_vecjac! : num_vecjac! + + outofplace = static_hasmethod(f, typeof((u, p, t))) + isinplace = static_hasmethod(f, typeof((u, u, p, t))) + + if !(isinplace) & !(outofplace) + error("$f must have signature f(u, p, t), or f(du, u, p, t)") + end + + L = RevModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!; autodiff = autodiff, + isinplace = isinplace, outofplace = outofplace) + + FunctionOperator(L, u, u; + isinplace = isinplace, outofplace = outofplace, + p = p, t = t, islinear = true, + ishermitian = ishermitian, opnorm = opnorm, + ) +end +# diff --git a/src/differentiation/vecjac_products_zygote.jl b/src/differentiation/vecjac_products_zygote.jl index 495502de..f5f22623 100644 --- a/src/differentiation/vecjac_products_zygote.jl +++ b/src/differentiation/vecjac_products_zygote.jl @@ -7,3 +7,8 @@ function auto_vecjac(f, x, v) vv, back = Zygote.pullback(f, x) return vec(back(reshape(v, size(vv)))[1]) end + +#ZygoteVecJac = VecJac +ZygoteVecJac(args...; autodiff = true, kwargs...) = VecJac(args...; autodiff = autodiff, kwargs...) + +# diff --git a/test/gpu/Project.toml b/test/gpu/Project.toml index 3255caa1..c96c01b2 100644 --- a/test/gpu/Project.toml +++ b/test/gpu/Project.toml @@ -1,3 +1,2 @@ [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -ArrayInterfaceGPUArrays = "6ba088a2-8465-4c0a-af30-387133b534db" \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index e3080368..6c44a287 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,6 +22,7 @@ if GROUP == "All" @time @safetestset "Integration test" begin include("test_integration.jl") end @time @safetestset "Special matrices" begin include("test_specialmatrices.jl") end @time @safetestset "Jac Vecs and Hes Vecs" begin include("test_jaches_products.jl") end + @time @safetestset "Vec Jac Products" begin include("test_vecjac_products.jl") end end if GROUP == "GPU" diff --git a/test/test_ad.jl b/test/test_ad.jl index 749008ca..f79759d7 100644 --- a/test/test_ad.jl +++ b/test/test_ad.jl @@ -2,10 +2,9 @@ using SparseDiffTools using ForwardDiff: Dual, jacobian, value using SparseArrays, Test using LinearAlgebra -using BlockBandedMatrices, ArrayInterfaceBlockBandedMatrices -using BandedMatrices, ArrayInterfaceBandedMatrices +using BlockBandedMatrices +using BandedMatrices using StaticArrays -using ArrayInterfaceStaticArrays fcalls = 0 function f(dx, x) diff --git a/test/test_gpu_ad.jl b/test/test_gpu_ad.jl index 3f9e4cce..4ae4b303 100644 --- a/test/test_gpu_ad.jl +++ b/test/test_gpu_ad.jl @@ -1,7 +1,6 @@ using SparseDiffTools, CUDA, Test, LinearAlgebra -using ArrayInterfaceCore: allowed_getindex, allowed_setindex! +using ArrayInterface: allowed_getindex, allowed_setindex! using SparseArrays -using ArrayInterfaceGPUArrays function f(dx, x) dx[2:(end - 1)] = x[1:(end - 2)] - 2x[2:(end - 1)] + x[3:end] diff --git a/test/test_jaches_products.jl b/test/test_jaches_products.jl index 20918232..1efcf565 100644 --- a/test/test_jaches_products.jl +++ b/test/test_jaches_products.jl @@ -3,12 +3,13 @@ using LinearAlgebra, Test using Random Random.seed!(123) - -const A = rand(300, 300) +N = 300 +const A = rand(N, N) f(y, x) = mul!(y, A, x) f(x) = A * x -x = rand(300) -v = rand(300) +x = rand(N) +v = rand(N) +a, b = rand(2) dy = similar(x) g(x) = sum(abs2, x) function h(x) @@ -20,8 +21,7 @@ end cache1 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), eltype(x))), eltype(x), 1}.(x, ForwardDiff.Partials.(Tuple.(v))) -cache2 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), eltype(x))), - eltype(x), 1}.(x, ForwardDiff.Partials.(Tuple.(v))) +cache2 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), eltype(x))), eltype(x), 1}.(x, ForwardDiff.Partials.(Tuple.(v))) @test num_jacvec!(dy, f, x, v)≈ForwardDiff.jacobian(f, similar(x), x) * v rtol=1e-6 @test num_jacvec!(dy, f, x, v, similar(v), similar(v))≈ForwardDiff.jacobian(f, similar(x), x) * v rtol=1e-6 @@ -65,61 +65,103 @@ cache4 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing, eltype(x))), eltype(x) @test auto_hesvecgrad!(dy, h, x, v, cache1, cache2)≈ForwardDiff.hessian(g, x) * v rtol=1e-2 @test auto_hesvecgrad(h, x, v)≈ForwardDiff.hessian(g, x) * v rtol=1e-2 +@info "JacVec" + L = JacVec(f, x) @test L * x ≈ auto_jacvec(f, x, x) @test L * v ≈ auto_jacvec(f, x, v) @test mul!(dy, L, v) ≈ auto_jacvec(f, x, v) -L.x .= v +dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*auto_jacvec(f,x,v) + b*_dy +update_coefficients!(L, v, nothing, 0.0) @test mul!(dy, L, v) ≈ auto_jacvec(f, v, v) +dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*auto_jacvec(f,x,v) + b*_dy L = JacVec(f, x, autodiff = false) @test L * x ≈ num_jacvec(f, x, x) @test L * v ≈ num_jacvec(f, x, v) -L.x == x @test mul!(dy, L, v)≈num_jacvec(f, x, v) rtol=1e-6 -L.x .= v +dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*num_jacvec(f,x,v) + b*_dy rtol=1e-6 +update_coefficients!(L, v, nothing, 0.0) @test mul!(dy, L, v)≈num_jacvec(f, v, v) rtol=1e-6 +dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*num_jacvec(f,x,v) + b*_dy rtol=1e-6 -### Integration test with IterativeSolvers out = similar(v) gmres!(out, L, v) -x = rand(300) -v = rand(300) +@info "HesVec" + +x = rand(N) +v = rand(N) L = HesVec(g, x, autodiff = false) @test L * x ≈ num_hesvec(g, x, x) @test L * v ≈ num_hesvec(g, x, v) @test mul!(dy, L, v)≈num_hesvec(g, x, v) rtol=1e-2 -L.x .= v +dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*num_hesvec(g,x,v) + b*_dy rtol=1e-2 +update_coefficients!(L, v, nothing, 0.0) @test mul!(dy, L, v)≈num_hesvec(g, v, v) rtol=1e-2 +dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*num_hesvec(g,x,v) + b*_dy rtol=1e-2 L = HesVec(g, x) @test L * x ≈ numauto_hesvec(g, x, x) @test L * v ≈ numauto_hesvec(g, x, v) @test mul!(dy, L, v)≈numauto_hesvec(g, x, v) rtol=1e-8 -L.x .= v +dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8 +update_coefficients!(L, v, nothing, 0.0) @test mul!(dy, L, v)≈numauto_hesvec(g, v, v) rtol=1e-8 +dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8 -### Integration test with IterativeSolvers out = similar(v) gmres!(out, L, v) -x = rand(300) -v = rand(300) +@info "ZygoteHesVec" +using Zygote +x = rand(N) +v = rand(N) + +L = ZygoteHesVec(g, x, autodiff = false) +@test L * x ≈ numback_hesvec(g, x, x) rtol = 1e-2 +@test L * v ≈ numback_hesvec(g, x, v) rtol = 1e-2 +@test mul!(dy, L, v)≈numback_hesvec(g, x, v) rtol=1e-2 +dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*numback_hesvec(g,x,v) + b*_dy rtol=1e-2 +update_coefficients!(L, v, nothing, 0.0) +@test mul!(dy, L, v)≈numback_hesvec(g, v, v) rtol=1e-2 +dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*numback_hesvec(g,x,v) + b*_dy rtol=1e-2 + +L = HesVec(g, x) +@test L * x ≈ autoback_hesvec(g, x, x) +@test L * v ≈ autoback_hesvec(g, x, v) +@test mul!(dy, L, v)≈autoback_hesvec(g, x, v) rtol=1e-8 +dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*autoback_hesvec(g,x,v)+b*_dy rtol=1e-8 +update_coefficients!(L, v, nothing, 0.0) +@test mul!(dy, L, v)≈autoback_hesvec(g, v, v) rtol=1e-8 +dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*autoback_hesvec(g,x,v)+b*_dy rtol=1e-8 + +out = similar(v) +gmres!(out, L, v) + + +@info "HesVecGrad" + +x = rand(N) +v = rand(N) L = HesVecGrad(h, x, autodiff = false) @test L * x ≈ num_hesvec(g, x, x) @test L * v ≈ num_hesvec(g, x, v) @test mul!(dy, L, v)≈num_hesvec(g, x, v) rtol=1e-2 -L.x .= v +dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*num_hesvec(g,x,v)+b*_dy rtol=1e-2 +update_coefficients!(L, v, nothing, 0.0) @test mul!(dy, L, v)≈num_hesvec(g, v, v) rtol=1e-2 +dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*num_hesvec(g,x,v)+b*_dy rtol=1e-2 L = HesVecGrad(h, x, autodiff = true) @test L * x ≈ autonum_hesvec(g, x, x) @test L * v ≈ numauto_hesvec(g, x, v) @test mul!(dy, L, v)≈numauto_hesvec(g, x, v) rtol=1e-8 -L.x .= v +dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8 +update_coefficients!(L, v, nothing, 0.0) @test mul!(dy, L, v)≈numauto_hesvec(g, v, v) rtol=1e-8 +dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8 -### Integration test with IterativeSolvers out = similar(v) gmres!(out, L, v) +# diff --git a/test/test_vecjac_products.jl b/test/test_vecjac_products.jl new file mode 100644 index 00000000..07c9c2bd --- /dev/null +++ b/test/test_vecjac_products.jl @@ -0,0 +1,34 @@ +using SparseDiffTools, ForwardDiff, FiniteDiff, Zygote, IterativeSolvers +using LinearAlgebra, Test + +using Random +Random.seed!(123) +N = 300 +const A = rand(N, N) + +x = rand(Float32, N) +v = rand(Float32, N) + +f(du,u,p,t) = mul!(du, A, u) +f(u,p,t) = A * u + +@info "VecJac" + +L = VecJac(f, x) +actual_vjp = Zygote.jacobian(x -> f(x, nothing, 0.0), x)[1]' * v +update_coefficients!(L, v, nothing, 0.0) +@test L * v ≈ actual_vjp +L = VecJac(f, x; autodiff = false) +update_coefficients!(L, v, nothing, 0.0) +@test L * v ≈ actual_vjp + +@info "ZygoteVecJac" + +L = ZygoteVecJac(f, x) +actual_vjp = Zygote.jacobian(x -> f(x, nothing, 0.0), x)[1]' * v +update_coefficients!(L, v, nothing, 0.0) +@test L * v ≈ actual_vjp +L = ZygoteVecJac(f, x; autodiff = false) +update_coefficients!(L, v, nothing, 0.0) +@test L * v ≈ actual_vjp +#