Skip to content

Commit 108df97

Browse files
Merge pull request #197 from JuliaDiff/tags
Use cache tags in jacvec hesvec operations
2 parents 39ff350 + 7be687b commit 108df97

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SparseDiffTools"
22
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
33
authors = ["Pankaj Mishra <[email protected]>", "Chris Rackauckas <[email protected]>"]
4-
version = "1.26.1"
4+
version = "1.26.2"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/differentiation/jaches_products.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
struct DeivVecTag end
22

3+
get_tag(::Array{Dual{T,V,N}}) where {T,V,N} = T
4+
get_tag(::Dual{T,V,N}) where {T,V,N} = T
5+
36
# J(f(x))*v
47
function auto_jacvec!(
58
dy,
@@ -9,7 +12,7 @@ function auto_jacvec!(
912
cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag(),eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x))))),
1013
cache2 = similar(cache1),
1114
)
12-
cache1 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag(),eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x)))))
15+
cache1 .= Dual{get_tag(cache1),eltype(x),1}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x)))))
1316
f(cache2, cache1)
1417
vecdy = _vec(dy)
1518
vecdy .= partials.(_vec(cache2), 1)
@@ -135,7 +138,7 @@ function autonum_hesvec!(
135138
)
136139
cache = FiniteDiff.GradientCache(v[1], cache1, Val{:central})
137140
g = (dx, x) -> FiniteDiff.finite_difference_gradient!(dx, f, x, cache)
138-
cache1 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag(),eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x)))))
141+
cache1 .= Dual{get_tag(cache1),eltype(x),1}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x)))))
139142
g(cache2, cache1)
140143
dy .= partials.(cache2, 1)
141144
end
@@ -175,7 +178,7 @@ function auto_hesvecgrad!(
175178
cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag(),eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x))))),
176179
cache3 = Dual{typeof(ForwardDiff.Tag(DeivVecTag(),eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x))))),
177180
)
178-
cache2 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag(),eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x)))))
181+
cache2 .= Dual{get_tag(cache2),eltype(x),1}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x)))))
179182
g(cache3, cache2)
180183
dy .= partials.(cache3, 1)
181184
end

0 commit comments

Comments
 (0)