diff --git a/src/partials.jl b/src/partials.jl index fce67b0a..0325ea75 100644 --- a/src/partials.jl +++ b/src/partials.jl @@ -197,28 +197,165 @@ end return tupexpr(i -> :(rand(V)), N) end -@generated function scale_tuple(tup::NTuple{N}, x) where N - return tupexpr(i -> :(tup[$i] * x), N) +const SIMDFloat = Union{Float64, Float32} +const SIMDInt = Union{ + Int128, Int64, Int32, Int16, Int8, + UInt128, UInt64, UInt32, UInt16, UInt8, + Bool + } +const SIMDType = Union{SIMDFloat, SIMDInt} + +# This may not be a sharp bound, but at least people won't get worse result. +const HAS_FLEXIABLE_VECTOR_LENGTH = VERSION >= v"1.6" + +function julia_type_to_llvm_type(@nospecialize(T::DataType)) + T === Float64 ? "double" : + T === Float32 ? "float" : + T <: Union{Int128,UInt128} ? "i128" : + T <: Union{Int64,UInt64} ? "i64" : + T <: Union{Int32,UInt32} ? "i32" : + T <: Union{Int16,UInt16} ? "i16" : + T <: Union{Bool,Int8,UInt8} ? "i8" : + error("$T cannot be mapped to a LLVM type") end -@generated function div_tuple_by_scalar(tup::NTuple{N}, x) where N - return tupexpr(i -> :(tup[$i] / x), N) +@generated function scale_tuple(tup::NTuple{N,T}, x::S) where {N,T,S} + if !(HAS_FLEXIABLE_VECTOR_LENGTH && T === S && S <: SIMDType) + return tupexpr(i -> :(tup[$i] * x), N) + end + + S = julia_type_to_llvm_type(T) + VT = NTuple{N, VecElement{T}} + op = T <: SIMDFloat ? "fmul nsz contract" : "mul" + llvmir = """ + %el = insertelement <$N x $S> undef, $S %1, i32 0 + %vx = shufflevector <$N x $S> %el, <$N x $S> undef, <$N x i32> zeroinitializer + %res = $op <$N x $S> %0, %vx + ret <$N x $S> %res + """ + + quote + $(Expr(:meta, :inline)) + ret = Base.llvmcall($llvmir, $VT, Tuple{$VT, $T}, $VT(tup), x) + Base.@ntuple $N i->ret[i].value + end +end + +@generated function div_tuple_by_scalar(tup::NTuple{N,T}, x::S) where {N,T,S} + if !(HAS_FLEXIABLE_VECTOR_LENGTH && T === S === typeof(one(T) / one(S)) && S <: SIMDType) + return tupexpr(i -> :(tup[$i] / x), N) + end + + S = julia_type_to_llvm_type(T) + VT = NTuple{N, VecElement{T}} + op = T <: SIMDFloat ? "fdiv nsz contract" : "div" + llvmir = """ + %el = insertelement <$N x $S> undef, $S %1, i32 0 + %vx = shufflevector <$N x $S> %el, <$N x $S> undef, <$N x i32> zeroinitializer + %res = $op <$N x $S> %0, %vx + ret <$N x $S> %res + """ + + quote + $(Expr(:meta, :inline)) + ret = Base.llvmcall($llvmir, $VT, Tuple{$VT, $T}, $VT(tup), x) + Base.@ntuple $N i->ret[i].value + end end -@generated function add_tuples(a::NTuple{N}, b::NTuple{N}) where N - return tupexpr(i -> :(a[$i] + b[$i]), N) +@generated function add_tuples(a::NTuple{N,T}, b::NTuple{N,S}) where {N,T,S} + if !(HAS_FLEXIABLE_VECTOR_LENGTH && T === S && S <: SIMDType) + return tupexpr(i -> :(a[$i] + b[$i]), N) + end + + S = julia_type_to_llvm_type(T) + VT = NTuple{N, VecElement{T}} + op = T <: SIMDFloat ? "fadd nsz contract" : "add" + llvmir = """ + %res = $op <$N x $S> %0, %1 + ret <$N x $S> %res + """ + + quote + $(Expr(:meta, :inline)) + ret = Base.llvmcall($llvmir, $VT, Tuple{$VT, $VT}, $VT(a), $VT(b)) + Base.@ntuple $N i->ret[i].value + end end -@generated function sub_tuples(a::NTuple{N}, b::NTuple{N}) where N - return tupexpr(i -> :(a[$i] - b[$i]), N) +@generated function sub_tuples(a::NTuple{N,T}, b::NTuple{N,S}) where {N,T,S} + if !(HAS_FLEXIABLE_VECTOR_LENGTH && T === S && S <: SIMDType) + return tupexpr(i -> :(a[$i] - b[$i]), N) + end + + S = julia_type_to_llvm_type(T) + VT = NTuple{N, VecElement{T}} + op = T <: SIMDFloat ? "fsub nsz contract" : "sub" + llvmir = """ + %res = $op <$N x $S> %0, %1 + ret <$N x $S> %res + """ + + quote + $(Expr(:meta, :inline)) + ret = Base.llvmcall($llvmir, $VT, Tuple{$VT, $VT}, $VT(a), $VT(b)) + Base.@ntuple $N i->ret[i].value + end end -@generated function minus_tuple(tup::NTuple{N}) where N - return tupexpr(i -> :(-tup[$i]), N) +@generated function minus_tuple(tup::NTuple{N,T}) where {N,T} + (HAS_FLEXIABLE_VECTOR_LENGTH && T <: SIMDType) || return tupexpr(i -> :(-tup[$i]), N) + + S = julia_type_to_llvm_type(T) + VT = NTuple{N, VecElement{T}} + if T <: SIMDFloat + llvmir = """ + %res = fneg nsz contract <$N x $S> %0 + ret <$N x $S> %res + """ + else + llvmir = """ + %res = sub <$N x $S> zeroinitializer, %0 + ret <$N x $S> %res + """ + end + + quote + $(Expr(:meta, :inline)) + ret = Base.llvmcall($llvmir, $VT, Tuple{$VT}, $VT(tup)) + Base.@ntuple $N i->ret[i].value + end end -@generated function mul_tuples(a::NTuple{N}, b::NTuple{N}, afactor, bfactor) where N - return tupexpr(i -> :((afactor * a[$i]) + (bfactor * b[$i])), N) +@generated function mul_tuples(a::NTuple{N,V1}, b::NTuple{N,V2}, afactor::S1, bfactor::S2) where {N,V1,V2,S1,S2} + if !(HAS_FLEXIABLE_VECTOR_LENGTH && V1 === V2 === S1 === S2 && S2 <: SIMDFloat) + return tupexpr(i -> :((afactor * a[$i]) + (bfactor * b[$i])), N) + end + + T = V1 + S = julia_type_to_llvm_type(T) + fmuladd = "@llvm.fmuladd.v$(N)f$(sizeof(T)*8)" + + VT = NTuple{N, VecElement{T}} + llvmir = """ + declare <$N x $S> $fmuladd(<$N x $S>, <$N x $S>, <$N x $S>) + + define <$N x $S> @entry(<$N x $S>, <$N x $S>, $S, $S) alwaysinline { + top: + %el1 = insertelement <$N x $S> undef, $S %2, i32 0 + %afactor = shufflevector <$N x $S> %el1, <$N x $S> undef, <$N x i32> zeroinitializer + %el2 = insertelement <$N x $S> undef, $S %3, i32 0 + %bfactor = shufflevector <$N x $S> %el2, <$N x $S> undef, <$N x i32> zeroinitializer + %tmp = fmul nsz contract <$N x $S> %1, %bfactor + %res = call nsz contract <$N x $S> $fmuladd(<$N x $S> %0, <$N x $S> %afactor, <$N x $S> %tmp) + ret <$N x $S> %res + } + """ + quote + $(Expr(:meta, :inline)) + ret = Base.llvmcall(($llvmir, "entry"), $VT, Tuple{$VT, $VT, $T, $T}, $VT(a), $VT(b), afactor, bfactor) + Base.@ntuple $N i->ret[i].value + end end ################### diff --git a/test/PartialsTest.jl b/test/PartialsTest.jl index 39fb05d7..84320446 100644 --- a/test/PartialsTest.jl +++ b/test/PartialsTest.jl @@ -7,6 +7,10 @@ using ForwardDiff: Partials samerng() = MersenneTwister(1) +approx_tuple(x, y) = all(zip(x, y)) do (a, b) + a ≈ b +end + for N in (0, 3), T in (Int, Float32, Float64) println(" ...testing Partials{$N,$T}") @@ -114,7 +118,8 @@ for N in (0, 3), T in (Int, Float32, Float64) if N > 0 @test ForwardDiff._div_partials(PARTIALS, PARTIALS2, X, Y) == ForwardDiff._mul_partials(PARTIALS, PARTIALS2, inv(Y), -X/(Y^2)) - @test ForwardDiff._mul_partials(PARTIALS, PARTIALS2, X, Y).values == map((a, b) -> (X * a) + (Y * b), VALUES, VALUES2) + # FMA + @test approx_tuple(ForwardDiff._mul_partials(PARTIALS, PARTIALS2, X, Y).values, map((a, b) -> (X * a) + (Y * b), VALUES, VALUES2)) @test ForwardDiff._mul_partials(ZERO_PARTIALS, PARTIALS, X, Y) == Y * PARTIALS @test ForwardDiff._mul_partials(PARTIALS, ZERO_PARTIALS, X, Y) == X * PARTIALS