From 9e429ff9249deacf0a2d58386647d19bd98d6c8f Mon Sep 17 00:00:00 2001 From: Jarrett Revels Date: Tue, 12 Feb 2019 17:12:31 -0500 Subject: [PATCH 1/3] expand NaN-safe mode to cover individual zeros in Partials --- src/dual.jl | 2 +- src/partials.jl | 49 ++++++++++++++++++++++++++++++++++++-------- test/PartialsTest.jl | 4 ++-- 3 files changed, 44 insertions(+), 11 deletions(-) diff --git a/src/dual.jl b/src/dual.jl index a76500b7..ab80224e 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -292,7 +292,7 @@ end # Predicates # #------------# -isconstant(d::Dual) = iszero(partials(d)) +isconstant(d::Dual) = all_zero_partials(partials(d)) for pred in UNARY_PREDICATES @eval Base.$(pred)(d::Dual) = $(pred)(value(d)) diff --git a/src/partials.jl b/src/partials.jl index eb33487c..95cb75f6 100644 --- a/src/partials.jl +++ b/src/partials.jl @@ -34,7 +34,7 @@ Base.mightalias(x::AbstractArray, y::Partials) = false # Generic Functions # ##################### -@inline iszero(partials::Partials) = iszero_tuple(partials.values) +@inline all_zero_partials(partials::Partials) = iszero_tuple(partials.values) @inline Base.zero(partials::Partials) = zero(typeof(partials)) @inline Base.zero(::Type{Partials{N,V}}) where {N,V} = Partials{N,V}(zero_tuple(NTuple{N,V})) @@ -92,19 +92,15 @@ end if NANSAFE_MODE_ENABLED @inline function Base.:*(partials::Partials, x::Real) - x = ifelse(!isfinite(x) && iszero(partials), one(x), x) - return Partials(scale_tuple(partials.values, x)) + return Partials(scale_tuple_nansafe(partials.values, x)) end @inline function Base.:/(partials::Partials, x::Real) - x = ifelse(x == zero(x) && iszero(partials), one(x), x) - return Partials(div_tuple_by_scalar(partials.values, x)) + return Partials(div_tuple_by_scalar_nansafe(partials.values, x)) end @inline function _mul_partials(a::Partials{N}, b::Partials{N}, x_a, x_b) where N - x_a = ifelse(!isfinite(x_a) && iszero(a), one(x_a), x_a) - x_b = ifelse(!isfinite(x_b) && iszero(b), one(x_b), x_b) - return Partials(mul_tuples(a.values, b.values, x_a, x_b)) + return Partials(mul_tuples_nansafe(a.values, b.values, x_a, x_b)) end else @inline function Base.:*(partials::Partials, x::Real) @@ -201,10 +197,32 @@ end return tupexpr(i -> :(tup[$i] * x), N) end +@generated function scale_tuple_nansafe(tup::NTuple{N}, x) where N + ex = tupexpr(i -> quote + t_i = tup[$i] + ifelse(is_x_inf && iszero(t_i), t_i, t_i * x) + end, N) + return quote + is_x_inf = !isfinite(x) + return $ex + end +end + @generated function div_tuple_by_scalar(tup::NTuple{N}, x) where N return tupexpr(i -> :(tup[$i] / x), N) end +@generated function div_tuple_by_scalar_nansafe(tup::NTuple{N}, x) where N + ex = tupexpr(i -> quote + t_i = tup[$i] + ifelse(is_x_zero && iszero(t_i), t_i, t_i / x) + end, N) + return quote + is_x_zero = iszero(x) + return $ex + end +end + @generated function add_tuples(a::NTuple{N}, b::NTuple{N}) where N return tupexpr(i -> :(a[$i] + b[$i]), N) end @@ -221,6 +239,21 @@ end return tupexpr(i -> :((afactor * a[$i]) + (bfactor * b[$i])), N) end +@generated function mul_tuples_nansafe(a::NTuple{N}, b::NTuple{N}, afactor, bfactor) where N + ex = tupexpr(i -> quote + a_i = a[$i] + b_i = b[$i] + a_term = ifelse(is_af_inf && iszero(a_i), a_i, a_i * afactor) + b_term = ifelse(is_bf_inf && iszero(b_i), b_i, b_i * bfactor) + a_term + b_term + end, N) + return quote + is_af_inf = !isfinite(afactor) + is_bf_inf = !isfinite(bfactor) + return $ex + end +end + ################### # Pretty Printing # ################### diff --git a/test/PartialsTest.jl b/test/PartialsTest.jl index c6ff0e8b..76ac7a08 100644 --- a/test/PartialsTest.jl +++ b/test/PartialsTest.jl @@ -54,8 +54,8 @@ for N in (0, 3), T in (Int, Float32, Float64) @test rand(samerng(), PARTIALS) == rand(samerng(), typeof(PARTIALS)) - @test ForwardDiff.iszero(PARTIALS) == (N == 0) - @test ForwardDiff.iszero(zero(PARTIALS)) + @test ForwardDiff.all_zero_partials(PARTIALS) == (N == 0) + @test ForwardDiff.all_zero_partials(zero(PARTIALS)) @test PARTIALS == copy(PARTIALS) @test (PARTIALS == PARTIALS2) == (N == 0) From 8405459f06c0cb0772a55c64480712551fceb7c8 Mon Sep 17 00:00:00 2001 From: Jarrett Revels Date: Tue, 19 Feb 2019 14:34:18 -0500 Subject: [PATCH 2/3] add tests --- src/partials.jl | 2 ++ test/PartialsTest.jl | 6 +++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/partials.jl b/src/partials.jl index 95cb75f6..63cfdfb9 100644 --- a/src/partials.jl +++ b/src/partials.jl @@ -2,6 +2,8 @@ struct Partials{N,V} <: AbstractVector{V} values::NTuple{N,V} end +Partials(values::Tuple) = Partials(promote(values...)) + ############################## # Utility/Accessor Functions # ############################## diff --git a/test/PartialsTest.jl b/test/PartialsTest.jl index 76ac7a08..03051d17 100644 --- a/test/PartialsTest.jl +++ b/test/PartialsTest.jl @@ -120,10 +120,14 @@ for N in (0, 3), T in (Int, Float32, Float64) if ForwardDiff.NANSAFE_MODE_ENABLED ZEROS = Partials((fill(zero(T), N)...,)) + SEED = ForwardDiff.single_seed(Partials{N,T}, Val(N)) @test (NaN * ZEROS).values == ZEROS.values + @test (NaN * SEED).values === promote(fill(zero(T), N - 1)..., NaN) @test (Inf * ZEROS).values == ZEROS.values - @test (ZEROS / 0).values == ZEROS.values + @test (Inf * SEED).values === promote(fill(zero(T), N - 1)..., Inf) + @test (ZEROS / 0.0).values == ZEROS.values + @test (SEED / 0.0).values === promote(fill(zero(T), N - 1)..., Inf) @test ForwardDiff._mul_partials(ZEROS, ZEROS, X, NaN).values == ZEROS.values @test ForwardDiff._mul_partials(ZEROS, ZEROS, NaN, X).values == ZEROS.values From e7dc4d7e82e6559cc96368a32c2bb14c99db82d7 Mon Sep 17 00:00:00 2001 From: Jarrett Revels Date: Tue, 19 Feb 2019 14:48:35 -0500 Subject: [PATCH 3/3] make docs less precise but more correct --- docs/src/user/advanced.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/user/advanced.md b/docs/src/user/advanced.md index 8ff07758..f850a1eb 100644 --- a/docs/src/user/advanced.md +++ b/docs/src/user/advanced.md @@ -134,7 +134,7 @@ or `isinf(y)`, in which case a `NaN` derivative will be propagated instead. It is possible to fix this behavior by checking that the perturbation component is zero before attempting to propagate derivative information, but this check can noticeably -decrease performance (~5%-10% on our benchmarks). +decrease performance. In order to preserve performance in the majority of use cases, ForwardDiff disables this check by default. If your code is affected by this `NaN` behvaior, you can enable