Skip to content

Commit 7c8723e

Browse files
fix all Dual definitions
1 parent 9690d71 commit 7c8723e

File tree

3 files changed

+27
-23
lines changed

3 files changed

+27
-23
lines changed

src/differentiation/jaches_products.jl

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,22 @@ function auto_jacvec!(
66
f,
77
x,
88
v,
9-
cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(reshape(v, size(x)))),
9+
cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))),
1010
cache2 = similar(cache1),
1111
)
12-
cache1 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(reshape(v, size(x))))
12+
cache1 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
1313
f(cache2, cache1)
1414
vecdy = _vec(dy)
15-
vecdy .= partials.(vec(cache2), 1)
15+
vecdy .= partials.(_vec(cache2), 1)
1616
end
1717

1818
_vec(v) = vec(v)
1919
_vec(v::AbstractVector) = v
2020

2121
function auto_jacvec(f, x, v)
2222
vv = reshape(v, axes(x))
23-
vec(partials.(vec(f(ForwardDiff.Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(vv)))), 1))
23+
y = ForwardDiff.Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(vv)))
24+
vec(partials.(vec(f(y)), 1))
2425
end
2526

2627
function num_jacvec!(
@@ -126,12 +127,12 @@ function autonum_hesvec!(
126127
f,
127128
x,
128129
v,
129-
cache1 = ForwardDiff.Dual{DeivVecTag}.(x, v),
130-
cache2 = ForwardDiff.Dual{DeivVecTag}.(x, v),
130+
cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))),
131+
cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))),
131132
)
132133
cache = FiniteDiff.GradientCache(v[1], cache1, Val{:central})
133134
g = (dx, x) -> FiniteDiff.finite_difference_gradient!(dx, f, x, cache)
134-
cache1 .= Dual{DeivVecTag}.(x, v)
135+
cache1 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
135136
g(cache2, cache1)
136137
dy .= partials.(cache2, 1)
137138
end
@@ -168,16 +169,17 @@ function auto_hesvecgrad!(
168169
g,
169170
x,
170171
v,
171-
cache2 = ForwardDiff.Dual{DeivVecTag}.(x, v),
172-
cache3 = ForwardDiff.Dual{DeivVecTag}.(x, v),
172+
cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))),
173+
cache3 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))),
173174
)
174-
cache2 .= Dual{DeivVecTag}.(x, v)
175+
cache2 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
175176
g(cache3, cache2)
176177
dy .= partials.(cache3, 1)
177178
end
178179

179180
function auto_hesvecgrad(g, x, v)
180-
partials.(g(Dual{DeivVecTag}.(x, v)), 1)
181+
y = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
182+
partials.(g(y), 1)
181183
end
182184

183185
### Operator Forms
@@ -192,8 +194,8 @@ end
192194

193195
function JacVec(f, x::AbstractArray; autodiff = true)
194196
if autodiff
195-
cache1 = ForwardDiff.Dual{DeivVecTag}.(x, x)
196-
cache2 = ForwardDiff.Dual{DeivVecTag}.(x, x)
197+
cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(x)))
198+
cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(x)))
197199
else
198200
cache1 = similar(x)
199201
cache2 = similar(x)
@@ -261,8 +263,8 @@ end
261263

262264
function HesVecGrad(g, x::AbstractArray; autodiff = false)
263265
if autodiff
264-
cache1 = ForwardDiff.Dual{DeivVecTag}.(x, x)
265-
cache2 = ForwardDiff.Dual{DeivVecTag}.(x, x)
266+
cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(x)))
267+
cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(x)))
266268
else
267269
cache1 = similar(x)
268270
cache2 = similar(x)

src/differentiation/jaches_products_zygote.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,19 @@ function numback_hesvec(f, x, v)
2424
(gxp - gxm)/(2ϵ)
2525
end
2626

27-
function autoback_hesvec!(dy, f, x, v, cache2 = ForwardDiff.Dual{Nothing}.(x, v),
28-
cache3 = ForwardDiff.Dual{Nothing}.(x, v))
27+
function autoback_hesvec!(dy, f, x, v,
28+
cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))),
29+
cache3 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))))
2930
g = let f=f
3031
(dx, x) -> dx .= first(Zygote.gradient(f,x))
3132
end
32-
cache2 .= Dual{Nothing}.(x, v)
33+
cache2 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
3334
g(cache3,cache2)
3435
dy .= partials.(cache3, 1)
3536
end
3637

3738
function autoback_hesvec(f, x, v)
3839
g = x -> first(Zygote.gradient(f,x))
39-
ForwardDiff.partials.(g(ForwardDiff.Dual{Nothing}.(x, v)), 1)
40+
y = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
41+
ForwardDiff.partials.(g(y), 1)
4042
end

test/test_jaches_products.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ function h(dy,x)
1818
FiniteDiff.finite_difference_gradient!(dy,g,x)
1919
end
2020

21-
cache1 = ForwardDiff.Dual{SparseDiffTools.DeivVecTag}.(x, v)
22-
cache2 = ForwardDiff.Dual{SparseDiffTools.DeivVecTag}.(x, v)
21+
cache1 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(v)))
22+
cache2 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(v)))
2323
@test num_jacvec!(dy, f, x, v) ForwardDiff.jacobian(f,similar(x),x)*v rtol=1e-6
2424
@test num_jacvec!(dy, f, x, v, similar(v), similar(v)) ForwardDiff.jacobian(f,similar(x),x)*v rtol=1e-6
2525
@test num_jacvec(f, x, v) ForwardDiff.jacobian(f,similar(x),x)*v rtol=1e-6
@@ -44,8 +44,8 @@ cache2 = ForwardDiff.Dual{SparseDiffTools.DeivVecTag}.(x, v)
4444
@test numback_hesvec!(dy, g, x, v, similar(v), similar(v)) ForwardDiff.hessian(g,x)*v rtol=1e-8
4545
@test numback_hesvec(g, x, v) ForwardDiff.hessian(g,x)*v rtol=1e-8
4646

47-
cache3 = ForwardDiff.Dual{Nothing}.(x, v)
48-
cache4 = ForwardDiff.Dual{Nothing}.(x, v)
47+
cache3 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(v)))
48+
cache4 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(v)))
4949
@test autoback_hesvec!(dy, g, x, v) ForwardDiff.hessian(g,x)*v rtol=1e-8
5050
@test autoback_hesvec!(dy, g, x, v, cache3, cache4) ForwardDiff.hessian(g,x)*v rtol=1e-8
5151
@test autoback_hesvec(g, x, v) ForwardDiff.hessian(g,x)*v rtol=1e-8

0 commit comments

Comments
 (0)