Skip to content

Commit fc6ee60

Browse files
committed
fix wrong partials multiplied in FMA
1 parent 6c61b61 commit fc6ee60

File tree

3 files changed

+232
-15
lines changed

3 files changed

+232
-15
lines changed

src/dual.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -414,13 +414,13 @@ end
414414
vx, vy = value(x), value(y)
415415
result = fma(vx, vy, value(z))
416416
return Dual(result,
417-
_mul_partials(partials(x), partials(y), vx, vy) + partials(z))
417+
_mul_partials(partials(x), partials(y), vy, vx) + partials(z))
418418
end
419419

420420
@inline function Base.fma(x::Dual, y::Dual, z::Real)
421421
vx, vy = value(x), value(y)
422422
result = fma(vx, vy, z)
423-
return Dual(result, _mul_partials(partials(x), partials(y), vx, vy))
423+
return Dual(result, _mul_partials(partials(x), partials(y), vy, vx))
424424
end
425425

426426
@inline function Base.fma(x::Dual, y::Real, z::Dual)

src/partials.jl.mem

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
- immutable Partials{N,T} <: AbstractVector{T}
2+
- values::NTuple{N,T}
3+
- end
4+
-
5+
- ##############################
6+
- # Utility/Accessor Functions #
7+
- ##############################
8+
-
9+
- @inline valtype{N,T}(::Partials{N,T}) = T
10+
- @inline valtype{N,T}(::Type{Partials{N,T}}) = T
11+
-
12+
- @inline npartials{N}(::Partials{N}) = N
13+
- @inline npartials{N,T}(::Type{Partials{N,T}}) = N
14+
-
15+
- @inline Base.length{N}(::Partials{N}) = N
16+
- @inline Base.size{N}(::Partials{N}) = (N,)
17+
-
18+
- @inline Base.getindex(partials::Partials, i::Int) = partials.values[i]
19+
- setindex{N,T}(partials::Partials{N,T}, v, i) = Partials{N,T}((partials[1:i-1]..., v, partials[i+1:N]...))
20+
-
21+
- Base.start(partials::Partials) = start(partials.values)
22+
- Base.next(partials::Partials, i) = next(partials.values, i)
23+
- Base.done(partials::Partials, i) = done(partials.values, i)
24+
-
25+
- Base.linearindexing(::Partials) = Base.LinearFast()
26+
-
27+
- #####################
28+
- # Generic Functions #
29+
- #####################
30+
-
31+
- @inline iszero(partials::Partials) = iszero_tuple(partials.values)
32+
-
33+
- @inline Base.zero(partials::Partials) = zero(typeof(partials))
34+
- @inline Base.zero{N,T}(::Type{Partials{N,T}}) = Partials{N,T}(zero_tuple(NTuple{N,T}))
35+
-
36+
- @inline Base.one(partials::Partials) = one(typeof(partials))
37+
- @inline Base.one{N,T}(::Type{Partials{N,T}}) = Partials{N,T}(one_tuple(NTuple{N,T}))
38+
-
39+
- @inline Base.rand(partials::Partials) = rand(typeof(partials))
40+
- @inline Base.rand{N,T}(::Type{Partials{N,T}}) = Partials{N,T}(rand_tuple(NTuple{N,T}))
41+
- @inline Base.rand(rng::AbstractRNG, partials::Partials) = rand(rng, typeof(partials))
42+
- @inline Base.rand{N,T}(rng::AbstractRNG, ::Type{Partials{N,T}}) = Partials{N,T}(rand_tuple(rng, NTuple{N,T}))
43+
-
44+
- Base.isequal{N}(a::Partials{N}, b::Partials{N}) = isequal(a.values, b.values)
45+
- @compat(Base.:(==)){N}(a::Partials{N}, b::Partials{N}) = a.values == b.values
46+
-
47+
- const PARTIALS_HASH = hash(Partials)
48+
-
49+
- Base.hash(partials::Partials) = hash(partials.values, PARTIALS_HASH)
50+
- Base.hash(partials::Partials, hsh::UInt64) = hash(hash(partials), hsh)
51+
-
52+
- @inline Base.copy(partials::Partials) = partials
53+
-
54+
- Base.read{N,T}(io::IO, ::Type{Partials{N,T}}) = Partials{N,T}(ntuple(i->read(io, T), Val{N}))
55+
-
56+
- function Base.write(io::IO, partials::Partials)
57+
- for p in partials
58+
- write(io, p)
59+
- end
60+
- end
61+
-
62+
- ########################
63+
- # Conversion/Promotion #
64+
- ########################
65+
-
66+
- Base.promote_rule{N,A,B}(::Type{Partials{N,A}}, ::Type{Partials{N,B}}) = Partials{N,promote_type(A, B)}
67+
-
68+
- Base.convert{N,T}(::Type{Partials{N,T}}, partials::Partials) = Partials{N,T}(partials.values)
69+
- Base.convert{N,T}(::Type{Partials{N,T}}, partials::Partials{N,T}) = partials
70+
-
71+
- ########################
72+
- # Arithmetic Functions #
73+
- ########################
74+
-
75+
- @inline @compat(Base.:+){N}(a::Partials{N}, b::Partials{N}) = Partials(add_tuples(a.values, b.values))
76+
- @inline @compat(Base.:-){N}(a::Partials{N}, b::Partials{N}) = Partials(sub_tuples(a.values, b.values))
77+
- @inline @compat(Base.:-)(partials::Partials) = Partials(minus_tuple(partials.values))
78+
- @inline @compat(Base.:*)(x::Real, partials::Partials) = partials*x
79+
-
80+
- @inline function _div_partials(a::Partials, b::Partials, aval, bval)
81+
- return _mul_partials(a, b, inv(bval), -(aval / (bval*bval)))
82+
- end
83+
-
84+
- # NaN/Inf-safe methods #
85+
- #----------------------#
86+
-
87+
- if NANSAFE_MODE_ENABLED
88+
- @inline function @compat(Base.:*)(partials::Partials, x::Real)
89+
- x = ifelse(!isfinite(x) && iszero(partials), one(x), x)
90+
- return Partials(scale_tuple(partials.values, x))
91+
- end
92+
-
93+
- @inline function @compat(Base.:/)(partials::Partials, x::Real)
94+
- x = ifelse(x == zero(x) && iszero(partials), one(x), x)
95+
- return Partials(div_tuple_by_scalar(partials.values, x))
96+
- end
97+
-
98+
- @inline function _mul_partials{N}(a::Partials{N}, b::Partials{N}, x_a, x_b)
99+
- x_a = ifelse(!isfinite(x_a) && iszero(a), one(x_a), x_a)
100+
- x_b = ifelse(!isfinite(x_b) && iszero(b), one(x_b), x_b)
101+
- return Partials(mul_tuples(a.values, b.values, x_a, x_b))
102+
- end
103+
- else
104+
- @inline function @compat(Base.:*)(partials::Partials, x::Real)
105+
- return Partials(scale_tuple(partials.values, x))
106+
- end
107+
-
108+
- @inline function @compat(Base.:/)(partials::Partials, x::Real)
109+
- return Partials(div_tuple_by_scalar(partials.values, x))
110+
- end
111+
-
112+
- @inline function _mul_partials{N}(a::Partials{N}, b::Partials{N}, x_a, x_b)
113+
- return Partials(mul_tuples(a.values, b.values, x_a, x_b))
114+
- end
115+
- end
116+
-
117+
- # edge cases where N == 0 #
118+
- #-------------------------#
119+
-
120+
- @inline @compat(Base.:+){A,B}(a::Partials{0,A}, b::Partials{0,B}) = Partials{0,promote_type(A,B)}(tuple())
121+
- @inline @compat(Base.:-){A,B}(a::Partials{0,A}, b::Partials{0,B}) = Partials{0,promote_type(A,B)}(tuple())
122+
- @inline @compat(Base.:-){T}(partials::Partials{0,T}) = partials
123+
- @inline @compat(Base.:*){T}(partials::Partials{0,T}, x::Real) = Partials{0,promote_type(T,typeof(x))}(tuple())
124+
- @inline @compat(Base.:*){T}(x::Real, partials::Partials{0,T}) = Partials{0,promote_type(T,typeof(x))}(tuple())
125+
- @inline @compat(Base.:/){T}(partials::Partials{0,T}, x::Real) = Partials{0,promote_type(T,typeof(x))}(tuple())
126+
-
127+
- @inline _mul_partials{A,B}(a::Partials{0,A}, b::Partials{0,B}, afactor, bfactor) = Partials{0,promote_type(A,B)}(tuple())
128+
- @inline _div_partials{A,B}(a::Partials{0,A}, b::Partials{0,B}, afactor, bfactor) = Partials{0,promote_type(A,B)}(tuple())
129+
-
130+
- ##################################
131+
- # Generated Functions on NTuples #
132+
- ##################################
133+
- # The below functions are generally
134+
- # equivalent to directly mapping over
135+
- # tuples using `map`, but run a bit
136+
- # faster since they generate inline code
137+
- # that doesn't rely on closures.
138+
-
139+
- function tupexpr(f, N)
140+
- ex = Expr(:tuple, [f(i) for i=1:N]...)
141+
4049855 return quote
142+
- $(Expr(:meta, :inline))
143+
- @inbounds return $ex
144+
- end
145+
- end
146+
-
147+
- @inline iszero_tuple(::Tuple{}) = true
148+
- @inline zero_tuple(::Type{Tuple{}}) = tuple()
149+
- @inline one_tuple(::Type{Tuple{}}) = tuple()
150+
- @inline rand_tuple(::AbstractRNG, ::Type{Tuple{}}) = tuple()
151+
- @inline rand_tuple(::Type{Tuple{}}) = tuple()
152+
-
153+
- @generated function iszero_tuple{N,T}(tup::NTuple{N,T})
154+
- ex = Expr(:&&, [:(z == tup[$i]) for i=1:N]...)
155+
- return quote
156+
- z = zero(T)
157+
- $(Expr(:meta, :inline))
158+
- @inbounds return $ex
159+
- end
160+
- end
161+
-
162+
- @generated function zero_tuple{N,T}(::Type{NTuple{N,T}})
163+
- ex = tupexpr(i -> :(z), N)
164+
784 return quote
165+
- z = zero(T)
166+
- return $ex
167+
- end
168+
- end
169+
-
170+
- @generated function one_tuple{N,T}(::Type{NTuple{N,T}})
171+
- ex = tupexpr(i -> :(z), N)
172+
- return quote
173+
- z = one(T)
174+
- return $ex
175+
- end
176+
- end
177+
-
178+
- @generated function rand_tuple{N,T}(rng::AbstractRNG, ::Type{NTuple{N,T}})
179+
- return tupexpr(i -> :(rand(rng, T)), N)
180+
- end
181+
-
182+
- @generated function rand_tuple{N,T}(::Type{NTuple{N,T}})
183+
- return tupexpr(i -> :(rand(T)), N)
184+
- end
185+
-
186+
- @generated function scale_tuple{N}(tup::NTuple{N}, x)
187+
- return tupexpr(i -> :(tup[$i] * x), N)
188+
- end
189+
-
190+
- @generated function div_tuple_by_scalar{N}(tup::NTuple{N}, x)
191+
- return tupexpr(i -> :(tup[$i] / x), N)
192+
- end
193+
-
194+
- @generated function add_tuples{N}(a::NTuple{N}, b::NTuple{N})
195+
- return tupexpr(i -> :(a[$i] + b[$i]), N)
196+
- end
197+
-
198+
- @generated function sub_tuples{N}(a::NTuple{N}, b::NTuple{N})
199+
- return tupexpr(i -> :(a[$i] - b[$i]), N)
200+
- end
201+
-
202+
- @generated function minus_tuple{N}(tup::NTuple{N})
203+
- return tupexpr(i -> :(-tup[$i]), N)
204+
- end
205+
-
206+
- @generated function mul_tuples{N}(a::NTuple{N}, b::NTuple{N}, afactor, bfactor)
207+
- return tupexpr(i -> :((afactor * a[$i]) + (bfactor * b[$i])), N)
208+
- end
209+
-
210+
- ###################
211+
- # Pretty Printing #
212+
- ###################
213+
-
214+
- Base.show{N}(io::IO, p::Partials{N}) = print(io, "Partials", p.values)
215+
-

test/DualTest.jl

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ samerng() = MersenneTwister(1)
1515
# exponent by one
1616
intrand(T) = T == Int ? rand(2:10) : rand(T)
1717

18+
dualapprox(A, B) = value(A) value(B) && partials(A) partials(B)
19+
1820
# fix testing issue with Base.hypot(::Int...) undefined in 0.4
1921
if v"0.4" <= VERSION < v"0.5"
2022
Base.hypot(x::Int, y::Int) = Base.hypot(Float64(x), Float64(y))
@@ -387,20 +389,20 @@ for N in (0,3), M in (0,4), T in (Int, Float32)
387389

388390
@test partials(NaNMath.pow(Dual(-2.0, 1.0), Dual(2.0, 0.0)), 1) == -4.0
389391

390-
@test fma(FDNUM, FDNUM2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
392+
@test dualapprox(fma(FDNUM, FDNUM2, FDNUM3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
391393
PRIMAL*PARTIALS2 + PRIMAL2*PARTIALS +
392-
PARTIALS3)
393-
@test fma(FDNUM, FDNUM2, PRIMAL3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
394-
PRIMAL*PARTIALS2 + PRIMAL2*PARTIALS)
395-
@test fma(PRIMAL, FDNUM2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
396-
PRIMAL*PARTIALS2 + PARTIALS3)
397-
@test fma(PRIMAL, FDNUM2, PRIMAL3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
398-
PRIMAL*PARTIALS2)
399-
@test fma(FDNUM, PRIMAL2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
400-
PRIMAL2*PARTIALS + PARTIALS3)
401-
@test fma(FDNUM, PRIMAL2, PRIMAL3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
402-
PRIMAL2*PARTIALS)
403-
@test fma(PRIMAL, PRIMAL2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), PARTIALS3)
394+
PARTIALS3))
395+
@test dualapprox(fma(FDNUM, FDNUM2, PRIMAL3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
396+
PRIMAL*PARTIALS2 + PRIMAL2*PARTIALS))
397+
@test dualapprox(fma(PRIMAL, FDNUM2, FDNUM3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
398+
PRIMAL*PARTIALS2 + PARTIALS3))
399+
@test dualapprox(fma(PRIMAL, FDNUM2, PRIMAL3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
400+
PRIMAL*PARTIALS2))
401+
@test dualapprox(fma(FDNUM, PRIMAL2, FDNUM3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
402+
PRIMAL2*PARTIALS + PARTIALS3))
403+
@test dualapprox(fma(FDNUM, PRIMAL2, PRIMAL3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
404+
PRIMAL2*PARTIALS))
405+
@test dualapprox(fma(PRIMAL, PRIMAL2, FDNUM3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), PARTIALS3))
404406

405407
# Unary Functions #
406408
#-----------------#

0 commit comments

Comments
 (0)