diff --git a/docs/src/user/advanced.md b/docs/src/user/advanced.md index 8ff07758..f850a1eb 100644 --- a/docs/src/user/advanced.md +++ b/docs/src/user/advanced.md @@ -134,7 +134,7 @@ or `isinf(y)`, in which case a `NaN` derivative will be propagated instead. It is possible to fix this behavior by checking that the perturbation component is zero before attempting to propagate derivative information, but this check can noticeably -decrease performance (~5%-10% on our benchmarks). +decrease performance. In order to preserve performance in the majority of use cases, ForwardDiff disables this check by default. If your code is affected by this `NaN` behvaior, you can enable diff --git a/src/dual.jl b/src/dual.jl index a76500b7..ab80224e 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -292,7 +292,7 @@ end # Predicates # #------------# -isconstant(d::Dual) = iszero(partials(d)) +isconstant(d::Dual) = all_zero_partials(partials(d)) for pred in UNARY_PREDICATES @eval Base.$(pred)(d::Dual) = $(pred)(value(d)) diff --git a/src/partials.jl b/src/partials.jl index eb33487c..63cfdfb9 100644 --- a/src/partials.jl +++ b/src/partials.jl @@ -2,6 +2,8 @@ struct Partials{N,V} <: AbstractVector{V} values::NTuple{N,V} end +Partials(values::Tuple) = Partials(promote(values...)) + ############################## # Utility/Accessor Functions # ############################## @@ -34,7 +36,7 @@ Base.mightalias(x::AbstractArray, y::Partials) = false # Generic Functions # ##################### -@inline iszero(partials::Partials) = iszero_tuple(partials.values) +@inline all_zero_partials(partials::Partials) = iszero_tuple(partials.values) @inline Base.zero(partials::Partials) = zero(typeof(partials)) @inline Base.zero(::Type{Partials{N,V}}) where {N,V} = Partials{N,V}(zero_tuple(NTuple{N,V})) @@ -92,19 +94,15 @@ end if NANSAFE_MODE_ENABLED @inline function Base.:*(partials::Partials, x::Real) - x = ifelse(!isfinite(x) && iszero(partials), one(x), x) - return Partials(scale_tuple(partials.values, x)) + return Partials(scale_tuple_nansafe(partials.values, x)) end @inline function Base.:/(partials::Partials, x::Real) - x = ifelse(x == zero(x) && iszero(partials), one(x), x) - return Partials(div_tuple_by_scalar(partials.values, x)) + return Partials(div_tuple_by_scalar_nansafe(partials.values, x)) end @inline function _mul_partials(a::Partials{N}, b::Partials{N}, x_a, x_b) where N - x_a = ifelse(!isfinite(x_a) && iszero(a), one(x_a), x_a) - x_b = ifelse(!isfinite(x_b) && iszero(b), one(x_b), x_b) - return Partials(mul_tuples(a.values, b.values, x_a, x_b)) + return Partials(mul_tuples_nansafe(a.values, b.values, x_a, x_b)) end else @inline function Base.:*(partials::Partials, x::Real) @@ -201,10 +199,32 @@ end return tupexpr(i -> :(tup[$i] * x), N) end +@generated function scale_tuple_nansafe(tup::NTuple{N}, x) where N + ex = tupexpr(i -> quote + t_i = tup[$i] + ifelse(is_x_inf && iszero(t_i), t_i, t_i * x) + end, N) + return quote + is_x_inf = !isfinite(x) + return $ex + end +end + @generated function div_tuple_by_scalar(tup::NTuple{N}, x) where N return tupexpr(i -> :(tup[$i] / x), N) end +@generated function div_tuple_by_scalar_nansafe(tup::NTuple{N}, x) where N + ex = tupexpr(i -> quote + t_i = tup[$i] + ifelse(is_x_zero && iszero(t_i), t_i, t_i / x) + end, N) + return quote + is_x_zero = iszero(x) + return $ex + end +end + @generated function add_tuples(a::NTuple{N}, b::NTuple{N}) where N return tupexpr(i -> :(a[$i] + b[$i]), N) end @@ -221,6 +241,21 @@ end return tupexpr(i -> :((afactor * a[$i]) + (bfactor * b[$i])), N) end +@generated function mul_tuples_nansafe(a::NTuple{N}, b::NTuple{N}, afactor, bfactor) where N + ex = tupexpr(i -> quote + a_i = a[$i] + b_i = b[$i] + a_term = ifelse(is_af_inf && iszero(a_i), a_i, a_i * afactor) + b_term = ifelse(is_bf_inf && iszero(b_i), b_i, b_i * bfactor) + a_term + b_term + end, N) + return quote + is_af_inf = !isfinite(afactor) + is_bf_inf = !isfinite(bfactor) + return $ex + end +end + ################### # Pretty Printing # ################### diff --git a/test/PartialsTest.jl b/test/PartialsTest.jl index c6ff0e8b..03051d17 100644 --- a/test/PartialsTest.jl +++ b/test/PartialsTest.jl @@ -54,8 +54,8 @@ for N in (0, 3), T in (Int, Float32, Float64) @test rand(samerng(), PARTIALS) == rand(samerng(), typeof(PARTIALS)) - @test ForwardDiff.iszero(PARTIALS) == (N == 0) - @test ForwardDiff.iszero(zero(PARTIALS)) + @test ForwardDiff.all_zero_partials(PARTIALS) == (N == 0) + @test ForwardDiff.all_zero_partials(zero(PARTIALS)) @test PARTIALS == copy(PARTIALS) @test (PARTIALS == PARTIALS2) == (N == 0) @@ -120,10 +120,14 @@ for N in (0, 3), T in (Int, Float32, Float64) if ForwardDiff.NANSAFE_MODE_ENABLED ZEROS = Partials((fill(zero(T), N)...,)) + SEED = ForwardDiff.single_seed(Partials{N,T}, Val(N)) @test (NaN * ZEROS).values == ZEROS.values + @test (NaN * SEED).values === promote(fill(zero(T), N - 1)..., NaN) @test (Inf * ZEROS).values == ZEROS.values - @test (ZEROS / 0).values == ZEROS.values + @test (Inf * SEED).values === promote(fill(zero(T), N - 1)..., Inf) + @test (ZEROS / 0.0).values == ZEROS.values + @test (SEED / 0.0).values === promote(fill(zero(T), N - 1)..., Inf) @test ForwardDiff._mul_partials(ZEROS, ZEROS, X, NaN).values == ZEROS.values @test ForwardDiff._mul_partials(ZEROS, ZEROS, NaN, X).values == ZEROS.values