From 7be687b0e23b63a87676249eb33704af82ccef47 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Tue, 23 Aug 2022 09:47:24 -0400 Subject: [PATCH] Use cache tags in jacvec hesvec operations --- Project.toml | 2 +- src/differentiation/jaches_products.jl | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index a41ed7ef..d40435b9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SparseDiffTools" uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804" authors = ["Pankaj Mishra ", "Chris Rackauckas "] -version = "1.26.1" +version = "1.26.2" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/differentiation/jaches_products.jl b/src/differentiation/jaches_products.jl index d3627d82..21ae0c12 100644 --- a/src/differentiation/jaches_products.jl +++ b/src/differentiation/jaches_products.jl @@ -1,5 +1,8 @@ struct DeivVecTag end +get_tag(::Array{Dual{T,V,N}}) where {T,V,N} = T +get_tag(::Dual{T,V,N}) where {T,V,N} = T + # J(f(x))*v function auto_jacvec!( dy, @@ -9,7 +12,7 @@ function auto_jacvec!( cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag(),eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x))))), cache2 = similar(cache1), ) - cache1 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag(),eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x))))) + cache1 .= Dual{get_tag(cache1),eltype(x),1}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x))))) f(cache2, cache1) vecdy = _vec(dy) vecdy .= partials.(_vec(cache2), 1) @@ -135,7 +138,7 @@ function autonum_hesvec!( ) cache = FiniteDiff.GradientCache(v[1], cache1, Val{:central}) g = (dx, x) -> FiniteDiff.finite_difference_gradient!(dx, f, x, cache) - cache1 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag(),eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x))))) + cache1 .= Dual{get_tag(cache1),eltype(x),1}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x))))) g(cache2, cache1) dy .= partials.(cache2, 1) end @@ -175,7 +178,7 @@ function auto_hesvecgrad!( cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag(),eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x))))), cache3 = Dual{typeof(ForwardDiff.Tag(DeivVecTag(),eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x))))), ) - cache2 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag(),eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x))))) + cache2 .= Dual{get_tag(cache2),eltype(x),1}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x))))) g(cache3, cache2) dy .= partials.(cache3, 1) end