@@ -148,6 +148,12 @@ macro define_ternary_dual_op(f, xyz_body, xy_body, xz_body, yz_body, x_body, y_b
148148 end
149149 append! (defs. args, expr. args)
150150 end
151+ expr = quote
152+ @inline $ (f)(x:: Dual{T} , y:: $R , z:: $R ) where {T} = $ x_body
153+ @inline $ (f)(x:: $R , y:: Dual{T} , z:: $R ) where {T} = $ y_body
154+ @inline $ (f)(x:: $R , y:: $R , z:: Dual{T} ) where {T} = $ z_body
155+ end
156+ append! (defs. args, expr. args)
151157 end
152158 return esc (defs)
153159end
@@ -391,23 +397,21 @@ end
391397# Special Cases #
392398# ################
393399
394- # Manually Optimized Functions #
395- # ------------------------------#
400+ # exp
396401
397402@inline function Base. exp {T} (d:: Dual{T} )
398403 expv = exp (value (d))
399404 return Dual {T} (expv, expv * partials (d))
400405end
401406
407+ # sqrt
408+
402409@inline function Base. sqrt {T} (d:: Dual{T} )
403410 sqrtv = sqrt (value (d))
404411 deriv = inv (sqrtv + sqrtv)
405412 return Dual {T} (sqrtv, deriv * partials (d))
406413end
407414
408- # Other Functions #
409- # -----------------#
410-
411415# hypot
412416
413417@inline function calc_hypot {T} (x, y, :: Type{T} )
@@ -461,22 +465,28 @@ end
461465 calc_atan2 (x, y, T)
462466)
463467
464- @generated function Base. fma {N} (x:: Dual{N} , y:: Dual{N} , z:: Dual{N} )
468+ # fma
469+
470+ @generated function calc_fma_xyz (x:: Dual{T,<:Real,N} ,
471+ y:: Dual{T,<:Real,N} ,
472+ z:: Dual{T,<:Real,N} ) where {T,N}
465473 ex = Expr (:tuple , [:(fma (value (x), partials (y)[$ i], fma (value (y), partials (x)[$ i], partials (z)[$ i]))) for i in 1 : N]. .. )
466474 return quote
467475 $ (Expr (:meta , :inline ))
468476 v = fma (value (x), value (y), value (z))
469- Dual (v, $ ex)
477+ return Dual {T} (v, $ ex)
470478 end
471479end
472480
473- @inline function Base . fma (x:: Dual , y:: Dual , z:: Real )
481+ @inline function calc_fma_xy (x:: Dual{T} , y:: Dual{T} , z:: Real ) where T
474482 vx, vy = value (x), value (y)
475483 result = fma (vx, vy, z)
476- return Dual (result, _mul_partials (partials (x), partials (y), vy, vx))
484+ return Dual {T} (result, _mul_partials (partials (x), partials (y), vy, vx))
477485end
478486
479- @generated function Base. fma {N} (x:: Dual{N} , y:: Real , z:: Dual{N} )
487+ @generated function calc_fma_xz (x:: Dual{T,<:Real,N} ,
488+ y:: Real ,
489+ z:: Dual{T,<:Real,N} ) where {T,N}
480490 ex = Expr (:tuple , [:(fma (partials (x)[$ i], y, partials (z)[$ i])) for i in 1 : N]. .. )
481491 return quote
482492 $ (Expr (:meta , :inline ))
@@ -485,14 +495,16 @@ end
485495 end
486496end
487497
488- @inline Base. fma (x:: Real , y:: Dual , z:: Dual ) = fma (y, x, z)
489-
490- @inline function Base. fma (x:: Dual , y:: Real , z:: Real )
491- vx = value (x)
492- return Dual (fma (vx, y, value (z)), partials (x) * y)
493- end
494-
495- @inline Base. fma (x:: Real , y:: Dual , z:: Real ) = fma (y, x, z)
498+ @define_ternary_dual_op (
499+ Base. fma,
500+ calc_fma_xyz (x, y, z), # xyz_body
501+ calc_fma_xy (x, y, z), # xy_body
502+ calc_fma_xz (x, y, z), # xz_body
503+ Base. fma (y, x, z), # yz_body
504+ Dual {T} (fma (value (x), y, z), partials (x) * y), # x_body
505+ Base. fma (y, x, z), # y_body
506+ Dual {T} (fma (x, y, value (z)), partials (z)) # z_body
507+ )
496508
497509# sincos
498510
0 commit comments