Skip to content

Commit 508a6cc

Browse files
Merge pull request #172 from JuliaDiff/duals
proper duals for JacVec
2 parents 5152045 + 7c8723e commit 508a6cc

File tree

4 files changed

+33
-24
lines changed

4 files changed

+33
-24
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SparseDiffTools"
22
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
33
authors = ["Pankaj Mishra <[email protected]>", "Chris Rackauckas <[email protected]>"]
4-
version = "1.19.1"
4+
version = "1.19.2"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/differentiation/jaches_products.jl

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,22 @@ function auto_jacvec!(
66
f,
77
x,
88
v,
9-
cache1 = Dual{DeivVecTag}.(x, reshape(v, size(x))),
9+
cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))),
1010
cache2 = similar(cache1),
1111
)
12-
cache1 .= Dual{DeivVecTag}.(x, reshape(v, size(x)))
12+
cache1 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
1313
f(cache2, cache1)
14-
dy .= partials.(cache2, 1)
14+
vecdy = _vec(dy)
15+
vecdy .= partials.(_vec(cache2), 1)
1516
end
1617

18+
_vec(v) = vec(v)
19+
_vec(v::AbstractVector) = v
20+
1721
function auto_jacvec(f, x, v)
1822
vv = reshape(v, axes(x))
19-
vec(partials.(vec(f(ForwardDiff.Dual{DeivVecTag}.(x, vv))), 1))
23+
y = ForwardDiff.Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(vv)))
24+
vec(partials.(vec(f(y)), 1))
2025
end
2126

2227
function num_jacvec!(
@@ -122,12 +127,12 @@ function autonum_hesvec!(
122127
f,
123128
x,
124129
v,
125-
cache1 = ForwardDiff.Dual{DeivVecTag}.(x, v),
126-
cache2 = ForwardDiff.Dual{DeivVecTag}.(x, v),
130+
cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))),
131+
cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))),
127132
)
128133
cache = FiniteDiff.GradientCache(v[1], cache1, Val{:central})
129134
g = (dx, x) -> FiniteDiff.finite_difference_gradient!(dx, f, x, cache)
130-
cache1 .= Dual{DeivVecTag}.(x, v)
135+
cache1 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
131136
g(cache2, cache1)
132137
dy .= partials.(cache2, 1)
133138
end
@@ -164,16 +169,17 @@ function auto_hesvecgrad!(
164169
g,
165170
x,
166171
v,
167-
cache2 = ForwardDiff.Dual{DeivVecTag}.(x, v),
168-
cache3 = ForwardDiff.Dual{DeivVecTag}.(x, v),
172+
cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))),
173+
cache3 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))),
169174
)
170-
cache2 .= Dual{DeivVecTag}.(x, v)
175+
cache2 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
171176
g(cache3, cache2)
172177
dy .= partials.(cache3, 1)
173178
end
174179

175180
function auto_hesvecgrad(g, x, v)
176-
partials.(g(Dual{DeivVecTag}.(x, v)), 1)
181+
y = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
182+
partials.(g(y), 1)
177183
end
178184

179185
### Operator Forms
@@ -188,15 +194,16 @@ end
188194

189195
function JacVec(f, x::AbstractArray; autodiff = true)
190196
if autodiff
191-
cache1 = ForwardDiff.Dual{DeivVecTag}.(x, x)
192-
cache2 = ForwardDiff.Dual{DeivVecTag}.(x, x)
197+
cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(x)))
198+
cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(x)))
193199
else
194200
cache1 = similar(x)
195201
cache2 = similar(x)
196202
end
197203
JacVec(f, cache1, cache2, x, autodiff)
198204
end
199205

206+
Base.eltype(L::JacVec) = eltype(L.x)
200207
Base.size(L::JacVec) = (length(L.cache1), length(L.cache1))
201208
Base.size(L::JacVec, i::Int) = length(L.cache1)
202209
Base.:*(L::JacVec, v::AbstractVector) =
@@ -256,8 +263,8 @@ end
256263

257264
function HesVecGrad(g, x::AbstractArray; autodiff = false)
258265
if autodiff
259-
cache1 = ForwardDiff.Dual{DeivVecTag}.(x, x)
260-
cache2 = ForwardDiff.Dual{DeivVecTag}.(x, x)
266+
cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(x)))
267+
cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(x)))
261268
else
262269
cache1 = similar(x)
263270
cache2 = similar(x)

src/differentiation/jaches_products_zygote.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,19 @@ function numback_hesvec(f, x, v)
2424
(gxp - gxm)/(2ϵ)
2525
end
2626

27-
function autoback_hesvec!(dy, f, x, v, cache2 = ForwardDiff.Dual{Nothing}.(x, v),
28-
cache3 = ForwardDiff.Dual{Nothing}.(x, v))
27+
function autoback_hesvec!(dy, f, x, v,
28+
cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))),
29+
cache3 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))))
2930
g = let f=f
3031
(dx, x) -> dx .= first(Zygote.gradient(f,x))
3132
end
32-
cache2 .= Dual{Nothing}.(x, v)
33+
cache2 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
3334
g(cache3,cache2)
3435
dy .= partials.(cache3, 1)
3536
end
3637

3738
function autoback_hesvec(f, x, v)
3839
g = x -> first(Zygote.gradient(f,x))
39-
ForwardDiff.partials.(g(ForwardDiff.Dual{Nothing}.(x, v)), 1)
40+
y = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
41+
ForwardDiff.partials.(g(y), 1)
4042
end

test/test_jaches_products.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ function h(dy,x)
1818
FiniteDiff.finite_difference_gradient!(dy,g,x)
1919
end
2020

21-
cache1 = ForwardDiff.Dual{SparseDiffTools.DeivVecTag}.(x, v)
22-
cache2 = ForwardDiff.Dual{SparseDiffTools.DeivVecTag}.(x, v)
21+
cache1 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(v)))
22+
cache2 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(v)))
2323
@test num_jacvec!(dy, f, x, v) ForwardDiff.jacobian(f,similar(x),x)*v rtol=1e-6
2424
@test num_jacvec!(dy, f, x, v, similar(v), similar(v)) ForwardDiff.jacobian(f,similar(x),x)*v rtol=1e-6
2525
@test num_jacvec(f, x, v) ForwardDiff.jacobian(f,similar(x),x)*v rtol=1e-6
@@ -44,8 +44,8 @@ cache2 = ForwardDiff.Dual{SparseDiffTools.DeivVecTag}.(x, v)
4444
@test numback_hesvec!(dy, g, x, v, similar(v), similar(v)) ForwardDiff.hessian(g,x)*v rtol=1e-8
4545
@test numback_hesvec(g, x, v) ForwardDiff.hessian(g,x)*v rtol=1e-8
4646

47-
cache3 = ForwardDiff.Dual{Nothing}.(x, v)
48-
cache4 = ForwardDiff.Dual{Nothing}.(x, v)
47+
cache3 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(v)))
48+
cache4 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(v)))
4949
@test autoback_hesvec!(dy, g, x, v) ForwardDiff.hessian(g,x)*v rtol=1e-8
5050
@test autoback_hesvec!(dy, g, x, v, cache3, cache4) ForwardDiff.hessian(g,x)*v rtol=1e-8
5151
@test autoback_hesvec(g, x, v) ForwardDiff.hessian(g,x)*v rtol=1e-8

0 commit comments

Comments
 (0)