Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 149 additions & 12 deletions src/partials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,28 +197,165 @@ end
return tupexpr(i -> :(rand(V)), N)
end

@generated function scale_tuple(tup::NTuple{N}, x) where N
return tupexpr(i -> :(tup[$i] * x), N)
const SIMDFloat = Union{Float64, Float32}
const SIMDInt = Union{
Int128, Int64, Int32, Int16, Int8,
UInt128, UInt64, UInt32, UInt16, UInt8,
Bool
}
const SIMDType = Union{SIMDFloat, SIMDInt}

# This may not be a sharp bound, but at least people won't get worse result.
const HAS_FLEXIABLE_VECTOR_LENGTH = VERSION >= v"1.6"

function julia_type_to_llvm_type(@nospecialize(T::DataType))
T === Float64 ? "double" :
T === Float32 ? "float" :
T <: Union{Int128,UInt128} ? "i128" :
T <: Union{Int64,UInt64} ? "i64" :
T <: Union{Int32,UInt32} ? "i32" :
T <: Union{Int16,UInt16} ? "i16" :
T <: Union{Bool,Int8,UInt8} ? "i8" :
error("$T cannot be mapped to a LLVM type")
end

@generated function div_tuple_by_scalar(tup::NTuple{N}, x) where N
return tupexpr(i -> :(tup[$i] / x), N)
@generated function scale_tuple(tup::NTuple{N,T}, x::S) where {N,T,S}
if !(HAS_FLEXIABLE_VECTOR_LENGTH && T === S && S <: SIMDType)
return tupexpr(i -> :(tup[$i] * x), N)
end

S = julia_type_to_llvm_type(T)
VT = NTuple{N, VecElement{T}}
op = T <: SIMDFloat ? "fmul nsz contract" : "mul"
llvmir = """
%el = insertelement <$N x $S> undef, $S %1, i32 0
%vx = shufflevector <$N x $S> %el, <$N x $S> undef, <$N x i32> zeroinitializer
%res = $op <$N x $S> %0, %vx
ret <$N x $S> %res
"""

quote
$(Expr(:meta, :inline))
ret = Base.llvmcall($llvmir, $VT, Tuple{$VT, $T}, $VT(tup), x)
Base.@ntuple $N i->ret[i].value
end
end

@generated function div_tuple_by_scalar(tup::NTuple{N,T}, x::S) where {N,T,S}
if !(HAS_FLEXIABLE_VECTOR_LENGTH && T === S === typeof(one(T) / one(S)) && S <: SIMDType)
return tupexpr(i -> :(tup[$i] / x), N)
end

S = julia_type_to_llvm_type(T)
VT = NTuple{N, VecElement{T}}
op = T <: SIMDFloat ? "fdiv nsz contract" : "div"
llvmir = """
%el = insertelement <$N x $S> undef, $S %1, i32 0
%vx = shufflevector <$N x $S> %el, <$N x $S> undef, <$N x i32> zeroinitializer
%res = $op <$N x $S> %0, %vx
ret <$N x $S> %res
"""

quote
$(Expr(:meta, :inline))
ret = Base.llvmcall($llvmir, $VT, Tuple{$VT, $T}, $VT(tup), x)
Base.@ntuple $N i->ret[i].value
end
end

@generated function add_tuples(a::NTuple{N}, b::NTuple{N}) where N
return tupexpr(i -> :(a[$i] + b[$i]), N)
@generated function add_tuples(a::NTuple{N,T}, b::NTuple{N,S}) where {N,T,S}
if !(HAS_FLEXIABLE_VECTOR_LENGTH && T === S && S <: SIMDType)
return tupexpr(i -> :(a[$i] + b[$i]), N)
end

S = julia_type_to_llvm_type(T)
VT = NTuple{N, VecElement{T}}
op = T <: SIMDFloat ? "fadd nsz contract" : "add"
llvmir = """
%res = $op <$N x $S> %0, %1
ret <$N x $S> %res
"""

quote
$(Expr(:meta, :inline))
ret = Base.llvmcall($llvmir, $VT, Tuple{$VT, $VT}, $VT(a), $VT(b))
Base.@ntuple $N i->ret[i].value
end
end

@generated function sub_tuples(a::NTuple{N}, b::NTuple{N}) where N
return tupexpr(i -> :(a[$i] - b[$i]), N)
@generated function sub_tuples(a::NTuple{N,T}, b::NTuple{N,S}) where {N,T,S}
if !(HAS_FLEXIABLE_VECTOR_LENGTH && T === S && S <: SIMDType)
return tupexpr(i -> :(a[$i] - b[$i]), N)
end

S = julia_type_to_llvm_type(T)
VT = NTuple{N, VecElement{T}}
op = T <: SIMDFloat ? "fsub nsz contract" : "sub"
llvmir = """
%res = $op <$N x $S> %0, %1
ret <$N x $S> %res
"""

quote
$(Expr(:meta, :inline))
ret = Base.llvmcall($llvmir, $VT, Tuple{$VT, $VT}, $VT(a), $VT(b))
Base.@ntuple $N i->ret[i].value
end
end

@generated function minus_tuple(tup::NTuple{N}) where N
return tupexpr(i -> :(-tup[$i]), N)
@generated function minus_tuple(tup::NTuple{N,T}) where {N,T}
(HAS_FLEXIABLE_VECTOR_LENGTH && T <: SIMDType) || return tupexpr(i -> :(-tup[$i]), N)

S = julia_type_to_llvm_type(T)
VT = NTuple{N, VecElement{T}}
if T <: SIMDFloat
llvmir = """
%res = fneg nsz contract <$N x $S> %0
ret <$N x $S> %res
"""
else
llvmir = """
%res = sub <$N x $S> zeroinitializer, %0
ret <$N x $S> %res
"""
end

quote
$(Expr(:meta, :inline))
ret = Base.llvmcall($llvmir, $VT, Tuple{$VT}, $VT(tup))
Base.@ntuple $N i->ret[i].value
end
end

@generated function mul_tuples(a::NTuple{N}, b::NTuple{N}, afactor, bfactor) where N
return tupexpr(i -> :((afactor * a[$i]) + (bfactor * b[$i])), N)
@generated function mul_tuples(a::NTuple{N,V1}, b::NTuple{N,V2}, afactor::S1, bfactor::S2) where {N,V1,V2,S1,S2}
if !(HAS_FLEXIABLE_VECTOR_LENGTH && V1 === V2 === S1 === S2 && S2 <: SIMDFloat)
return tupexpr(i -> :((afactor * a[$i]) + (bfactor * b[$i])), N)
end

T = V1
S = julia_type_to_llvm_type(T)
fmuladd = "@llvm.fmuladd.v$(N)f$(sizeof(T)*8)"

VT = NTuple{N, VecElement{T}}
llvmir = """
declare <$N x $S> $fmuladd(<$N x $S>, <$N x $S>, <$N x $S>)

define <$N x $S> @entry(<$N x $S>, <$N x $S>, $S, $S) alwaysinline {
top:
%el1 = insertelement <$N x $S> undef, $S %2, i32 0
%afactor = shufflevector <$N x $S> %el1, <$N x $S> undef, <$N x i32> zeroinitializer
%el2 = insertelement <$N x $S> undef, $S %3, i32 0
%bfactor = shufflevector <$N x $S> %el2, <$N x $S> undef, <$N x i32> zeroinitializer
%tmp = fmul nsz contract <$N x $S> %1, %bfactor
%res = call nsz contract <$N x $S> $fmuladd(<$N x $S> %0, <$N x $S> %afactor, <$N x $S> %tmp)
ret <$N x $S> %res
}
"""
quote
$(Expr(:meta, :inline))
ret = Base.llvmcall(($llvmir, "entry"), $VT, Tuple{$VT, $VT, $T, $T}, $VT(a), $VT(b), afactor, bfactor)
Base.@ntuple $N i->ret[i].value
end
end

###################
Expand Down
7 changes: 6 additions & 1 deletion test/PartialsTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ using ForwardDiff: Partials

samerng() = MersenneTwister(1)

approx_tuple(x, y) = all(zip(x, y)) do (a, b)
a ≈ b
end

for N in (0, 3), T in (Int, Float32, Float64)
println(" ...testing Partials{$N,$T}")

Expand Down Expand Up @@ -114,7 +118,8 @@ for N in (0, 3), T in (Int, Float32, Float64)

if N > 0
@test ForwardDiff._div_partials(PARTIALS, PARTIALS2, X, Y) == ForwardDiff._mul_partials(PARTIALS, PARTIALS2, inv(Y), -X/(Y^2))
@test ForwardDiff._mul_partials(PARTIALS, PARTIALS2, X, Y).values == map((a, b) -> (X * a) + (Y * b), VALUES, VALUES2)
# FMA
@test approx_tuple(ForwardDiff._mul_partials(PARTIALS, PARTIALS2, X, Y).values, map((a, b) -> (X * a) + (Y * b), VALUES, VALUES2))
@test ForwardDiff._mul_partials(ZERO_PARTIALS, PARTIALS, X, Y) == Y * PARTIALS
@test ForwardDiff._mul_partials(PARTIALS, ZERO_PARTIALS, X, Y) == X * PARTIALS

Expand Down