Skip to content

Commit aa798b2

Browse files
authored
add eltype for AbstractConfig (#222)
1 parent edad35c commit aa798b2

File tree

4 files changed

+21
-2
lines changed

4 files changed

+21
-2
lines changed

src/config.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ end
4040

4141
Base.copy(cfg::AbstractConfig) = deepcopy(cfg)
4242

43+
Base.eltype(cfg::AbstractConfig) = eltype(typeof(cfg))
44+
4345
@inline chunksize(::AbstractConfig{T,N}) where {T,N} = N
4446

4547
##################
@@ -60,6 +62,8 @@ function GradientConfig(::F,
6062
return GradientConfig{T,V,N,typeof(duals)}(seeds, duals)
6163
end
6264

65+
Base.eltype(::Type{GradientConfig{T,V,N,D}}) where {T,V,N,D} = Dual{T,V,N}
66+
6367
##################
6468
# JacobianConfig #
6569
##################
@@ -90,13 +94,15 @@ function JacobianConfig(::F,
9094
return JacobianConfig{T,X,N,typeof(duals)}(seeds, duals)
9195
end
9296

97+
Base.eltype(::Type{JacobianConfig{T,V,N,D}}) where {T,V,N,D} = Dual{T,V,N}
98+
9399
#################
94100
# HessianConfig #
95101
#################
96102

97103
struct HessianConfig{T,V,N,D,H,DJ} <: AbstractConfig{T,N}
98104
jacobian_config::JacobianConfig{Tag{Void,H},V,N,DJ}
99-
gradient_config::GradientConfig{T,Dual{Tag{Void,H},V,N},D}
105+
gradient_config::GradientConfig{T,Dual{Tag{Void,H},V,N},N,D}
100106
end
101107

102108
function HessianConfig(f::F,
@@ -117,3 +123,5 @@ function HessianConfig(f::F,
117123
gradient_config = GradientConfig(f, jacobian_config.duals[2], chunk, tag)
118124
return HessianConfig(jacobian_config, gradient_config)
119125
end
126+
127+
Base.eltype(::Type{HessianConfig{T,V,N,D,H,DJ}}) where {T,V,N,D,H,DJ} = Dual{T,Dual{Tag{Void,H},V,N},N}

test/GradientTest.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import Calculus
44

55
using Base.Test
66
using ForwardDiff
7+
using ForwardDiff: Dual, Tag
78
using StaticArrays
89

910
include(joinpath(dirname(@__FILE__), "utils.jl"))
@@ -21,6 +22,8 @@ for c in (1, 2, 3), tag in (nothing, f)
2122
println(" ...running hardcoded test with chunk size = $c and tag = $tag")
2223
cfg = ForwardDiff.GradientConfig(tag, x, ForwardDiff.Chunk{c}())
2324

25+
@test eltype(cfg) == Dual{typeof(Tag(typeof(tag), eltype(x))), eltype(x), c}
26+
2427
@test isapprox(g, ForwardDiff.gradient(f, x, cfg))
2528
@test isapprox(g, ForwardDiff.gradient(f, x))
2629

test/HessianTest.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import Calculus
44

55
using Base.Test
66
using ForwardDiff
7+
using ForwardDiff: Dual, Tag
78
using StaticArrays
89

910
include(joinpath(dirname(@__FILE__), "utils.jl"))
@@ -25,6 +26,10 @@ for c in (1, 2, 3), tag in (nothing, f)
2526
cfg = ForwardDiff.HessianConfig(tag, x, ForwardDiff.Chunk{c}())
2627
resultcfg = ForwardDiff.HessianConfig(tag, DiffBase.HessianResult(x), x, ForwardDiff.Chunk{c}())
2728

29+
D = Dual{typeof(Tag(Void, eltype(x))), eltype(x), c}
30+
@test eltype(cfg) == Dual{typeof(Tag(typeof(tag), Dual{Void,eltype(x),0})), D, c}
31+
@test eltype(resultcfg) == eltype(cfg)
32+
2833
@test isapprox(h, ForwardDiff.hessian(f, x))
2934
@test isapprox(h, ForwardDiff.hessian(f, x, cfg))
3035

test/JacobianTest.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import Calculus
44

55
using Base.Test
66
using ForwardDiff
7-
using ForwardDiff: JacobianConfig
7+
using ForwardDiff: Dual, Tag, JacobianConfig
88
using StaticArrays
99

1010
include(joinpath(dirname(@__FILE__), "utils.jl"))
@@ -34,6 +34,9 @@ for c in (1, 2, 3), tags in ((nothing, nothing), (f, f!))
3434
cfg = JacobianConfig(tags[1], x, ForwardDiff.Chunk{c}())
3535
ycfg = JacobianConfig(tags[2], zeros(4), x, ForwardDiff.Chunk{c}())
3636

37+
@test eltype(cfg) == Dual{typeof(Tag(typeof(tags[1]), eltype(x))), eltype(x), c}
38+
@test eltype(ycfg) == Dual{typeof(Tag(typeof(tags[2]), eltype(x))), eltype(x), c}
39+
3740
# testing f(x)
3841
@test isapprox(j, ForwardDiff.jacobian(f, x, cfg))
3942
@test isapprox(j, ForwardDiff.jacobian(f, x))

0 commit comments

Comments
 (0)