diff --git a/Project.toml b/Project.toml index 79ac2698..a043d2cb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SparseDiffTools" uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804" authors = ["Pankaj Mishra ", "Chris Rackauckas "] -version = "1.19.1" +version = "1.19.2" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/differentiation/jaches_products.jl b/src/differentiation/jaches_products.jl index b35b3778..a2b5d48a 100644 --- a/src/differentiation/jaches_products.jl +++ b/src/differentiation/jaches_products.jl @@ -6,17 +6,22 @@ function auto_jacvec!( f, x, v, - cache1 = Dual{DeivVecTag}.(x, reshape(v, size(x))), + cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))), cache2 = similar(cache1), ) - cache1 .= Dual{DeivVecTag}.(x, reshape(v, size(x))) + cache1 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))) f(cache2, cache1) - dy .= partials.(cache2, 1) + vecdy = _vec(dy) + vecdy .= partials.(_vec(cache2), 1) end +_vec(v) = vec(v) +_vec(v::AbstractVector) = v + function auto_jacvec(f, x, v) vv = reshape(v, axes(x)) - vec(partials.(vec(f(ForwardDiff.Dual{DeivVecTag}.(x, vv))), 1)) + y = ForwardDiff.Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(vv))) + vec(partials.(vec(f(y)), 1)) end function num_jacvec!( @@ -122,12 +127,12 @@ function autonum_hesvec!( f, x, v, - cache1 = ForwardDiff.Dual{DeivVecTag}.(x, v), - cache2 = ForwardDiff.Dual{DeivVecTag}.(x, v), + cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))), + cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))), ) cache = FiniteDiff.GradientCache(v[1], cache1, Val{:central}) g = (dx, x) -> FiniteDiff.finite_difference_gradient!(dx, f, x, cache) - cache1 .= Dual{DeivVecTag}.(x, v) + cache1 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))) g(cache2, cache1) dy .= partials.(cache2, 1) end @@ -164,16 +169,17 @@ function auto_hesvecgrad!( g, x, v, - cache2 = ForwardDiff.Dual{DeivVecTag}.(x, v), - cache3 = ForwardDiff.Dual{DeivVecTag}.(x, v), + cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))), + cache3 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))), ) - cache2 .= Dual{DeivVecTag}.(x, v) + cache2 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))) g(cache3, cache2) dy .= partials.(cache3, 1) end function auto_hesvecgrad(g, x, v) - partials.(g(Dual{DeivVecTag}.(x, v)), 1) + y = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))) + partials.(g(y), 1) end ### Operator Forms @@ -188,8 +194,8 @@ end function JacVec(f, x::AbstractArray; autodiff = true) if autodiff - cache1 = ForwardDiff.Dual{DeivVecTag}.(x, x) - cache2 = ForwardDiff.Dual{DeivVecTag}.(x, x) + cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(x))) + cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(x))) else cache1 = similar(x) cache2 = similar(x) @@ -197,6 +203,7 @@ function JacVec(f, x::AbstractArray; autodiff = true) JacVec(f, cache1, cache2, x, autodiff) end +Base.eltype(L::JacVec) = eltype(L.x) Base.size(L::JacVec) = (length(L.cache1), length(L.cache1)) Base.size(L::JacVec, i::Int) = length(L.cache1) Base.:*(L::JacVec, v::AbstractVector) = @@ -256,8 +263,8 @@ end function HesVecGrad(g, x::AbstractArray; autodiff = false) if autodiff - cache1 = ForwardDiff.Dual{DeivVecTag}.(x, x) - cache2 = ForwardDiff.Dual{DeivVecTag}.(x, x) + cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(x))) + cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(x))) else cache1 = similar(x) cache2 = similar(x) diff --git a/src/differentiation/jaches_products_zygote.jl b/src/differentiation/jaches_products_zygote.jl index fa672fa9..75a7ecb4 100644 --- a/src/differentiation/jaches_products_zygote.jl +++ b/src/differentiation/jaches_products_zygote.jl @@ -24,17 +24,19 @@ function numback_hesvec(f, x, v) (gxp - gxm)/(2ϵ) end -function autoback_hesvec!(dy, f, x, v, cache2 = ForwardDiff.Dual{Nothing}.(x, v), - cache3 = ForwardDiff.Dual{Nothing}.(x, v)) +function autoback_hesvec!(dy, f, x, v, + cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))), + cache3 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))) g = let f=f (dx, x) -> dx .= first(Zygote.gradient(f,x)) end - cache2 .= Dual{Nothing}.(x, v) + cache2 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))) g(cache3,cache2) dy .= partials.(cache3, 1) end function autoback_hesvec(f, x, v) g = x -> first(Zygote.gradient(f,x)) - ForwardDiff.partials.(g(ForwardDiff.Dual{Nothing}.(x, v)), 1) + y = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))) + ForwardDiff.partials.(g(y), 1) end diff --git a/test/test_jaches_products.jl b/test/test_jaches_products.jl index c00db556..cbeb2f54 100644 --- a/test/test_jaches_products.jl +++ b/test/test_jaches_products.jl @@ -18,8 +18,8 @@ function h(dy,x) FiniteDiff.finite_difference_gradient!(dy,g,x) end -cache1 = ForwardDiff.Dual{SparseDiffTools.DeivVecTag}.(x, v) -cache2 = ForwardDiff.Dual{SparseDiffTools.DeivVecTag}.(x, v) +cache1 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(v))) +cache2 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(v))) @test num_jacvec!(dy, f, x, v) ≈ ForwardDiff.jacobian(f,similar(x),x)*v rtol=1e-6 @test num_jacvec!(dy, f, x, v, similar(v), similar(v)) ≈ ForwardDiff.jacobian(f,similar(x),x)*v rtol=1e-6 @test num_jacvec(f, x, v) ≈ ForwardDiff.jacobian(f,similar(x),x)*v rtol=1e-6 @@ -44,8 +44,8 @@ cache2 = ForwardDiff.Dual{SparseDiffTools.DeivVecTag}.(x, v) @test numback_hesvec!(dy, g, x, v, similar(v), similar(v)) ≈ ForwardDiff.hessian(g,x)*v rtol=1e-8 @test numback_hesvec(g, x, v) ≈ ForwardDiff.hessian(g,x)*v rtol=1e-8 -cache3 = ForwardDiff.Dual{Nothing}.(x, v) -cache4 = ForwardDiff.Dual{Nothing}.(x, v) +cache3 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(v))) +cache4 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(v))) @test autoback_hesvec!(dy, g, x, v) ≈ ForwardDiff.hessian(g,x)*v rtol=1e-8 @test autoback_hesvec!(dy, g, x, v, cache3, cache4) ≈ ForwardDiff.hessian(g,x)*v rtol=1e-8 @test autoback_hesvec(g, x, v) ≈ ForwardDiff.hessian(g,x)*v rtol=1e-8