@@ -82,6 +82,18 @@ Base.convert(::Type{Partials{N,V}}, partials::Partials{N,V}) where {N,V} = parti
8282@inline Base.:- (partials:: Partials ) = Partials (minus_tuple (partials. values))
8383@inline Base.:* (x:: Real , partials:: Partials ) = partials* x
8484
85+ @inline function Base.:* (partials:: Partials , x:: Real )
86+ return Partials (scale_tuple (partials. values, x))
87+ end
88+
89+ @inline function Base.:/ (partials:: Partials , x:: Real )
90+ return Partials (div_tuple_by_scalar (partials. values, x))
91+ end
92+
93+ @inline function _mul_partials (a:: Partials{N} , b:: Partials{N} , x_a, x_b) where N
94+ return Partials (mul_tuples (a. values, b. values, x_a, x_b))
95+ end
96+
8597@inline function _div_partials (a:: Partials , b:: Partials , aval, bval)
8698 return _mul_partials (a, b, inv (bval), - (aval / (bval* bval)))
8799end
90102# ----------------------#
91103
92104if NANSAFE_MODE_ENABLED
93- @inline function Base.: * (partials :: Partials , x :: Real )
94- x = ifelse ( ! isfinite (x) && iszero (partials), one (x), x)
95- return Partials ( scale_tuple (partials . values, x))
96- end
97-
98- @inline function Base.: / (partials :: Partials , x:: Real )
99- x = ifelse (x == zero (x) && iszero (partials), one (x), x)
100- return Partials ( div_tuple_by_scalar (partials . values, x))
105+ # A dual number with a zero partial is just an unperturbed non-dual number
106+ # Hence when propagated the resulting dual number is unperturbed as well,
107+ # ie., its partial is zero as well, regardless of the primal value
108+ # However, standard floating point multiplication/division would return `NaN`
109+ # if the primal is not-finite/zero
110+ @inline function _mul_partial (partial :: Real , x:: Real )
111+ y = partial * x
112+ return iszero (partial) ? zero (y) : y
101113 end
102-
103- @inline function _mul_partials (a:: Partials{N} , b:: Partials{N} , x_a, x_b) where N
104- x_a = ifelse (! isfinite (x_a) && iszero (a), one (x_a), x_a)
105- x_b = ifelse (! isfinite (x_b) && iszero (b), one (x_b), x_b)
106- return Partials (mul_tuples (a. values, b. values, x_a, x_b))
114+ @inline function _div_partial (partial:: Real , x:: Real )
115+ y = partial / x
116+ return iszero (partial) ? zero (y) : y
107117 end
108118else
109- @inline function Base.:* (partials:: Partials , x:: Real )
110- return Partials (scale_tuple (partials. values, x))
111- end
112-
113- @inline function Base.:/ (partials:: Partials , x:: Real )
114- return Partials (div_tuple_by_scalar (partials. values, x))
115- end
116-
117- @inline function _mul_partials (a:: Partials{N} , b:: Partials{N} , x_a, x_b) where N
118- return Partials (mul_tuples (a. values, b. values, x_a, x_b))
119- end
119+ @inline _mul_partial (partial:: Real , x:: Real ) = partial * x
120+ @inline _div_partial (partial:: Real , x:: Real ) = partial / x
120121end
121122
122123# edge cases where N == 0 #
@@ -197,11 +198,11 @@ end
197198end
198199
199200@generated function scale_tuple (tup:: NTuple{N} , x) where N
200- return tupexpr (i -> :(tup[$ i] * x ), N)
201+ return tupexpr (i -> :(_mul_partial ( tup[$ i], x) ), N)
201202end
202203
203204@generated function div_tuple_by_scalar (tup:: NTuple{N} , x) where N
204- return tupexpr (i -> :(tup[$ i] / x ), N)
205+ return tupexpr (i -> :(_div_partial ( tup[$ i], x) ), N)
205206end
206207
207208@generated function add_tuples (a:: NTuple{N} , b:: NTuple{N} ) where N
217218end
218219
219220@generated function mul_tuples (a:: NTuple{N} , b:: NTuple{N} , afactor, bfactor) where N
220- return tupexpr (i -> :((afactor * a[$ i]) + (bfactor * b[$ i])), N)
221+ return tupexpr (i -> :(_mul_partial ( a[$ i], afactor ) + _mul_partial ( b[$ i], bfactor )), N)
221222end
222223
223224# ##################
0 commit comments