diff --git a/src/ForwardDiff.jl b/src/ForwardDiff.jl index 35ee6a8a..5174e324 100644 --- a/src/ForwardDiff.jl +++ b/src/ForwardDiff.jl @@ -16,6 +16,13 @@ import SpecialFunctions import LogExpFunctions import CommonSubexpressions +const SIMDFloat = Union{Float64, Float32} +const SIMDInt = Union{ + Int128, Int64, Int32, Int16, Int8, + UInt128, UInt64, UInt32, UInt16, UInt8, + } +const SIMDType = Union{SIMDFloat, SIMDInt} + include("prelude.jl") include("partials.jl") include("dual.jl") diff --git a/src/dual.jl b/src/dual.jl index 03330c00..75d46378 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -541,6 +541,16 @@ end # fma # #-----# +@inline function calc_fma_xyz(x::Dual{T,V,N}, + y::Dual{T,V,N}, + z::Dual{T,V,N}) where {T, V<:SIMDFloat,N} + xv, yv, zv = value(x), value(y), value(z) + rv = fma(xv, yv, zv) + N == 0 && return Dual{T}(rv) + xp, yp, zp = Vec(partials(x).values), Vec(partials(y).values), Vec(partials(z).values) + parts = Tuple(fma(xv, yp, fma(yv, xp, zp))) + Dual{T}(rv, parts) +end @generated function calc_fma_xyz(x::Dual{T,<:Any,N}, y::Dual{T,<:Any,N}, z::Dual{T,<:Any,N}) where {T,N} @@ -583,6 +593,16 @@ end # muladd # #--------# +@inline function calc_muladd_xyz(x::Dual{T,V,N}, + y::Dual{T,V,N}, + z::Dual{T,V,N}) where {T, V<:SIMDType,N} + xv, yv, zv = value(x), value(y), value(z) + rv = muladd(xv, yv, zv) + N == 0 && return Dual{T}(rv) + xp, yp, zp = Vec(partials(x).values), Vec(partials(y).values), Vec(partials(z).values) + parts = Tuple(muladd(xv, yp, muladd(yv, xp, zp))) + Dual{T}(rv, parts) +end @generated function calc_muladd_xyz(x::Dual{T,<:Any,N}, y::Dual{T,<:Any,N}, z::Dual{T,<:Any,N}) where {T,N} diff --git a/src/partials.jl b/src/partials.jl index eca0a76c..7a94884e 100644 --- a/src/partials.jl +++ b/src/partials.jl @@ -205,13 +205,6 @@ end return tupexpr(i -> :(rand(V)), N) end - -const SIMDFloat = Union{Float64, Float32} -const SIMDInt = Union{ - Int128, Int64, Int32, Int16, Int8, - UInt128, UInt64, UInt32, UInt16, UInt8, - } -const SIMDType = Union{SIMDFloat, SIMDInt} const NT{N,T} = NTuple{N,T} # SIMD implementation