From 56f70c5b6cd71cc3f5c3cc75f66364b3e932b8ca Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Mon, 27 Jul 2020 07:33:54 -0500 Subject: [PATCH 1/2] Allow value and partials to have distinct types --- Project.toml | 2 +- src/apiutils.jl | 26 ++++++------- src/config.jl | 14 +++---- src/dual.jl | 96 ++++++++++++++++++++++++++++++------------------ test/DualTest.jl | 32 ++++++++-------- 5 files changed, 98 insertions(+), 72 deletions(-) diff --git a/Project.toml b/Project.toml index 4e117b7d..e666aa37 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ForwardDiff" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.12" +version = "0.11.0" [deps] CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950" diff --git a/src/apiutils.jl b/src/apiutils.jl index ae5d78ad..830872a6 100644 --- a/src/apiutils.jl +++ b/src/apiutils.jl @@ -21,7 +21,7 @@ end @generated function dualize(::Type{T}, x::StaticArray) where T N = length(x) dx = Expr(:tuple, [:(Dual{T}(x[$i], chunk, Val{$i}())) for i in 1:N]...) - V = StaticArrays.similar_type(x, Dual{T,eltype(x),N}) + V = StaticArrays.similar_type(x, Dual{T,eltype(x),N,eltype(x)}) return quote chunk = Chunk{$N}() $(Expr(:meta, :inline)) @@ -53,38 +53,38 @@ end return Expr(:tuple, [:(single_seed(Partials{N,V}, Val{$i}())) for i in 1:N]...) end -function seed!(duals::AbstractArray{Dual{T,V,N}}, x, - seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N} +function seed!(duals::AbstractArray{Dual{T,V,N,P}}, x, + seed::Partials{N,P} = zero(Partials{N,P})) where {T,V,N,P} for i in eachindex(duals) - duals[i] = Dual{T,V,N}(x[i], seed) + duals[i] = Dual{T,V,N,P}(x[i], seed) end return duals end -function seed!(duals::AbstractArray{Dual{T,V,N}}, x, - seeds::NTuple{N,Partials{N,V}}) where {T,V,N} +function seed!(duals::AbstractArray{Dual{T,V,N,P}}, x, + seeds::NTuple{N,Partials{N,P}}) where {T,V,N,P} for i in 1:N - duals[i] = Dual{T,V,N}(x[i], seeds[i]) + duals[i] = Dual{T,V,N,P}(x[i], seeds[i]) end return duals end -function seed!(duals::AbstractArray{Dual{T,V,N}}, x, index, - seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N} +function seed!(duals::AbstractArray{Dual{T,V,N,P}}, x, index, + seed::Partials{N,P} = zero(Partials{N,P})) where {T,V,N,P} offset = index - 1 for i in 1:N j = i + offset - duals[j] = Dual{T,V,N}(x[j], seed) + duals[j] = Dual{T,V,N,P}(x[j], seed) end return duals end -function seed!(duals::AbstractArray{Dual{T,V,N}}, x, index, - seeds::NTuple{N,Partials{N,V}}, chunksize = N) where {T,V,N} +function seed!(duals::AbstractArray{Dual{T,V,N,P}}, x, index, + seeds::NTuple{N,Partials{N,P}}, chunksize = N) where {T,V,N,P} offset = index - 1 for i in 1:chunksize j = i + offset - duals[j] = Dual{T,V,N}(x[j], seeds[i]) + duals[j] = Dual{T,V,N,P}(x[j], seeds[i]) end return duals end diff --git a/src/config.jl b/src/config.jl index cb23c201..594507d6 100644 --- a/src/config.jl +++ b/src/config.jl @@ -83,7 +83,7 @@ function DerivativeConfig(f::F, y::AbstractArray{Y}, x::X, tag::T = Tag(f, X)) where {F,X<:Real,Y<:Real,T} - duals = similar(y, Dual{T,Y,1}) + duals = similar(y, Dual{T,Y,1,Y}) return DerivativeConfig{T,typeof(duals)}(duals) end @@ -119,7 +119,7 @@ function GradientConfig(f::F, ::Chunk{N} = Chunk(x), ::T = Tag(f, V)) where {F,V,N,T} seeds = construct_seeds(Partials{N,V}) - duals = similar(x, Dual{T,V,N}) + duals = similar(x, Dual{T,V,N,V}) return GradientConfig{T,V,N,typeof(duals)}(seeds, duals) end @@ -156,7 +156,7 @@ function JacobianConfig(f::F, ::Chunk{N} = Chunk(x), ::T = Tag(f, V)) where {F,V,N,T} seeds = construct_seeds(Partials{N,V}) - duals = similar(x, Dual{T,V,N}) + duals = similar(x, Dual{T,V,N,V}) return JacobianConfig{T,V,N,typeof(duals)}(seeds, duals) end @@ -182,8 +182,8 @@ function JacobianConfig(f::F, ::Chunk{N} = Chunk(x), ::T = Tag(f, X)) where {F,Y,X,N,T} seeds = construct_seeds(Partials{N,X}) - yduals = similar(y, Dual{T,Y,N}) - xduals = similar(x, Dual{T,X,N}) + yduals = similar(y, Dual{T,Y,N,Y}) + xduals = similar(x, Dual{T,X,N,X}) duals = (yduals, xduals) return JacobianConfig{T,X,N,typeof(duals)}(seeds, duals) end @@ -197,7 +197,7 @@ Base.eltype(::Type{JacobianConfig{T,V,N,D}}) where {T,V,N,D} = Dual{T,V,N} struct HessianConfig{T,V,N,DG,DJ} <: AbstractConfig{N} jacobian_config::JacobianConfig{T,V,N,DJ} - gradient_config::GradientConfig{T,Dual{T,V,N},N,DG} + gradient_config::GradientConfig{T,Dual{T,V,N,V},N,DG} end """ @@ -254,4 +254,4 @@ end checktag(::HessianConfig{T},f,x) where {T} = checktag(T,f,x) Base.eltype(::Type{HessianConfig{T,V,N,DG,DJ}}) where {T,V,N,DG,DJ} = - Dual{T,Dual{T,V,N},N} + Dual{T,Dual{T,V,N,V},N,Dual{T,V,N,V}} diff --git a/src/dual.jl b/src/dual.jl index bec1e83d..28790b44 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -11,14 +11,16 @@ Dual. By default, only `<:Real` types are allowed. can_dual(::Type{<:Real}) = true can_dual(::Type) = false -struct Dual{T,V,N} <: Real +struct Dual{T,V,N,P} <: Real value::V - partials::Partials{N,V} - function Dual{T, V, N}(value::V, partials::Partials{N, V}) where {T, V, N} + partials::Partials{N,P} + function Dual{T, V, N, P}(value::V, partials::Partials{N, P}) where {T, V, N, P} can_dual(V) || throw_cannot_dual(V) - new{T, V, N}(value, partials) + can_dual(P) || throw_cannot_dual(P) + new{T, V, N, P}(value, partials) end end +@inline Dual{T,V,N}(value, partials::Partials{N,P}) where {T,V,N,P} = Dual{T,V,N,P}(value, partials) ############## # Exceptions # @@ -52,9 +54,10 @@ tag can be extracted, so it should be used in the _innermost_ function. # Constructors # ################ -@inline Dual{T}(value::V, partials::Partials{N,V}) where {T,N,V} = Dual{T,V,N}(value, partials) - -@inline function Dual{T}(value::A, partials::Partials{N,B}) where {T,N,A,B} +@inline Dual{T}(value::V, partials::Partials{N,V}) where {T,N,V<:Real} = Dual{T,V,N,V}(value, partials) # ambiguity resolution +@inline Dual{T}(value::V, partials::Partials{N,V}) where {T,N,V} = Dual{T,V,N,V}(value, partials) +@inline Dual{T}(value::V, partials::Partials{N,P}) where {T,N,V,P} = Dual{T,V,N,P}(value, partials) +@inline function Dual{T}(value::A, partials::Partials{N,B}) where {T,N,A<:Real,B<:Real} C = promote_type(A, B) return Dual{T}(convert(C, value), convert(Partials{N,C}, partials)) end @@ -68,8 +71,9 @@ end @inline Dual(args...) = Dual{Nothing}(args...) # we define these special cases so that the "constructor <--> convert" pun holds for `Dual` -@inline Dual{T,V,N}(x::Dual{T,V,N}) where {T,V,N} = x -@inline Dual{T,V,N}(x) where {T,V,N} = convert(Dual{T,V,N}, x) +@inline Dual{T,V,N,P}(x::Dual{T,V,N,P}) where {T,V,N,P} = x +@inline Dual{T,V,N,P}(x) where {T,V,N,P} = convert(Dual{T,V,N,P}, x) +@inline Dual{T,V,N,P}(x::Number) where {T,V,N,P} = convert(Dual{T,V,N,P}, x) @inline Dual{T,V,N}(x::Number) where {T,V,N} = convert(Dual{T,V,N}, x) @inline Dual{T,V}(x) where {T,V} = convert(Dual{T,V}, x) @@ -109,15 +113,21 @@ end @inline npartials(::Dual{T,V,N}) where {T,V,N} = N -@inline npartials(::Type{Dual{T,V,N}}) where {T,V,N} = N +@inline npartials(::Type{Dual{T,V,N,P}}) where {T,V,N,P} = N @inline order(::Type{V}) where {V} = 0 -@inline order(::Type{Dual{T,V,N}}) where {T,V,N} = 1 + order(V) +@inline order(::Type{Dual{T,V,N,P}}) where {T,V,N,P} = 1 + order(V) @inline valtype(::V) where {V} = V @inline valtype(::Type{V}) where {V} = V @inline valtype(::Dual{T,V,N}) where {T,V,N} = V @inline valtype(::Type{Dual{T,V,N}}) where {T,V,N} = V +@inline valtype(::Type{Dual{T,V,N,P}}) where {T,V,N,P} = V + +@inline partialtype(::V) where {V} = V +@inline partialtype(::Type{V}) where {V} = V +@inline partialtype(::Dual{T,V,N,P}) where {T,V,N,P} = P +@inline partialtype(::Type{Dual{T,V,N,P}}) where {T,V,N,P} = P @inline tagtype(::V) where {V} = Nothing @inline tagtype(::Type{V}) where {V} = Nothing @@ -282,10 +292,10 @@ Base.round(d::Dual) = round(value(d)) Base.hash(d::Dual) = hash(value(d)) Base.hash(d::Dual, hsh::UInt) = hash(value(d), hsh) -function Base.read(io::IO, ::Type{Dual{T,V,N}}) where {T,V,N} +function Base.read(io::IO, ::Type{Dual{T,V,N,P}}) where {T,V,N,P} value = read(io, V) - partials = read(io, Partials{N,V}) - return Dual{T,V,N}(value, partials) + partials = read(io, Partials{N,P}) + return Dual{T,V,N,P}(value, partials) end function Base.write(io::IO, d::Dual) @@ -294,18 +304,24 @@ function Base.write(io::IO, d::Dual) end @inline Base.zero(d::Dual) = zero(typeof(d)) -@inline Base.zero(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(zero(V), zero(Partials{N,V})) +@inline Base.zero(::Type{Dual{T,V,N,P}}) where {T,V,N,P} = Dual{T}(zero(V), zero(Partials{N,P})) +@inline Base.zero(::Type{Dual{T,V,N}}) where {T,V,N} = zero(Dual{T,V,N,V}) @inline Base.one(d::Dual) = one(typeof(d)) -@inline Base.one(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(one(V), zero(Partials{N,V})) +@inline Base.one(::Type{Dual{T,V,N,P}}) where {T,V,N,P} = Dual{T}(one(V), zero(Partials{N,P})) +@inline Base.one(::Type{Dual{T,V,N}}) where {T,V,N} = one(Dual{T,V,N,V}) @inline Random.rand(rng::AbstractRNG, d::Dual) = rand(rng, value(d)) -@inline Random.rand(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(rand(V), zero(Partials{N,V})) -@inline Random.rand(rng::AbstractRNG, ::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(rand(rng, V), zero(Partials{N,V})) -@inline Random.randn(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(randn(V), zero(Partials{N,V})) -@inline Random.randn(rng::AbstractRNG, ::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(randn(rng, V), zero(Partials{N,V})) -@inline Random.randexp(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(randexp(V), zero(Partials{N,V})) -@inline Random.randexp(rng::AbstractRNG, ::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(randexp(rng, V), zero(Partials{N,V})) +@inline Random.rand(::Type{Dual{T,V,N,P}}) where {T,V,N,P} = Dual{T}(rand(V), zero(Partials{N,P})) +@inline Random.rand(rng::AbstractRNG, ::Type{Dual{T,V,N,P}}) where {T,V,N,P} = Dual{T}(rand(rng, V), zero(Partials{N,P})) +@inline Random.randn(::Type{Dual{T,V,N,P}}) where {T,V,N,P} = Dual{T}(randn(V), zero(Partials{N,P})) +@inline Random.randn(rng::AbstractRNG, ::Type{Dual{T,V,N,P}}) where {T,V,N,P} = Dual{T}(randn(rng, V), zero(Partials{N,P})) +@inline Random.randexp(::Type{Dual{T,V,N,P}}) where {T,V,N,P} = Dual{T}(randexp(V), zero(Partials{N,P})) +@inline Random.randexp(rng::AbstractRNG, ::Type{Dual{T,V,N,P}}) where {T,V,N,P} = Dual{T}(randexp(rng, V), zero(Partials{N,P})) + +@inline Base.zero(::Type{Partials{N,Dual{T,V,M}}}) where {N,T,V,M} = zero(Partials{N,Dual{T,V,M,V}}) +@inline Base.one(::Type{Partials{N,Dual{T,V,M}}}) where {N,T,V,M} = one(Partials{N,Dual{T,V,M,V}}) + # Predicates # #------------# @@ -331,35 +347,45 @@ end # Promotion/Conversion # ######################## -Base.@pure function Base.promote_rule(::Type{Dual{T1,V1,N1}}, - ::Type{Dual{T2,V2,N2}}) where {T1,V1,N1,T2,V2,N2} +Base.@pure function Base.promote_rule(::Type{Dual{T1,V1,N1,P1}}, + ::Type{Dual{T2,V2,N2,P2}}) where {T1,V1,N1,P1,T2,V2,N2,P2} # V1 and V2 might themselves be Dual types if T2 ≺ T1 - Dual{T1,promote_type(V1,Dual{T2,V2,N2}),N1} + Dual{T1,promote_type(V1,Dual{T2,V2,N2,P2}),N1,P1} else - Dual{T2,promote_type(V2,Dual{T1,V1,N1}),N2} + Dual{T2,promote_type(V2,Dual{T1,V1,N1,P1}),N2,P2} end end +function Base.promote_rule(::Type{Dual{T,A,N,PA}}, + ::Type{Dual{T,B,N,PB}}) where {T,A,B,PA,PB,N} + return Dual{T,promote_type(A, B),N,promote_type(PA, PB)} +end function Base.promote_rule(::Type{Dual{T,A,N}}, ::Type{Dual{T,B,N}}) where {T,A,B,N} - return Dual{T,promote_type(A, B),N} + return Dual{T,promote_type(A, B),N,promote_type(A, B)} end for R in (Irrational, Real, BigFloat, Bool) if isconcretetype(R) # issue #322 @eval begin - Base.promote_rule(::Type{$R}, ::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T,promote_type($R, V),N} - Base.promote_rule(::Type{Dual{T,V,N}}, ::Type{$R}) where {T,V,N} = Dual{T,promote_type(V, $R),N} + Base.promote_rule(::Type{$R}, ::Type{Dual{T,V,N,P}}) where {T,V,N,P} = Dual{T,promote_type($R, V),N,promote_type($R, P)} + Base.promote_rule(::Type{$R}, ::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T,promote_type($R, V),N,promote_type($R, V)} end else @eval begin - Base.promote_rule(::Type{R}, ::Type{Dual{T,V,N}}) where {R<:$R,T,V,N} = Dual{T,promote_type(R, V),N} - Base.promote_rule(::Type{Dual{T,V,N}}, ::Type{R}) where {T,V,N,R<:$R} = Dual{T,promote_type(V, R),N} + Base.promote_rule(::Type{R}, ::Type{Dual{T,V,N,P}}) where {R<:$R,T,V,N,P} = Dual{T,promote_type(R, V),N,promote_type(R, P)} + Base.promote_rule(::Type{R}, ::Type{Dual{T,V,N}}) where {R<:$R,T,V,N} = Dual{T,promote_type(R, V),N,promote_type(R, V)} end end end +Base.convert(::Type{Partials{N,Dual{T,V,M}}}, partials::Partials) where {N,T,V,M} = + convert(Partials{N,Dual{T,V,M,V}}, partials) + +Base.convert(::Type{Dual{T,V,N,P}}, d::Dual{T}) where {T,V,N,P} = Dual{T}(convert(V, value(d)), convert(Partials{N,P}, partials(d))) +Base.convert(::Type{Dual{T,V,N,P}}, x) where {T,V,N,P} = Dual{T}(convert(V, x), zero(Partials{N,P})) +Base.convert(::Type{Dual{T,V,N,P}}, x::Number) where {T,V,N,P} = Dual{T}(convert(V, x), zero(Partials{N,P})) Base.convert(::Type{Dual{T,V,N}}, d::Dual{T}) where {T,V,N} = Dual{T}(convert(V, value(d)), convert(Partials{N,V}, partials(d))) Base.convert(::Type{Dual{T,V,N}}, x) where {T,V,N} = Dual{T}(convert(V, x), zero(Partials{N,V})) Base.convert(::Type{Dual{T,V,N}}, x::Number) where {T,V,N} = Dual{T}(convert(V, x), zero(Partials{N,V})) @@ -621,10 +647,10 @@ function Base.show(io::IO, d::Dual{T,V,N}) where {T,V,N} print(io, ")") end -function Base.typemin(::Type{ForwardDiff.Dual{T,V,N}}) where {T,V,N} - ForwardDiff.Dual{T,V,N}(typemin(V)) +function Base.typemin(::Type{ForwardDiff.Dual{T,V,N,P}}) where {T,V,N,P} + ForwardDiff.Dual{T,V,N,P}(typemin(V)) end -function Base.typemax(::Type{ForwardDiff.Dual{T,V,N}}) where {T,V,N} - ForwardDiff.Dual{T,V,N}(typemax(V)) +function Base.typemax(::Type{ForwardDiff.Dual{T,V,N,P}}) where {T,V,N,P} + ForwardDiff.Dual{T,V,N,P}(typemax(V)) end diff --git a/test/DualTest.jl b/test/DualTest.jl index dec58c1f..26513115 100644 --- a/test/DualTest.jl +++ b/test/DualTest.jl @@ -57,10 +57,10 @@ for N in (0,3), M in (0,4), V in (Int, Float32) @test Dual(PRIMAL, PARTIALS...) === Dual{Nothing}(PRIMAL, PARTIALS...) @test Dual(PRIMAL) === Dual{Nothing}(PRIMAL) - @test typeof(Dual{TestTag()}(widen(V)(PRIMAL), PARTIALS)) === Dual{TestTag(),widen(V),N} - @test typeof(Dual{TestTag()}(widen(V)(PRIMAL), PARTIALS.values)) === Dual{TestTag(),widen(V),N} - @test typeof(Dual{TestTag()}(widen(V)(PRIMAL), PARTIALS...)) === Dual{TestTag(),widen(V),N} - @test typeof(NESTED_FDNUM) == Dual{TestTag(),Dual{TestTag(),V,M},N} + @test typeof(Dual{TestTag()}(widen(V)(PRIMAL), PARTIALS)) === Dual{TestTag(),widen(V),N,widen(V)} + @test typeof(Dual{TestTag()}(widen(V)(PRIMAL), PARTIALS.values)) === Dual{TestTag(),widen(V),N,widen(V)} + @test typeof(Dual{TestTag()}(widen(V)(PRIMAL), PARTIALS...)) === Dual{TestTag(),widen(V),N,widen(V)} + @test typeof(NESTED_FDNUM) == Dual{TestTag(),Dual{TestTag(),V,M,V},N,Dual{TestTag(),V,M,V}} ############# # Accessors # @@ -88,8 +88,8 @@ for N in (0,3), M in (0,4), V in (Int, Float32) @test ForwardDiff.valtype(FDNUM) == V @test ForwardDiff.valtype(typeof(FDNUM)) == V - @test ForwardDiff.valtype(NESTED_FDNUM) == Dual{TestTag(),V,M} - @test ForwardDiff.valtype(typeof(NESTED_FDNUM)) == Dual{TestTag(),V,M} + @test ForwardDiff.valtype(NESTED_FDNUM) == Dual{TestTag(),V,M,V} + @test ForwardDiff.valtype(typeof(NESTED_FDNUM)) == Dual{TestTag(),V,M,V} ##################### # Generic Functions # @@ -290,22 +290,22 @@ for N in (0,3), M in (0,4), V in (Int, Float32) WIDE_T = widen(V) - @test promote_type(Dual{TestTag(),V,N}, V) == Dual{TestTag(),V,N} - @test promote_type(Dual{TestTag(),V,N}, WIDE_T) == Dual{TestTag(),WIDE_T,N} - @test promote_type(Dual{TestTag(),WIDE_T,N}, V) == Dual{TestTag(),WIDE_T,N} - @test promote_type(Dual{TestTag(),V,N}, Dual{TestTag(),V,N}) == Dual{TestTag(),V,N} - @test promote_type(Dual{TestTag(),V,N}, Dual{TestTag(),WIDE_T,N}) == Dual{TestTag(),WIDE_T,N} - @test promote_type(Dual{TestTag(),WIDE_T,N}, Dual{TestTag(),Dual{TestTag(),V,M},N}) == Dual{TestTag(),Dual{TestTag(),WIDE_T,M},N} + @test promote_type(Dual{TestTag(),V,N}, V) == Dual{TestTag(),V,N,V} + @test promote_type(Dual{TestTag(),V,N}, WIDE_T) == Dual{TestTag(),WIDE_T,N,WIDE_T} + @test promote_type(Dual{TestTag(),WIDE_T,N}, V) == Dual{TestTag(),WIDE_T,N,WIDE_T} + @test promote_type(Dual{TestTag(),V,N,V}, Dual{TestTag(),V,N,V}) == Dual{TestTag(),V,N,V} + @test promote_type(Dual{TestTag(),V,N}, Dual{TestTag(),WIDE_T,N}) == Dual{TestTag(),WIDE_T,N,WIDE_T} + @test promote_type(Dual{TestTag(),WIDE_T,N}, Dual{TestTag(),Dual{TestTag(),V,M},N}) == Dual{TestTag(),Dual{TestTag(),WIDE_T,M,WIDE_T},N,Dual{TestTag(),WIDE_T,M,WIDE_T}} # issue #322 - @test promote_type(Bool, Dual{TestTag(),V,N}) == Dual{TestTag(),promote_type(Bool, V),N} - @test promote_type(BigFloat, Dual{TestTag(),V,N}) == Dual{TestTag(),promote_type(BigFloat, V),N} + @test promote_type(Bool, Dual{TestTag(),V,N}) == Dual{TestTag(),promote_type(Bool, V),N,promote_type(Bool, V)} + @test promote_type(BigFloat, Dual{TestTag(),V,N}) == Dual{TestTag(),promote_type(BigFloat, V),N,promote_type(BigFloat, V)} WIDE_FDNUM = convert(Dual{TestTag(),WIDE_T,N}, FDNUM) WIDE_NESTED_FDNUM = convert(Dual{TestTag(),Dual{TestTag(),WIDE_T,M},N}, NESTED_FDNUM) - @test typeof(WIDE_FDNUM) === Dual{TestTag(),WIDE_T,N} - @test typeof(WIDE_NESTED_FDNUM) === Dual{TestTag(),Dual{TestTag(),WIDE_T,M},N} + @test typeof(WIDE_FDNUM) === Dual{TestTag(),WIDE_T,N,WIDE_T} + @test typeof(WIDE_NESTED_FDNUM) === Dual{TestTag(),Dual{TestTag(),WIDE_T,M,WIDE_T},N,Dual{TestTag(),WIDE_T,M,WIDE_T}} @test value(WIDE_FDNUM) == PRIMAL @test value(WIDE_NESTED_FDNUM) == PRIMAL From cf8814e34f2e89433518a57d2f40a1fe59c592cb Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Mon, 27 Jul 2020 07:34:16 -0500 Subject: [PATCH 2/2] Use `oneunit` instead of `one` where appropriate --- src/derivative.jl | 8 ++++---- src/partials.jl | 12 ++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/derivative.jl b/src/derivative.jl index fa6d98fa..c1b55e4e 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -11,7 +11,7 @@ This method assumes that `isa(f(x), Union{Real,AbstractArray})`. """ @inline function derivative(f::F, x::R) where {F,R<:Real} T = typeof(Tag(f, R)) - return extract_derivative(T, f(Dual{T}(x, one(x)))) + return extract_derivative(T, f(Dual{T}(x, oneunit(x)))) end """ @@ -27,7 +27,7 @@ Set `check` to `Val{false}()` to disable tag checking. This can lead to perturba CHK && checktag(T, f!, x) ydual = cfg.duals seed!(ydual, y) - f!(ydual, Dual{T}(x, one(x))) + f!(ydual, Dual{T}(x, oneunit(x))) map!(value, y, ydual) return extract_derivative(T, ydual) end @@ -43,7 +43,7 @@ This method assumes that `isa(f(x), Union{Real,AbstractArray})`. @inline function derivative!(result::Union{AbstractArray,DiffResult}, f::F, x::R) where {F,R<:Real} T = typeof(Tag(f, R)) - ydual = f(Dual{T}(x, one(x))) + ydual = f(Dual{T}(x, oneunit(x))) result = extract_value!(T, result, ydual) result = extract_derivative!(T, result, ydual) return result @@ -63,7 +63,7 @@ Set `check` to `Val{false}()` to disable tag checking. This can lead to perturba CHK && checktag(T, f!, x) ydual = cfg.duals seed!(ydual, y) - f!(ydual, Dual{T}(x, one(x))) + f!(ydual, Dual{T}(x, oneunit(x))) result = extract_value!(T, result, y, ydual) result = extract_derivative!(T, result, ydual) return result diff --git a/src/partials.jl b/src/partials.jl index fce67b0a..a26ec165 100644 --- a/src/partials.jl +++ b/src/partials.jl @@ -7,7 +7,7 @@ end ############################## @generated function single_seed(::Type{Partials{N,V}}, ::Val{i}) where {N,V,i} - ex = Expr(:tuple, [ifelse(i === j, :(one(V)), :(zero(V))) for j in 1:N]...) + ex = Expr(:tuple, [ifelse(i === j, :(oneunit(V)), :(zero(V))) for j in 1:N]...) return :(Partials($(ex))) end @@ -92,18 +92,18 @@ end if NANSAFE_MODE_ENABLED @inline function Base.:*(partials::Partials, x::Real) - x = ifelse(!isfinite(x) && iszero(partials), one(x), x) + x = ifelse(!isfinite(x) && iszero(partials), oneunit(x), x) return Partials(scale_tuple(partials.values, x)) end @inline function Base.:/(partials::Partials, x::Real) - x = ifelse(x == zero(x) && iszero(partials), one(x), x) + x = ifelse(x == zero(x) && iszero(partials), oneunit(x), x) return Partials(div_tuple_by_scalar(partials.values, x)) end @inline function _mul_partials(a::Partials{N}, b::Partials{N}, x_a, x_b) where N - x_a = ifelse(!isfinite(x_a) && iszero(a), one(x_a), x_a) - x_b = ifelse(!isfinite(x_b) && iszero(b), one(x_b), x_b) + x_a = ifelse(!isfinite(x_a) && iszero(a), oneunit(x_a), x_a) + x_b = ifelse(!isfinite(x_b) && iszero(b), oneunit(x_b), x_b) return Partials(mul_tuples(a.values, b.values, x_a, x_b)) end else @@ -184,7 +184,7 @@ end @generated function one_tuple(::Type{NTuple{N,V}}) where {N,V} ex = tupexpr(i -> :(z), N) return quote - z = one(V) + z = oneunit(V) return $ex end end