From c0846a740a87d0b0065de3bdd021d0ca495e867c Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Tue, 28 Dec 2021 22:35:51 -0500 Subject: [PATCH 1/4] proper duals for JacVec --- Project.toml | 2 +- src/differentiation/jaches_products.jl | 13 +++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 79ac2698..a043d2cb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SparseDiffTools" uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804" authors = ["Pankaj Mishra ", "Chris Rackauckas "] -version = "1.19.1" +version = "1.19.2" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/differentiation/jaches_products.jl b/src/differentiation/jaches_products.jl index b35b3778..149fdfa3 100644 --- a/src/differentiation/jaches_products.jl +++ b/src/differentiation/jaches_products.jl @@ -6,17 +6,21 @@ function auto_jacvec!( f, x, v, - cache1 = Dual{DeivVecTag}.(x, reshape(v, size(x))), + cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials(reshape(v, size(x)))), cache2 = similar(cache1), ) - cache1 .= Dual{DeivVecTag}.(x, reshape(v, size(x))) + cache1 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials(reshape(v, size(x)))) f(cache2, cache1) - dy .= partials.(cache2, 1) + vecdy = _vec(dy) + vecdy .= partials.(vec(cache2), 1) end +_vec(v) = vec(v) +_vec(v::AbstractVector) = v + function auto_jacvec(f, x, v) vv = reshape(v, axes(x)) - vec(partials.(vec(f(ForwardDiff.Dual{DeivVecTag}.(x, vv))), 1)) + vec(partials.(vec(f(ForwardDiff.Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, vv))), 1)) end function num_jacvec!( @@ -197,6 +201,7 @@ function JacVec(f, x::AbstractArray; autodiff = true) JacVec(f, cache1, cache2, x, autodiff) end +Base.eltype(L::JacVec) = eltype(L.x) Base.size(L::JacVec) = (length(L.cache1), length(L.cache1)) Base.size(L::JacVec, i::Int) = length(L.cache1) Base.:*(L::JacVec, v::AbstractVector) = From a4ed331343dff776903bb65b995fed5bd47ebeb1 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Tue, 28 Dec 2021 22:49:42 -0500 Subject: [PATCH 2/4] add missing partials --- src/differentiation/jaches_products.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/differentiation/jaches_products.jl b/src/differentiation/jaches_products.jl index 149fdfa3..d9bfa388 100644 --- a/src/differentiation/jaches_products.jl +++ b/src/differentiation/jaches_products.jl @@ -20,7 +20,7 @@ _vec(v::AbstractVector) = v function auto_jacvec(f, x, v) vv = reshape(v, axes(x)) - vec(partials.(vec(f(ForwardDiff.Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, vv))), 1)) + vec(partials.(vec(f(ForwardDiff.Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials(vv)))), 1)) end function num_jacvec!( From 9690d71ba2853de610d08725eb358ad6a735114d Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Tue, 28 Dec 2021 22:56:51 -0500 Subject: [PATCH 3/4] fix broadcast --- src/differentiation/jaches_products.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/differentiation/jaches_products.jl b/src/differentiation/jaches_products.jl index d9bfa388..409cdb14 100644 --- a/src/differentiation/jaches_products.jl +++ b/src/differentiation/jaches_products.jl @@ -6,10 +6,10 @@ function auto_jacvec!( f, x, v, - cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials(reshape(v, size(x)))), + cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(reshape(v, size(x)))), cache2 = similar(cache1), ) - cache1 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials(reshape(v, size(x)))) + cache1 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(reshape(v, size(x)))) f(cache2, cache1) vecdy = _vec(dy) vecdy .= partials.(vec(cache2), 1) @@ -20,7 +20,7 @@ _vec(v::AbstractVector) = v function auto_jacvec(f, x, v) vv = reshape(v, axes(x)) - vec(partials.(vec(f(ForwardDiff.Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials(vv)))), 1)) + vec(partials.(vec(f(ForwardDiff.Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(vv)))), 1)) end function num_jacvec!( From 7c8723e5557d2401c2d5bbc648bd8de66f9ffa87 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Wed, 29 Dec 2021 05:44:40 -0500 Subject: [PATCH 4/4] fix all Dual definitions --- src/differentiation/jaches_products.jl | 32 ++++++++++--------- src/differentiation/jaches_products_zygote.jl | 10 +++--- test/test_jaches_products.jl | 8 ++--- 3 files changed, 27 insertions(+), 23 deletions(-) diff --git a/src/differentiation/jaches_products.jl b/src/differentiation/jaches_products.jl index 409cdb14..a2b5d48a 100644 --- a/src/differentiation/jaches_products.jl +++ b/src/differentiation/jaches_products.jl @@ -6,13 +6,13 @@ function auto_jacvec!( f, x, v, - cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(reshape(v, size(x)))), + cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))), cache2 = similar(cache1), ) - cache1 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(reshape(v, size(x)))) + cache1 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))) f(cache2, cache1) vecdy = _vec(dy) - vecdy .= partials.(vec(cache2), 1) + vecdy .= partials.(_vec(cache2), 1) end _vec(v) = vec(v) @@ -20,7 +20,8 @@ _vec(v::AbstractVector) = v function auto_jacvec(f, x, v) vv = reshape(v, axes(x)) - vec(partials.(vec(f(ForwardDiff.Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(vv)))), 1)) + y = ForwardDiff.Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(vv))) + vec(partials.(vec(f(y)), 1)) end function num_jacvec!( @@ -126,12 +127,12 @@ function autonum_hesvec!( f, x, v, - cache1 = ForwardDiff.Dual{DeivVecTag}.(x, v), - cache2 = ForwardDiff.Dual{DeivVecTag}.(x, v), + cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))), + cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))), ) cache = FiniteDiff.GradientCache(v[1], cache1, Val{:central}) g = (dx, x) -> FiniteDiff.finite_difference_gradient!(dx, f, x, cache) - cache1 .= Dual{DeivVecTag}.(x, v) + cache1 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))) g(cache2, cache1) dy .= partials.(cache2, 1) end @@ -168,16 +169,17 @@ function auto_hesvecgrad!( g, x, v, - cache2 = ForwardDiff.Dual{DeivVecTag}.(x, v), - cache3 = ForwardDiff.Dual{DeivVecTag}.(x, v), + cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))), + cache3 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))), ) - cache2 .= Dual{DeivVecTag}.(x, v) + cache2 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))) g(cache3, cache2) dy .= partials.(cache3, 1) end function auto_hesvecgrad(g, x, v) - partials.(g(Dual{DeivVecTag}.(x, v)), 1) + y = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))) + partials.(g(y), 1) end ### Operator Forms @@ -192,8 +194,8 @@ end function JacVec(f, x::AbstractArray; autodiff = true) if autodiff - cache1 = ForwardDiff.Dual{DeivVecTag}.(x, x) - cache2 = ForwardDiff.Dual{DeivVecTag}.(x, x) + cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(x))) + cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(x))) else cache1 = similar(x) cache2 = similar(x) @@ -261,8 +263,8 @@ end function HesVecGrad(g, x::AbstractArray; autodiff = false) if autodiff - cache1 = ForwardDiff.Dual{DeivVecTag}.(x, x) - cache2 = ForwardDiff.Dual{DeivVecTag}.(x, x) + cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(x))) + cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(x))) else cache1 = similar(x) cache2 = similar(x) diff --git a/src/differentiation/jaches_products_zygote.jl b/src/differentiation/jaches_products_zygote.jl index fa672fa9..75a7ecb4 100644 --- a/src/differentiation/jaches_products_zygote.jl +++ b/src/differentiation/jaches_products_zygote.jl @@ -24,17 +24,19 @@ function numback_hesvec(f, x, v) (gxp - gxm)/(2ϵ) end -function autoback_hesvec!(dy, f, x, v, cache2 = ForwardDiff.Dual{Nothing}.(x, v), - cache3 = ForwardDiff.Dual{Nothing}.(x, v)) +function autoback_hesvec!(dy, f, x, v, + cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))), + cache3 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))) g = let f=f (dx, x) -> dx .= first(Zygote.gradient(f,x)) end - cache2 .= Dual{Nothing}.(x, v) + cache2 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))) g(cache3,cache2) dy .= partials.(cache3, 1) end function autoback_hesvec(f, x, v) g = x -> first(Zygote.gradient(f,x)) - ForwardDiff.partials.(g(ForwardDiff.Dual{Nothing}.(x, v)), 1) + y = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))) + ForwardDiff.partials.(g(y), 1) end diff --git a/test/test_jaches_products.jl b/test/test_jaches_products.jl index c00db556..cbeb2f54 100644 --- a/test/test_jaches_products.jl +++ b/test/test_jaches_products.jl @@ -18,8 +18,8 @@ function h(dy,x) FiniteDiff.finite_difference_gradient!(dy,g,x) end -cache1 = ForwardDiff.Dual{SparseDiffTools.DeivVecTag}.(x, v) -cache2 = ForwardDiff.Dual{SparseDiffTools.DeivVecTag}.(x, v) +cache1 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(v))) +cache2 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(v))) @test num_jacvec!(dy, f, x, v) ≈ ForwardDiff.jacobian(f,similar(x),x)*v rtol=1e-6 @test num_jacvec!(dy, f, x, v, similar(v), similar(v)) ≈ ForwardDiff.jacobian(f,similar(x),x)*v rtol=1e-6 @test num_jacvec(f, x, v) ≈ ForwardDiff.jacobian(f,similar(x),x)*v rtol=1e-6 @@ -44,8 +44,8 @@ cache2 = ForwardDiff.Dual{SparseDiffTools.DeivVecTag}.(x, v) @test numback_hesvec!(dy, g, x, v, similar(v), similar(v)) ≈ ForwardDiff.hessian(g,x)*v rtol=1e-8 @test numback_hesvec(g, x, v) ≈ ForwardDiff.hessian(g,x)*v rtol=1e-8 -cache3 = ForwardDiff.Dual{Nothing}.(x, v) -cache4 = ForwardDiff.Dual{Nothing}.(x, v) +cache3 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(v))) +cache4 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(v))) @test autoback_hesvec!(dy, g, x, v) ≈ ForwardDiff.hessian(g,x)*v rtol=1e-8 @test autoback_hesvec!(dy, g, x, v, cache3, cache4) ≈ ForwardDiff.hessian(g,x)*v rtol=1e-8 @test autoback_hesvec(g, x, v) ≈ ForwardDiff.hessian(g,x)*v rtol=1e-8