Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/src/user/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ or `isinf(y)`, in which case a `NaN` derivative will be propagated instead.

It is possible to fix this behavior by checking that the perturbation component is zero
before attempting to propagate derivative information, but this check can noticeably
decrease performance (~5%-10% on our benchmarks).
decrease performance.

In order to preserve performance in the majority of use cases, ForwardDiff disables this
check by default. If your code is affected by this `NaN` behvaior, you can enable
Expand Down
2 changes: 1 addition & 1 deletion src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ end
# Predicates #
#------------#

isconstant(d::Dual) = iszero(partials(d))
isconstant(d::Dual) = all_zero_partials(partials(d))

for pred in UNARY_PREDICATES
@eval Base.$(pred)(d::Dual) = $(pred)(value(d))
Expand Down
51 changes: 43 additions & 8 deletions src/partials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ struct Partials{N,V} <: AbstractVector{V}
values::NTuple{N,V}
end

Partials(values::Tuple) = Partials(promote(values...))

##############################
# Utility/Accessor Functions #
##############################
Expand Down Expand Up @@ -34,7 +36,7 @@ Base.mightalias(x::AbstractArray, y::Partials) = false
# Generic Functions #
#####################

@inline iszero(partials::Partials) = iszero_tuple(partials.values)
@inline all_zero_partials(partials::Partials) = iszero_tuple(partials.values)

@inline Base.zero(partials::Partials) = zero(typeof(partials))
@inline Base.zero(::Type{Partials{N,V}}) where {N,V} = Partials{N,V}(zero_tuple(NTuple{N,V}))
Expand Down Expand Up @@ -92,19 +94,15 @@ end

if NANSAFE_MODE_ENABLED
@inline function Base.:*(partials::Partials, x::Real)
x = ifelse(!isfinite(x) && iszero(partials), one(x), x)
return Partials(scale_tuple(partials.values, x))
return Partials(scale_tuple_nansafe(partials.values, x))
end

@inline function Base.:/(partials::Partials, x::Real)
x = ifelse(x == zero(x) && iszero(partials), one(x), x)
return Partials(div_tuple_by_scalar(partials.values, x))
return Partials(div_tuple_by_scalar_nansafe(partials.values, x))
end

@inline function _mul_partials(a::Partials{N}, b::Partials{N}, x_a, x_b) where N
x_a = ifelse(!isfinite(x_a) && iszero(a), one(x_a), x_a)
x_b = ifelse(!isfinite(x_b) && iszero(b), one(x_b), x_b)
return Partials(mul_tuples(a.values, b.values, x_a, x_b))
return Partials(mul_tuples_nansafe(a.values, b.values, x_a, x_b))
end
else
@inline function Base.:*(partials::Partials, x::Real)
Expand Down Expand Up @@ -201,10 +199,32 @@ end
return tupexpr(i -> :(tup[$i] * x), N)
end

@generated function scale_tuple_nansafe(tup::NTuple{N}, x) where N
ex = tupexpr(i -> quote
t_i = tup[$i]
ifelse(is_x_inf && iszero(t_i), t_i, t_i * x)
end, N)
return quote
is_x_inf = !isfinite(x)
return $ex
end
end

@generated function div_tuple_by_scalar(tup::NTuple{N}, x) where N
return tupexpr(i -> :(tup[$i] / x), N)
end

@generated function div_tuple_by_scalar_nansafe(tup::NTuple{N}, x) where N
ex = tupexpr(i -> quote
t_i = tup[$i]
ifelse(is_x_zero && iszero(t_i), t_i, t_i / x)
end, N)
return quote
is_x_zero = iszero(x)
return $ex
end
end

@generated function add_tuples(a::NTuple{N}, b::NTuple{N}) where N
return tupexpr(i -> :(a[$i] + b[$i]), N)
end
Expand All @@ -221,6 +241,21 @@ end
return tupexpr(i -> :((afactor * a[$i]) + (bfactor * b[$i])), N)
end

@generated function mul_tuples_nansafe(a::NTuple{N}, b::NTuple{N}, afactor, bfactor) where N
ex = tupexpr(i -> quote
a_i = a[$i]
b_i = b[$i]
a_term = ifelse(is_af_inf && iszero(a_i), a_i, a_i * afactor)
b_term = ifelse(is_bf_inf && iszero(b_i), b_i, b_i * bfactor)
a_term + b_term
end, N)
return quote
is_af_inf = !isfinite(afactor)
is_bf_inf = !isfinite(bfactor)
return $ex
end
end

###################
# Pretty Printing #
###################
Expand Down
10 changes: 7 additions & 3 deletions test/PartialsTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ for N in (0, 3), T in (Int, Float32, Float64)

@test rand(samerng(), PARTIALS) == rand(samerng(), typeof(PARTIALS))

@test ForwardDiff.iszero(PARTIALS) == (N == 0)
@test ForwardDiff.iszero(zero(PARTIALS))
@test ForwardDiff.all_zero_partials(PARTIALS) == (N == 0)
@test ForwardDiff.all_zero_partials(zero(PARTIALS))

@test PARTIALS == copy(PARTIALS)
@test (PARTIALS == PARTIALS2) == (N == 0)
Expand Down Expand Up @@ -120,10 +120,14 @@ for N in (0, 3), T in (Int, Float32, Float64)

if ForwardDiff.NANSAFE_MODE_ENABLED
ZEROS = Partials((fill(zero(T), N)...,))
SEED = ForwardDiff.single_seed(Partials{N,T}, Val(N))

@test (NaN * ZEROS).values == ZEROS.values
@test (NaN * SEED).values === promote(fill(zero(T), N - 1)..., NaN)
@test (Inf * ZEROS).values == ZEROS.values
@test (ZEROS / 0).values == ZEROS.values
@test (Inf * SEED).values === promote(fill(zero(T), N - 1)..., Inf)
@test (ZEROS / 0.0).values == ZEROS.values
@test (SEED / 0.0).values === promote(fill(zero(T), N - 1)..., Inf)

@test ForwardDiff._mul_partials(ZEROS, ZEROS, X, NaN).values == ZEROS.values
@test ForwardDiff._mul_partials(ZEROS, ZEROS, NaN, X).values == ZEROS.values
Expand Down