@@ -82,6 +82,18 @@ Base.convert(::Type{Partials{N,V}}, partials::Partials{N,V}) where {N,V} = parti
8282@inline Base.:- (partials:: Partials ) = Partials (minus_tuple (partials. values))
8383@inline Base.:* (x:: Real , partials:: Partials ) = partials* x
8484
85+ @inline function Base.:* (partials:: Partials , x:: Real )
86+ return Partials (scale_tuple (partials. values, x))
87+ end
88+
89+ @inline function Base.:/ (partials:: Partials , x:: Real )
90+ return Partials (div_tuple_by_scalar (partials. values, x))
91+ end
92+
93+ @inline function _mul_partials (a:: Partials{N} , b:: Partials{N} , x_a, x_b) where N
94+ return Partials (mul_tuples (a. values, b. values, x_a, x_b))
95+ end
96+
8597@inline function _div_partials (a:: Partials , b:: Partials , aval, bval)
8698 return _mul_partials (a, b, inv (bval), - (aval / (bval* bval)))
8799end
90102# ----------------------#
91103
92104if NANSAFE_MODE_ENABLED
93- @inline function Base.:* (partials:: Partials , x:: Real )
94- x = ifelse (! isfinite (x) && iszero (partials), one (x), x)
95- return Partials (scale_tuple (partials. values, x))
96- end
97-
98- @inline function Base.:/ (partials:: Partials , x:: Real )
99- x = ifelse (x == zero (x) && iszero (partials), one (x), x)
100- return Partials (div_tuple_by_scalar (partials. values, x))
105+ @inline function _mul_partial (partial:: Real , x:: Real )
106+ y = partial * x
107+ return ! isfinite (x) && iszero (partial) ? zero (y) : y
101108 end
102-
103- @inline function _mul_partials (a:: Partials{N} , b:: Partials{N} , x_a, x_b) where N
104- x_a = ifelse (! isfinite (x_a) && iszero (a), one (x_a), x_a)
105- x_b = ifelse (! isfinite (x_b) && iszero (b), one (x_b), x_b)
106- return Partials (mul_tuples (a. values, b. values, x_a, x_b))
109+ @inline function _div_partial (partial:: Real , x:: Real )
110+ y = partial / x
111+ return iszero (x) && iszero (partial) ? zero (y) : y
107112 end
108113else
109- @inline function Base.:* (partials:: Partials , x:: Real )
110- return Partials (scale_tuple (partials. values, x))
111- end
112-
113- @inline function Base.:/ (partials:: Partials , x:: Real )
114- return Partials (div_tuple_by_scalar (partials. values, x))
115- end
116-
117- @inline function _mul_partials (a:: Partials{N} , b:: Partials{N} , x_a, x_b) where N
118- return Partials (mul_tuples (a. values, b. values, x_a, x_b))
119- end
114+ @inline _mul_partial (partial:: Real , x:: Real ) = partial * x
115+ @inline _div_partial (partial:: Real , x:: Real ) = partial / x
120116end
121117
122118# edge cases where N == 0 #
@@ -197,11 +193,11 @@ end
197193end
198194
199195@generated function scale_tuple (tup:: NTuple{N} , x) where N
200- return tupexpr (i -> :(tup[$ i] * x ), N)
196+ return tupexpr (i -> :(_mul_partial ( tup[$ i], x) ), N)
201197end
202198
203199@generated function div_tuple_by_scalar (tup:: NTuple{N} , x) where N
204- return tupexpr (i -> :(tup[$ i] / x ), N)
200+ return tupexpr (i -> :(_div_partial ( tup[$ i], x) ), N)
205201end
206202
207203@generated function add_tuples (a:: NTuple{N} , b:: NTuple{N} ) where N
217213end
218214
219215@generated function mul_tuples (a:: NTuple{N} , b:: NTuple{N} , afactor, bfactor) where N
220- return tupexpr (i -> :((afactor * a[$ i]) + (bfactor * b[$ i])), N)
216+ return tupexpr (i -> :(_mul_partial ( a[$ i], afactor ) + _mul_partial ( b[$ i], bfactor )), N)
221217end
222218
223219# ##################
0 commit comments