Skip to content

Commit 91749f4

Browse files
tpappjrevels
authored andcommitted
Added fma, with tests. (#203)
Tests required a third value set (PRIMAL etc) in the testing framework, which was also added. Fixes issue #202.
1 parent 54bdd2c commit 91749f4

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

src/dual.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,38 @@ end
410410
@ambiguous @inline Base.atan2(y::Real, x::Dual) = calc_atan2(y, x)
411411
@ambiguous @inline Base.atan2(y::Dual, x::Real) = calc_atan2(y, x)
412412

413+
@inline function Base.fma(x::Dual, y::Dual, z::Dual)
414+
vx, vy = value(x), value(y)
415+
result = fma(vx, vy, value(z))
416+
return Dual(result,
417+
_mul_partials(partials(x), partials(y), vx, vy) + partials(z))
418+
end
419+
420+
@inline function Base.fma(x::Dual, y::Dual, z::Real)
421+
vx, vy = value(x), value(y)
422+
result = fma(vx, vy, z)
423+
return Dual(result, _mul_partials(partials(x), partials(y), vx, vy))
424+
end
425+
426+
@inline function Base.fma(x::Dual, y::Real, z::Dual)
427+
vx = value(x)
428+
result = fma(vx, y, value(z))
429+
return Dual(result, partials(x) * y + partials(z))
430+
end
431+
432+
@inline Base.fma(x::Real, y::Dual, z::Dual) = fma(y, x, z)
433+
434+
@inline function Base.fma(x::Dual, y::Real, z::Real)
435+
vx = value(x)
436+
return Dual(fma(vx, y, value(z)), partials(x) * y)
437+
end
438+
439+
@inline Base.fma(x::Real, y::Dual, z::Real) = fma(y, x, z)
440+
441+
@inline function Base.fma(x::Real, y::Real, z::Dual)
442+
Dual(fma(x, y, value(z)), partials(z))
443+
end
444+
413445
###################
414446
# Pretty Printing #
415447
###################

test/DualTest.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ for N in (0,3), M in (0,4), T in (Int, Float32)
4848
PRIMAL2 = intrand(T)
4949
FDNUM2 = Dual(PRIMAL2, PARTIALS2)
5050

51+
PARTIALS3 = Partials{N,T}(ntuple(n -> intrand(T), Val{N}))
52+
PRIMAL3 = intrand(T)
53+
FDNUM3 = Dual(PRIMAL3, PARTIALS3)
54+
5155
M_PARTIALS = Partials{M,T}(ntuple(m -> intrand(T), Val{M}))
5256
NESTED_PARTIALS = convert(Partials{N,Dual{M,T}}, PARTIALS)
5357
NESTED_FDNUM = Dual(Dual(PRIMAL, M_PARTIALS), NESTED_PARTIALS)
@@ -383,6 +387,21 @@ for N in (0,3), M in (0,4), T in (Int, Float32)
383387

384388
@test partials(NaNMath.pow(Dual(-2.0, 1.0), Dual(2.0, 0.0)), 1) == -4.0
385389

390+
@test fma(FDNUM, FDNUM2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
391+
PRIMAL*PARTIALS2 + PRIMAL2*PARTIALS +
392+
PARTIALS3)
393+
@test fma(FDNUM, FDNUM2, PRIMAL3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
394+
PRIMAL*PARTIALS2 + PRIMAL2*PARTIALS)
395+
@test fma(PRIMAL, FDNUM2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
396+
PRIMAL*PARTIALS2 + PARTIALS3)
397+
@test fma(PRIMAL, FDNUM2, PRIMAL3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
398+
PRIMAL*PARTIALS2)
399+
@test fma(FDNUM, PRIMAL2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
400+
PRIMAL2*PARTIALS + PARTIALS3)
401+
@test fma(FDNUM, PRIMAL2, PRIMAL3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
402+
PRIMAL2*PARTIALS)
403+
@test fma(PRIMAL, PRIMAL2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), PARTIALS3)
404+
386405
# Unary Functions #
387406
#-----------------#
388407

0 commit comments

Comments
 (0)