Skip to content

Commit a93cc69

Browse files
committed
get FMA working
1 parent a53de0f commit a93cc69

File tree

2 files changed

+39
-33
lines changed

2 files changed

+39
-33
lines changed

src/dual.jl

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,12 @@ macro define_ternary_dual_op(f, xyz_body, xy_body, xz_body, yz_body, x_body, y_b
148148
end
149149
append!(defs.args, expr.args)
150150
end
151+
expr = quote
152+
@inline $(f)(x::Dual{T}, y::$R, z::$R) where {T} = $x_body
153+
@inline $(f)(x::$R, y::Dual{T}, z::$R) where {T} = $y_body
154+
@inline $(f)(x::$R, y::$R, z::Dual{T}) where {T} = $z_body
155+
end
156+
append!(defs.args, expr.args)
151157
end
152158
return esc(defs)
153159
end
@@ -391,23 +397,21 @@ end
391397
# Special Cases #
392398
#################
393399

394-
# Manually Optimized Functions #
395-
#------------------------------#
400+
# exp
396401

397402
@inline function Base.exp{T}(d::Dual{T})
398403
expv = exp(value(d))
399404
return Dual{T}(expv, expv * partials(d))
400405
end
401406

407+
# sqrt
408+
402409
@inline function Base.sqrt{T}(d::Dual{T})
403410
sqrtv = sqrt(value(d))
404411
deriv = inv(sqrtv + sqrtv)
405412
return Dual{T}(sqrtv, deriv * partials(d))
406413
end
407414

408-
# Other Functions #
409-
#-----------------#
410-
411415
# hypot
412416

413417
@inline function calc_hypot{T}(x, y, ::Type{T})
@@ -461,22 +465,28 @@ end
461465
calc_atan2(x, y, T)
462466
)
463467

464-
@generated function Base.fma{N}(x::Dual{N}, y::Dual{N}, z::Dual{N})
468+
# fma
469+
470+
@generated function calc_fma_xyz(x::Dual{T,<:Real,N},
471+
y::Dual{T,<:Real,N},
472+
z::Dual{T,<:Real,N}) where {T,N}
465473
ex = Expr(:tuple, [:(fma(value(x), partials(y)[$i], fma(value(y), partials(x)[$i], partials(z)[$i]))) for i in 1:N]...)
466474
return quote
467475
$(Expr(:meta, :inline))
468476
v = fma(value(x), value(y), value(z))
469-
Dual(v, $ex)
477+
return Dual{T}(v, $ex)
470478
end
471479
end
472480

473-
@inline function Base.fma(x::Dual, y::Dual, z::Real)
481+
@inline function calc_fma_xy(x::Dual{T}, y::Dual{T}, z::Real) where T
474482
vx, vy = value(x), value(y)
475483
result = fma(vx, vy, z)
476-
return Dual(result, _mul_partials(partials(x), partials(y), vy, vx))
484+
return Dual{T}(result, _mul_partials(partials(x), partials(y), vy, vx))
477485
end
478486

479-
@generated function Base.fma{N}(x::Dual{N}, y::Real, z::Dual{N})
487+
@generated function calc_fma_xz(x::Dual{T,<:Real,N},
488+
y::Real,
489+
z::Dual{T,<:Real,N}) where {T,N}
480490
ex = Expr(:tuple, [:(fma(partials(x)[$i], y, partials(z)[$i])) for i in 1:N]...)
481491
return quote
482492
$(Expr(:meta, :inline))
@@ -485,14 +495,16 @@ end
485495
end
486496
end
487497

488-
@inline Base.fma(x::Real, y::Dual, z::Dual) = fma(y, x, z)
489-
490-
@inline function Base.fma(x::Dual, y::Real, z::Real)
491-
vx = value(x)
492-
return Dual(fma(vx, y, value(z)), partials(x) * y)
493-
end
494-
495-
@inline Base.fma(x::Real, y::Dual, z::Real) = fma(y, x, z)
498+
@define_ternary_dual_op(
499+
Base.fma,
500+
calc_fma_xyz(x, y, z), # xyz_body
501+
calc_fma_xy(x, y, z), # xy_body
502+
calc_fma_xz(x, y, z), # xz_body
503+
Base.fma(y, x, z), # yz_body
504+
Dual{T}(fma(value(x), y, z), partials(x) * y), # x_body
505+
Base.fma(y, x, z), # y_body
506+
Dual{T}(fma(x, y, value(z)), partials(z)) # z_body
507+
)
496508

497509
# sincos
498510

test/DualTest.jl

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -367,21 +367,6 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
367367

368368
@test partials(NaNMath.pow(Dual(-2.0, 1.0), Dual(2.0, 0.0)), 1) == -4.0
369369

370-
test_approx_diffnums(fma(FDNUM, FDNUM2, FDNUM3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
371-
PRIMAL*PARTIALS2 + PRIMAL2*PARTIALS +
372-
PARTIALS3))
373-
test_approx_diffnums(fma(FDNUM, FDNUM2, PRIMAL3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
374-
PRIMAL*PARTIALS2 + PRIMAL2*PARTIALS))
375-
test_approx_diffnums(fma(PRIMAL, FDNUM2, FDNUM3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
376-
PRIMAL*PARTIALS2 + PARTIALS3))
377-
test_approx_diffnums(fma(PRIMAL, FDNUM2, PRIMAL3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
378-
PRIMAL*PARTIALS2))
379-
test_approx_diffnums(fma(FDNUM, PRIMAL2, FDNUM3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
380-
PRIMAL2*PARTIALS + PARTIALS3))
381-
test_approx_diffnums(fma(FDNUM, PRIMAL2, PRIMAL3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
382-
PRIMAL2*PARTIALS))
383-
test_approx_diffnums(fma(PRIMAL, PRIMAL2, FDNUM3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), PARTIALS3))
384-
385370
# Unary Functions #
386371
#-----------------#
387372

@@ -437,12 +422,21 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
437422

438423
@test dual_isapprox(hypot(FDNUM, FDNUM2), sqrt(FDNUM^2 + FDNUM2^2))
439424
@test dual_isapprox(hypot(FDNUM, FDNUM2, FDNUM), sqrt(2*(FDNUM^2) + FDNUM2^2))
425+
440426
@test all(map(dual_isapprox, ForwardDiff.sincos(FDNUM), (sin(FDNUM), cos(FDNUM))))
441427

442428
if V === Float32
443429
@test typeof(sqrt(FDNUM)) === typeof(FDNUM)
444430
@test typeof(sqrt(NESTED_FDNUM)) === typeof(NESTED_FDNUM)
445431
end
432+
433+
@test dual_isapprox(fma(FDNUM, FDNUM2, FDNUM3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), PRIMAL*PARTIALS2 + PRIMAL2*PARTIALS + PARTIALS3))
434+
@test dual_isapprox(fma(FDNUM, FDNUM2, PRIMAL3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), PRIMAL*PARTIALS2 + PRIMAL2*PARTIALS))
435+
@test dual_isapprox(fma(PRIMAL, FDNUM2, FDNUM3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), PRIMAL*PARTIALS2 + PARTIALS3))
436+
@test dual_isapprox(fma(PRIMAL, FDNUM2, PRIMAL3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), PRIMAL*PARTIALS2))
437+
@test dual_isapprox(fma(FDNUM, PRIMAL2, FDNUM3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), PRIMAL2*PARTIALS + PARTIALS3))
438+
@test dual_isapprox(fma(FDNUM, PRIMAL2, PRIMAL3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), PRIMAL2*PARTIALS))
439+
@test dual_isapprox(fma(PRIMAL, PRIMAL2, FDNUM3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), PARTIALS3))
446440
end
447441

448442
end # module

0 commit comments

Comments
 (0)