Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,38 @@ end
@ambiguous @inline Base.atan2(y::Real, x::Dual) = calc_atan2(y, x)
@ambiguous @inline Base.atan2(y::Dual, x::Real) = calc_atan2(y, x)

@inline function Base.fma(x::Dual, y::Dual, z::Dual)
vx, vy = value(x), value(y)
result = fma(vx, vy, value(z))
return Dual(result,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be better to define a fma function for Partials as well, so that fma instructions will be used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean for the

Base.fma(x::Dual, y::Real, z::Dual)

case? For the others I don't see how it would help.

Copy link
Collaborator

@KristofferC KristofferC Mar 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking something like (where R is the value part and E is the partials part):

Partials(x * y + z)
=
Partials[(Rx * [Ey1, Ey2, ...]) +
         (Ry * [Ex1, Ex2, ...])
             + [Ez1, Ez2, ...]
        ]
= 
[Rx * Ey1 + Ry * Ex1 + Ez1,
 Rx * Ey2 + Ry * Ex2 + Ez2,
 ...]
= 
[fma(Rx, Ey1, fma(Ry, Ex1, Ez1)),
 fma(Rx, Ey2, fma(Ry, Ex2, Ez2)),
 ...]

An implementation would be:

@generated function Base.fma{N}(x::Dual{N}, y::Dual{N}, z::Dual{N})
    ex = Expr(:tuple, [:(fma(value(x), partials(y)[$i], fma(value(y), partials(x)[$i], partials(z)[$i]))) for i in 1:N]...)
    return quote
        v = fma(value(x), value(y), value(z))
        Dual(v, $ex)
    end
end

so that:

julia> x, y, z = Dual(rand(9)...) , Dual(rand(9)...), Dual(rand(9)...);

julia> x*y + z - fma(x,y,z)
Dual(-5.551115123125783e-17,0.0,0.0,0.0,0.0,0.0,0.0,-2.220446049250313e-16,0.0)

This gives a bunch of fused multiply adds in the generated code:

	vfmadd213sd	8(%rcx), %xmm2, %xmm8
	vfmadd231sd	8(%rdx), %xmm1, %xmm8

but I'm not sure how big of a performance difference it is in general.

Copy link
Collaborator

@KristofferC KristofferC Mar 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some benchmarking with and without O3 shows that the fma version is just barely (5-10%) faster than a function f(x,y,z) = x*y + z (with my processor at least). Might not be worth the extra complexity.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that the point of fma is not speed but precision. Try

for i in 1:100
    x, y, z = rand(3)
    if fma(x,y,z) != x*y+z
        println("*")
    end
end

Copy link
Collaborator

@KristofferC KristofferC Mar 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it was both. Anyway, regarding accuracy, I guess the function I posted has a purpose then, since it does more fusing...?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly. Could you perhaps merge my PR as is, then you could add your function?

Copy link
Collaborator

@KristofferC KristofferC Mar 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I let @jrevels do the merging since this is his baby :)

_mul_partials(partials(x), partials(y), vx, vy) + partials(z))
end

@inline function Base.fma(x::Dual, y::Dual, z::Real)
vx, vy = value(x), value(y)
result = fma(vx, vy, z)
return Dual(result, _mul_partials(partials(x), partials(y), vx, vy))
end

@inline function Base.fma(x::Dual, y::Real, z::Dual)
vx = value(x)
result = fma(vx, y, value(z))
return Dual(result, partials(x) * y + partials(z))
end

@inline Base.fma(x::Real, y::Dual, z::Dual) = fma(y, x, z)

@inline function Base.fma(x::Dual, y::Real, z::Real)
vx = value(x)
return Dual(fma(vx, y, value(z)), partials(x) * y)
end

@inline Base.fma(x::Real, y::Dual, z::Real) = fma(y, x, z)

@inline function Base.fma(x::Real, y::Real, z::Dual)
Dual(fma(x, y, value(z)), partials(z))
end

###################
# Pretty Printing #
###################
Expand Down
19 changes: 19 additions & 0 deletions test/DualTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ for N in (0,3), M in (0,4), T in (Int, Float32)
PRIMAL2 = intrand(T)
FDNUM2 = Dual(PRIMAL2, PARTIALS2)

PARTIALS3 = Partials{N,T}(ntuple(n -> intrand(T), Val{N}))
PRIMAL3 = intrand(T)
FDNUM3 = Dual(PRIMAL3, PARTIALS3)

M_PARTIALS = Partials{M,T}(ntuple(m -> intrand(T), Val{M}))
NESTED_PARTIALS = convert(Partials{N,Dual{M,T}}, PARTIALS)
NESTED_FDNUM = Dual(Dual(PRIMAL, M_PARTIALS), NESTED_PARTIALS)
Expand Down Expand Up @@ -383,6 +387,21 @@ for N in (0,3), M in (0,4), T in (Int, Float32)

@test partials(NaNMath.pow(Dual(-2.0, 1.0), Dual(2.0, 0.0)), 1) == -4.0

@test fma(FDNUM, FDNUM2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
PRIMAL*PARTIALS2 + PRIMAL2*PARTIALS +
PARTIALS3)
@test fma(FDNUM, FDNUM2, PRIMAL3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
PRIMAL*PARTIALS2 + PRIMAL2*PARTIALS)
@test fma(PRIMAL, FDNUM2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
PRIMAL*PARTIALS2 + PARTIALS3)
@test fma(PRIMAL, FDNUM2, PRIMAL3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
PRIMAL*PARTIALS2)
@test fma(FDNUM, PRIMAL2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
PRIMAL2*PARTIALS + PARTIALS3)
@test fma(FDNUM, PRIMAL2, PRIMAL3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
PRIMAL2*PARTIALS)
@test fma(PRIMAL, PRIMAL2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), PARTIALS3)

# Unary Functions #
#-----------------#

Expand Down