From 7121d4bf0817ee5e36b0eb8fe15f2fcd9e14ea5f Mon Sep 17 00:00:00 2001 From: "Tamas K. Papp" Date: Wed, 8 Mar 2017 13:35:55 +0100 Subject: [PATCH] Added fma, with tests. Tests required a third value set (PRIMAL etc) in the testing framework, which was also added. Fixes issue #202. --- src/dual.jl | 32 ++++++++++++++++++++++++++++++++ test/DualTest.jl | 19 +++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/src/dual.jl b/src/dual.jl index 6e0d105d..dfed4927 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -410,6 +410,38 @@ end @ambiguous @inline Base.atan2(y::Real, x::Dual) = calc_atan2(y, x) @ambiguous @inline Base.atan2(y::Dual, x::Real) = calc_atan2(y, x) +@inline function Base.fma(x::Dual, y::Dual, z::Dual) + vx, vy = value(x), value(y) + result = fma(vx, vy, value(z)) + return Dual(result, + _mul_partials(partials(x), partials(y), vx, vy) + partials(z)) +end + +@inline function Base.fma(x::Dual, y::Dual, z::Real) + vx, vy = value(x), value(y) + result = fma(vx, vy, z) + return Dual(result, _mul_partials(partials(x), partials(y), vx, vy)) +end + +@inline function Base.fma(x::Dual, y::Real, z::Dual) + vx = value(x) + result = fma(vx, y, value(z)) + return Dual(result, partials(x) * y + partials(z)) +end + +@inline Base.fma(x::Real, y::Dual, z::Dual) = fma(y, x, z) + +@inline function Base.fma(x::Dual, y::Real, z::Real) + vx = value(x) + return Dual(fma(vx, y, value(z)), partials(x) * y) +end + +@inline Base.fma(x::Real, y::Dual, z::Real) = fma(y, x, z) + +@inline function Base.fma(x::Real, y::Real, z::Dual) + Dual(fma(x, y, value(z)), partials(z)) +end + ################### # Pretty Printing # ################### diff --git a/test/DualTest.jl b/test/DualTest.jl index ca9931a3..ef96341b 100644 --- a/test/DualTest.jl +++ b/test/DualTest.jl @@ -48,6 +48,10 @@ for N in (0,3), M in (0,4), T in (Int, Float32) PRIMAL2 = intrand(T) FDNUM2 = Dual(PRIMAL2, PARTIALS2) + PARTIALS3 = Partials{N,T}(ntuple(n -> intrand(T), Val{N})) + PRIMAL3 = intrand(T) + FDNUM3 = Dual(PRIMAL3, PARTIALS3) + M_PARTIALS = Partials{M,T}(ntuple(m -> intrand(T), Val{M})) NESTED_PARTIALS = convert(Partials{N,Dual{M,T}}, PARTIALS) NESTED_FDNUM = Dual(Dual(PRIMAL, M_PARTIALS), NESTED_PARTIALS) @@ -383,6 +387,21 @@ for N in (0,3), M in (0,4), T in (Int, Float32) @test partials(NaNMath.pow(Dual(-2.0, 1.0), Dual(2.0, 0.0)), 1) == -4.0 + @test fma(FDNUM, FDNUM2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), + PRIMAL*PARTIALS2 + PRIMAL2*PARTIALS + + PARTIALS3) + @test fma(FDNUM, FDNUM2, PRIMAL3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), + PRIMAL*PARTIALS2 + PRIMAL2*PARTIALS) + @test fma(PRIMAL, FDNUM2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), + PRIMAL*PARTIALS2 + PARTIALS3) + @test fma(PRIMAL, FDNUM2, PRIMAL3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), + PRIMAL*PARTIALS2) + @test fma(FDNUM, PRIMAL2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), + PRIMAL2*PARTIALS + PARTIALS3) + @test fma(FDNUM, PRIMAL2, PRIMAL3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), + PRIMAL2*PARTIALS) + @test fma(PRIMAL, PRIMAL2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), PARTIALS3) + # Unary Functions # #-----------------#