Skip to content

Commit f72be8d

Browse files
KristofferCYingboMa
andcommitted
use SIMD.jl for explicit vectorization of partial operations
Alternative to #555 Co-authored-by: Yingbo Ma <[email protected]>
1 parent 0af523a commit f72be8d

File tree

4 files changed

+347
-25
lines changed

4 files changed

+347
-25
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
1212
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
1313
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1414
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
15+
SIMD = "fdea26ae-647d-5447-a871-4b548cad5224"
1516
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1617
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1718

@@ -24,9 +25,10 @@ DiffTests = "0.0.1, 0.1"
2425
LogExpFunctions = "0.3"
2526
NaNMath = "0.2.2, 0.3"
2627
Preferences = "1"
28+
SIMD = "3"
2729
SpecialFunctions = "0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10, 1.0"
2830
StaticArrays = "0.8.3, 0.9, 0.10, 0.11, 0.12, 1.0"
29-
julia = "1"
31+
julia = "1.6"
3032

3133
[extras]
3234
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"

src/partials.jl

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -197,29 +197,35 @@ end
197197
return tupexpr(i -> :(rand(V)), N)
198198
end
199199

200-
@generated function scale_tuple(tup::NTuple{N}, x) where N
201-
return tupexpr(i -> :(tup[$i] * x), N)
202-
end
203-
204-
@generated function div_tuple_by_scalar(tup::NTuple{N}, x) where N
205-
return tupexpr(i -> :(tup[$i] / x), N)
206-
end
207-
208-
@generated function add_tuples(a::NTuple{N}, b::NTuple{N}) where N
209-
return tupexpr(i -> :(a[$i] + b[$i]), N)
210-
end
211-
212-
@generated function sub_tuples(a::NTuple{N}, b::NTuple{N}) where N
213-
return tupexpr(i -> :(a[$i] - b[$i]), N)
214-
end
215-
216-
@generated function minus_tuple(tup::NTuple{N}) where N
217-
return tupexpr(i -> :(-tup[$i]), N)
218-
end
219-
220-
@generated function mul_tuples(a::NTuple{N}, b::NTuple{N}, afactor, bfactor) where N
221-
return tupexpr(i -> :((afactor * a[$i]) + (bfactor * b[$i])), N)
222-
end
200+
# LLVM versions before ~ version 6 had problems with certain vector lengths
201+
const HAS_FLEXIBLE_VECTOR_LENGTH = VERSION >= v"1.6"
202+
203+
const SIMDFloat = Union{Float64, Float32}
204+
const SIMDInt = Union{
205+
Int128, Int64, Int32, Int16, Int8,
206+
UInt128, UInt64, UInt32, UInt16, UInt8,
207+
}
208+
const SIMDType = Union{SIMDFloat, SIMDInt}
209+
210+
using SIMD
211+
212+
const NT{N,T} = NTuple{N,T}
213+
214+
# SIMD implementation
215+
add_tuples(a::NT{N,T}, b::NT{N,T}) where {N, T<:SIMDType} = Tuple(Vec(a) + Vec(b))
216+
sub_tuples(a::NT{N,T}, b::NT{N,T}) where {N, T<:SIMDType} = Tuple(Vec(a) - Vec(b))
217+
scale_tuple(tup::NT{N,T}, x::T) where {N, T<:SIMDType} = Tuple(Vec(tup) * x)
218+
div_tuple_by_scalar(tup::NT{N,T}, x::T) where {N, T<:SIMDFloat} = Tuple(Vec(tup) / x)
219+
minus_tuple(tup::NT{N,T}) where {N, T<:SIMDType} = Tuple(-Vec(tup))
220+
mul_tuples(a::NT{N,T}, b::NT{N,T}, af::T, bf::T) where {N, T<:SIMDType} = Tuple(muladd(Vec{N,T}(af), Vec(a), Vec{N,T}(bf) * Vec(b)))
221+
222+
# Fallback implementations
223+
@generated add_tuples(a::NT{N}, b::NT{N}) where N = tupexpr(i -> :(a[$i] + b[$i]), N)
224+
@generated sub_tuples(a::NT{N}, b::NT{N}) where N = tupexpr(i -> :(a[$i] - b[$i]), N)
225+
@generated scale_tuple(tup::NT{N}, x) where N = tupexpr(i -> :(tup[$i] * x), N)
226+
@generated div_tuple_by_scalar(tup::NT{N}, x) where N = tupexpr(i -> :(tup[$i] / x), N)
227+
@generated minus_tuple(tup::NT{N}) where N = tupexpr(i -> :(-tup[$i]), N)
228+
@generated mul_tuples(a::NT{N}, b::NT{N}, af, bf) where N = tupexpr(i -> :((af * a[$i]) + (bf * b[$i])), N)
223229

224230
###################
225231
# Pretty Printing #

src/partials2.jl

Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
struct Partials{N,V} <: AbstractVector{V}
2+
values::NTuple{N,V}
3+
end
4+
5+
##############################
6+
# Utility/Accessor Functions #
7+
##############################
8+
9+
@generated function single_seed(::Type{Partials{N,V}}, ::Val{i}) where {N,V,i}
10+
ex = Expr(:tuple, [ifelse(i === j, :(one(V)), :(zero(V))) for j in 1:N]...)
11+
return :(Partials($(ex)))
12+
end
13+
14+
@inline valtype(::Partials{N,V}) where {N,V} = V
15+
@inline valtype(::Type{Partials{N,V}}) where {N,V} = V
16+
17+
@inline npartials(::Partials{N}) where {N} = N
18+
@inline npartials(::Type{Partials{N,V}}) where {N,V} = N
19+
20+
@inline Base.length(::Partials{N}) where {N} = N
21+
@inline Base.size(::Partials{N}) where {N} = (N,)
22+
23+
@inline Base.@propagate_inbounds Base.getindex(partials::Partials, i::Int) = partials.values[i]
24+
25+
Base.iterate(partials::Partials) = iterate(partials.values)
26+
Base.iterate(partials::Partials, i) = iterate(partials.values, i)
27+
28+
Base.IndexStyle(::Type{<:Partials}) = IndexLinear()
29+
30+
# Can be deleted after https://github.com/JuliaLang/julia/pull/29854 is on a release
31+
Base.mightalias(x::AbstractArray, y::Partials) = false
32+
33+
#####################
34+
# Generic Functions #
35+
#####################
36+
37+
@inline Base.iszero(partials::Partials) = iszero_tuple(partials.values)
38+
39+
@inline Base.zero(partials::Partials) = zero(typeof(partials))
40+
@inline Base.zero(::Type{Partials{N,V}}) where {N,V} = Partials{N,V}(zero_tuple(NTuple{N,V}))
41+
42+
@inline Base.one(partials::Partials) = one(typeof(partials))
43+
@inline Base.one(::Type{Partials{N,V}}) where {N,V} = Partials{N,V}(one_tuple(NTuple{N,V}))
44+
45+
@inline Random.rand(partials::Partials) = rand(typeof(partials))
46+
@inline Random.rand(::Type{Partials{N,V}}) where {N,V} = Partials{N,V}(rand_tuple(NTuple{N,V}))
47+
@inline Random.rand(rng::AbstractRNG, partials::Partials) = rand(rng, typeof(partials))
48+
@inline Random.rand(rng::AbstractRNG, ::Type{Partials{N,V}}) where {N,V} = Partials{N,V}(rand_tuple(rng, NTuple{N,V}))
49+
50+
Base.isequal(a::Partials{N}, b::Partials{N}) where {N} = isequal(a.values, b.values)
51+
Base.:(==)(a::Partials{N}, b::Partials{N}) where {N} = a.values == b.values
52+
53+
const PARTIALS_HASH = hash(Partials)
54+
55+
Base.hash(partials::Partials) = hash(partials.values, PARTIALS_HASH)
56+
Base.hash(partials::Partials, hsh::UInt64) = hash(hash(partials), hsh)
57+
58+
@inline Base.copy(partials::Partials) = partials
59+
60+
Base.read(io::IO, ::Type{Partials{N,V}}) where {N,V} = Partials{N,V}(ntuple(i->read(io, V), N))
61+
62+
function Base.write(io::IO, partials::Partials)
63+
for p in partials
64+
write(io, p)
65+
end
66+
end
67+
68+
########################
69+
# Conversion/Promotion #
70+
########################
71+
72+
Base.promote_rule(::Type{Partials{N,A}}, ::Type{Partials{N,B}}) where {N,A,B} = Partials{N,promote_type(A, B)}
73+
74+
Base.convert(::Type{Partials{N,V}}, partials::Partials) where {N,V} = Partials{N,V}(partials.values)
75+
Base.convert(::Type{Partials{N,V}}, partials::Partials{N,V}) where {N,V} = partials
76+
77+
########################
78+
# Arithmetic Functions #
79+
########################
80+
81+
@inline Base.:+(a::Partials{N}, b::Partials{N}) where {N} = Partials(add_tuples(a.values, b.values))
82+
@inline Base.:-(a::Partials{N}, b::Partials{N}) where {N} = Partials(sub_tuples(a.values, b.values))
83+
@inline Base.:-(partials::Partials) = Partials(minus_tuple(partials.values))
84+
@inline Base.:*(x::Real, partials::Partials) = partials*x
85+
86+
@inline function _div_partials(a::Partials, b::Partials, aval, bval)
87+
return _mul_partials(a, b, inv(bval), -(aval / (bval*bval)))
88+
end
89+
90+
# NaN/Inf-safe methods #
91+
#----------------------#
92+
93+
if NANSAFE_MODE_ENABLED
94+
@inline function Base.:*(partials::Partials, x::Real)
95+
x = ifelse(!isfinite(x) && iszero(partials), one(x), x)
96+
return Partials(scale_tuple(partials.values, x))
97+
end
98+
99+
@inline function Base.:/(partials::Partials, x::Real)
100+
x = ifelse(x == zero(x) && iszero(partials), one(x), x)
101+
return Partials(div_tuple_by_scalar(partials.values, x))
102+
end
103+
104+
@inline function _mul_partials(a::Partials{N}, b::Partials{N}, x_a, x_b) where N
105+
x_a = ifelse(!isfinite(x_a) && iszero(a), one(x_a), x_a)
106+
x_b = ifelse(!isfinite(x_b) && iszero(b), one(x_b), x_b)
107+
return Partials(mul_tuples(a.values, b.values, x_a, x_b))
108+
end
109+
else
110+
@inline function Base.:*(partials::Partials, x::Real)
111+
return Partials(scale_tuple(partials.values, x))
112+
end
113+
114+
@inline function Base.:/(partials::Partials, x::Real)
115+
return Partials(div_tuple_by_scalar(partials.values, x))
116+
end
117+
118+
@inline function _mul_partials(a::Partials{N}, b::Partials{N}, x_a, x_b) where N
119+
return Partials(mul_tuples(a.values, b.values, x_a, x_b))
120+
end
121+
end
122+
123+
# edge cases where N == 0 #
124+
#-------------------------#
125+
126+
@inline Base.:+(a::Partials{0,A}, b::Partials{0,B}) where {A,B} = Partials{0,promote_type(A,B)}(tuple())
127+
@inline Base.:+(a::Partials{0,A}, b::Partials{N,B}) where {N,A,B} = convert(Partials{N,promote_type(A,B)}, b)
128+
@inline Base.:+(a::Partials{N,A}, b::Partials{0,B}) where {N,A,B} = convert(Partials{N,promote_type(A,B)}, a)
129+
130+
@inline Base.:-(a::Partials{0,A}, b::Partials{0,B}) where {A,B} = Partials{0,promote_type(A,B)}(tuple())
131+
@inline Base.:-(a::Partials{0,A}, b::Partials{N,B}) where {N,A,B} = -(convert(Partials{N,promote_type(A,B)}, b))
132+
@inline Base.:-(a::Partials{N,A}, b::Partials{0,B}) where {N,A,B} = convert(Partials{N,promote_type(A,B)}, a)
133+
@inline Base.:-(partials::Partials{0,V}) where {V} = partials
134+
135+
@inline Base.:*(partials::Partials{0,V}, x::Real) where {V} = Partials{0,promote_type(V,typeof(x))}(tuple())
136+
@inline Base.:*(x::Real, partials::Partials{0,V}) where {V} = Partials{0,promote_type(V,typeof(x))}(tuple())
137+
138+
@inline Base.:/(partials::Partials{0,V}, x::Real) where {V} = Partials{0,promote_type(V,typeof(x))}(tuple())
139+
140+
@inline _mul_partials(a::Partials{0,A}, b::Partials{0,B}, afactor, bfactor) where {A,B} = Partials{0,promote_type(A,B)}(tuple())
141+
@inline _mul_partials(a::Partials{0,A}, b::Partials{N,B}, afactor, bfactor) where {N,A,B} = bfactor * b
142+
@inline _mul_partials(a::Partials{N,A}, b::Partials{0,B}, afactor, bfactor) where {N,A,B} = afactor * a
143+
144+
##################################
145+
# Generated Functions on NTuples #
146+
##################################
147+
# The below functions are generally
148+
# equivalent to directly mapping over
149+
# tuples using `map`, but run a bit
150+
# faster since they generate inline code
151+
# that doesn't rely on closures.
152+
153+
function tupexpr(f, N)
154+
ex = Expr(:tuple, [f(i) for i=1:N]...)
155+
return quote
156+
$(Expr(:meta, :inline))
157+
@inbounds return $ex
158+
end
159+
end
160+
161+
@inline iszero_tuple(::Tuple{}) = true
162+
@inline zero_tuple(::Type{Tuple{}}) = tuple()
163+
@inline one_tuple(::Type{Tuple{}}) = tuple()
164+
@inline rand_tuple(::AbstractRNG, ::Type{Tuple{}}) = tuple()
165+
@inline rand_tuple(::Type{Tuple{}}) = tuple()
166+
167+
@generated function iszero_tuple(tup::NTuple{N,V}) where {N,V}
168+
ex = Expr(:&&, [:(z == tup[$i]) for i=1:N]...)
169+
return quote
170+
z = zero(V)
171+
$(Expr(:meta, :inline))
172+
@inbounds return $ex
173+
end
174+
end
175+
176+
@generated function zero_tuple(::Type{NTuple{N,V}}) where {N,V}
177+
ex = tupexpr(i -> :(z), N)
178+
return quote
179+
z = zero(V)
180+
return $ex
181+
end
182+
end
183+
184+
@generated function one_tuple(::Type{NTuple{N,V}}) where {N,V}
185+
ex = tupexpr(i -> :(z), N)
186+
return quote
187+
z = one(V)
188+
return $ex
189+
end
190+
end
191+
192+
@generated function rand_tuple(rng::AbstractRNG, ::Type{NTuple{N,V}}) where {N,V}
193+
return tupexpr(i -> :(rand(rng, V)), N)
194+
end
195+
196+
@generated function rand_tuple(::Type{NTuple{N,V}}) where {N,V}
197+
return tupexpr(i -> :(rand(V)), N)
198+
end
199+
200+
const SIMDFloat = Union{Float64, Float32}
201+
const SIMDInt = Union{
202+
Int128, Int64, Int32, Int16, Int8,
203+
UInt128, UInt64, UInt32, UInt16, UInt8,
204+
Bool
205+
}
206+
const SIMDType = Union{SIMDFloat, SIMDInt}
207+
208+
# This may not be a sharp bound, but at least people won't get worse result.
209+
const HAS_FLEXIABLE_VECTOR_LENGTH = VERSION >= v"1.6"
210+
211+
function julia_type_to_llvm_type(@nospecialize(T::DataType))
212+
T === Float64 ? "double" :
213+
T === Float32 ? "float" :
214+
T <: Union{Int128,UInt128} ? "i128" :
215+
T <: Union{Int64,UInt64} ? "i64" :
216+
T <: Union{Int32,UInt32} ? "i32" :
217+
T <: Union{Int16,UInt16} ? "i16" :
218+
T <: Union{Bool,Int8,UInt8} ? "i8" :
219+
error("$T cannot be mapped to a LLVM type")
220+
end
221+
222+
@generated function scale_tuple(tup::NTuple{N,T}, x::S) where {N,T,S}
223+
if !(HAS_FLEXIABLE_VECTOR_LENGTH && T === S && S <: SIMDType)
224+
return tupexpr(i -> :(tup[$i] * x), N)
225+
end
226+
227+
S = julia_type_to_llvm_type(T)
228+
VT = NTuple{N, VecElement{T}}
229+
op = T <: SIMDFloat ? "fmul nsz contract" : "mul"
230+
llvmir = """
231+
%el = insertelement <$N x $S> undef, $S %1, i32 0
232+
%vx = shufflevector <$N x $S> %el, <$N x $S> undef, <$N x i32> zeroinitializer
233+
%res = $op <$N x $S> %0, %vx
234+
ret <$N x $S> %res
235+
"""
236+
237+
quote
238+
$(Expr(:meta, :inline))
239+
ret = Base.llvmcall($llvmir, $VT, Tuple{$VT, $T}, $VT(tup), x)
240+
Base.@ntuple $N i->ret[i].value
241+
end
242+
end
243+
244+
@generated function div_tuple_by_scalar(tup::NTuple{N,T}, x::S) where {N,T,S}
245+
if !(HAS_FLEXIABLE_VECTOR_LENGTH && T === S === typeof(one(T) / one(S)) && S <: SIMDType)
246+
return tupexpr(i -> :(tup[$i] / x), N)
247+
end
248+
249+
S = julia_type_to_llvm_type(T)
250+
VT = NTuple{N, VecElement{T}}
251+
op = T <: SIMDFloat ? "fdiv nsz contract" : "div"
252+
llvmir = """
253+
%el = insertelement <$N x $S> undef, $S %1, i32 0
254+
%vx = shufflevector <$N x $S> %el, <$N x $S> undef, <$N x i32> zeroinitializer
255+
%res = $op <$N x $S> %0, %vx
256+
ret <$N x $S> %res
257+
"""
258+
259+
quote
260+
$(Expr(:meta, :inline))
261+
ret = Base.llvmcall($llvmir, $VT, Tuple{$VT, $T}, $VT(tup), x)
262+
Base.@ntuple $N i->ret[i].value
263+
end
264+
end
265+
266+
@generated function add_tuples(a::NTuple{N}, b::NTuple{N}) where N
267+
return tupexpr(i -> :(a[$i] + b[$i]), N)
268+
end
269+
270+
@generated function minus_tuple(tup::NTuple{N}) where N
271+
return tupexpr(i -> :(-tup[$i]), N)
272+
end
273+
274+
@generated function sub_tuples(a::NTuple{N}, b::NTuple{N}) where N
275+
return tupexpr(i -> :(a[$i] - b[$i]), N)
276+
end
277+
278+
279+
@generated function mul_tuples(a::NTuple{N,V1}, b::NTuple{N,V2}, afactor::S1, bfactor::S2) where {N,V1,V2,S1,S2}
280+
if !(HAS_FLEXIABLE_VECTOR_LENGTH && V1 === V2 === S1 === S2 && S2 <: SIMDFloat)
281+
return tupexpr(i -> :((afactor * a[$i]) + (bfactor * b[$i])), N)
282+
end
283+
284+
T = V1
285+
S = julia_type_to_llvm_type(T)
286+
fmuladd = "@llvm.fmuladd.v$(N)f$(sizeof(T)*8)"
287+
288+
VT = NTuple{N, VecElement{T}}
289+
llvmir = """
290+
declare <$N x $S> $fmuladd(<$N x $S>, <$N x $S>, <$N x $S>)
291+
292+
define <$N x $S> @entry(<$N x $S>, <$N x $S>, $S, $S) alwaysinline {
293+
top:
294+
%el1 = insertelement <$N x $S> undef, $S %2, i32 0
295+
%afactor = shufflevector <$N x $S> %el1, <$N x $S> undef, <$N x i32> zeroinitializer
296+
%el2 = insertelement <$N x $S> undef, $S %3, i32 0
297+
%bfactor = shufflevector <$N x $S> %el2, <$N x $S> undef, <$N x i32> zeroinitializer
298+
%tmp = fmul nsz contract <$N x $S> %1, %bfactor
299+
%res = call nsz contract <$N x $S> $fmuladd(<$N x $S> %0, <$N x $S> %afactor, <$N x $S> %tmp)
300+
ret <$N x $S> %res
301+
}
302+
"""
303+
quote
304+
$(Expr(:meta, :inline))
305+
ret = Base.llvmcall(($llvmir, "entry"), $VT, Tuple{$VT, $VT, $T, $T}, $VT(a), $VT(b), afactor, bfactor)
306+
Base.@ntuple $N i->ret[i].value
307+
end
308+
end
309+
310+
###################
311+
# Pretty Printing #
312+
###################
313+
314+
Base.show(io::IO, p::Partials{N}) where {N} = print(io, "Partials", p.values)

0 commit comments

Comments
 (0)