From 280be369dbdb0af4646aefc3f1ee4c773d6be22a Mon Sep 17 00:00:00 2001 From: Kristoffer Carlsson Date: Sat, 18 Mar 2017 21:00:51 +0100 Subject: [PATCH] fix wrong partials multiplied in FMA --- src/dual.jl | 4 ++-- test/DualTest.jl | 26 +++++++++++++------------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/dual.jl b/src/dual.jl index dfed4927..0cfb7182 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -414,13 +414,13 @@ end vx, vy = value(x), value(y) result = fma(vx, vy, value(z)) return Dual(result, - _mul_partials(partials(x), partials(y), vx, vy) + partials(z)) + _mul_partials(partials(x), partials(y), vy, vx) + partials(z)) end @inline function Base.fma(x::Dual, y::Dual, z::Real) vx, vy = value(x), value(y) result = fma(vx, vy, z) - return Dual(result, _mul_partials(partials(x), partials(y), vx, vy)) + return Dual(result, _mul_partials(partials(x), partials(y), vy, vx)) end @inline function Base.fma(x::Dual, y::Real, z::Dual) diff --git a/test/DualTest.jl b/test/DualTest.jl index ef96341b..53dc610e 100644 --- a/test/DualTest.jl +++ b/test/DualTest.jl @@ -387,20 +387,20 @@ for N in (0,3), M in (0,4), T in (Int, Float32) @test partials(NaNMath.pow(Dual(-2.0, 1.0), Dual(2.0, 0.0)), 1) == -4.0 - @test fma(FDNUM, FDNUM2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), + test_approx_diffnums(fma(FDNUM, FDNUM2, FDNUM3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), PRIMAL*PARTIALS2 + PRIMAL2*PARTIALS + - PARTIALS3) - @test fma(FDNUM, FDNUM2, PRIMAL3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), - PRIMAL*PARTIALS2 + PRIMAL2*PARTIALS) - @test fma(PRIMAL, FDNUM2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), - PRIMAL*PARTIALS2 + PARTIALS3) - @test fma(PRIMAL, FDNUM2, PRIMAL3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), - PRIMAL*PARTIALS2) - @test fma(FDNUM, PRIMAL2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), - PRIMAL2*PARTIALS + PARTIALS3) - @test fma(FDNUM, PRIMAL2, PRIMAL3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), - PRIMAL2*PARTIALS) - @test fma(PRIMAL, PRIMAL2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), PARTIALS3) + PARTIALS3)) + test_approx_diffnums(fma(FDNUM, FDNUM2, PRIMAL3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), + PRIMAL*PARTIALS2 + PRIMAL2*PARTIALS)) + test_approx_diffnums(fma(PRIMAL, FDNUM2, FDNUM3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), + PRIMAL*PARTIALS2 + PARTIALS3)) + test_approx_diffnums(fma(PRIMAL, FDNUM2, PRIMAL3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), + PRIMAL*PARTIALS2)) + test_approx_diffnums(fma(FDNUM, PRIMAL2, FDNUM3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), + PRIMAL2*PARTIALS + PARTIALS3)) + test_approx_diffnums(fma(FDNUM, PRIMAL2, PRIMAL3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), + PRIMAL2*PARTIALS)) + test_approx_diffnums(fma(PRIMAL, PRIMAL2, FDNUM3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), PARTIALS3)) # Unary Functions # #-----------------#