Skip to content

Commit b6d3636

Browse files
committed
fix wrong partials multiplied in FMA
1 parent 6c61b61 commit b6d3636

File tree

2 files changed

+17
-15
lines changed

2 files changed

+17
-15
lines changed

src/dual.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -414,13 +414,13 @@ end
414414
vx, vy = value(x), value(y)
415415
result = fma(vx, vy, value(z))
416416
return Dual(result,
417-
_mul_partials(partials(x), partials(y), vx, vy) + partials(z))
417+
_mul_partials(partials(x), partials(y), vy, vx) + partials(z))
418418
end
419419

420420
@inline function Base.fma(x::Dual, y::Dual, z::Real)
421421
vx, vy = value(x), value(y)
422422
result = fma(vx, vy, z)
423-
return Dual(result, _mul_partials(partials(x), partials(y), vx, vy))
423+
return Dual(result, _mul_partials(partials(x), partials(y), vy, vx))
424424
end
425425

426426
@inline function Base.fma(x::Dual, y::Real, z::Dual)

test/DualTest.jl

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ samerng() = MersenneTwister(1)
1515
# exponent by one
1616
intrand(T) = T == Int ? rand(2:10) : rand(T)
1717

18+
dualapprox(A, B) = value(A) value(B) && partials(A) partials(B)
19+
1820
# fix testing issue with Base.hypot(::Int...) undefined in 0.4
1921
if v"0.4" <= VERSION < v"0.5"
2022
Base.hypot(x::Int, y::Int) = Base.hypot(Float64(x), Float64(y))
@@ -387,20 +389,20 @@ for N in (0,3), M in (0,4), T in (Int, Float32)
387389

388390
@test partials(NaNMath.pow(Dual(-2.0, 1.0), Dual(2.0, 0.0)), 1) == -4.0
389391

390-
@test fma(FDNUM, FDNUM2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
392+
@test dualapprox(fma(FDNUM, FDNUM2, FDNUM3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
391393
PRIMAL*PARTIALS2 + PRIMAL2*PARTIALS +
392-
PARTIALS3)
393-
@test fma(FDNUM, FDNUM2, PRIMAL3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
394-
PRIMAL*PARTIALS2 + PRIMAL2*PARTIALS)
395-
@test fma(PRIMAL, FDNUM2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
396-
PRIMAL*PARTIALS2 + PARTIALS3)
397-
@test fma(PRIMAL, FDNUM2, PRIMAL3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
398-
PRIMAL*PARTIALS2)
399-
@test fma(FDNUM, PRIMAL2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
400-
PRIMAL2*PARTIALS + PARTIALS3)
401-
@test fma(FDNUM, PRIMAL2, PRIMAL3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
402-
PRIMAL2*PARTIALS)
403-
@test fma(PRIMAL, PRIMAL2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), PARTIALS3)
394+
PARTIALS3))
395+
@test dualapprox(fma(FDNUM, FDNUM2, PRIMAL3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
396+
PRIMAL*PARTIALS2 + PRIMAL2*PARTIALS))
397+
@test dualapprox(fma(PRIMAL, FDNUM2, FDNUM3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
398+
PRIMAL*PARTIALS2 + PARTIALS3))
399+
@test dualapprox(fma(PRIMAL, FDNUM2, PRIMAL3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
400+
PRIMAL*PARTIALS2))
401+
@test dualapprox(fma(FDNUM, PRIMAL2, FDNUM3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
402+
PRIMAL2*PARTIALS + PARTIALS3))
403+
@test dualapprox(fma(FDNUM, PRIMAL2, PRIMAL3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
404+
PRIMAL2*PARTIALS))
405+
@test dualapprox(fma(PRIMAL, PRIMAL2, FDNUM3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), PARTIALS3))
404406

405407
# Unary Functions #
406408
#-----------------#

0 commit comments

Comments
 (0)