|
| 1 | +struct Partials{N,V} <: AbstractVector{V} |
| 2 | + values::NTuple{N,V} |
| 3 | +end |
| 4 | + |
| 5 | +############################## |
| 6 | +# Utility/Accessor Functions # |
| 7 | +############################## |
| 8 | + |
| 9 | +@generated function single_seed(::Type{Partials{N,V}}, ::Val{i}) where {N,V,i} |
| 10 | + ex = Expr(:tuple, [ifelse(i === j, :(one(V)), :(zero(V))) for j in 1:N]...) |
| 11 | + return :(Partials($(ex))) |
| 12 | +end |
| 13 | + |
| 14 | +@inline valtype(::Partials{N,V}) where {N,V} = V |
| 15 | +@inline valtype(::Type{Partials{N,V}}) where {N,V} = V |
| 16 | + |
| 17 | +@inline npartials(::Partials{N}) where {N} = N |
| 18 | +@inline npartials(::Type{Partials{N,V}}) where {N,V} = N |
| 19 | + |
| 20 | +@inline Base.length(::Partials{N}) where {N} = N |
| 21 | +@inline Base.size(::Partials{N}) where {N} = (N,) |
| 22 | + |
| 23 | +@inline Base.@propagate_inbounds Base.getindex(partials::Partials, i::Int) = partials.values[i] |
| 24 | + |
| 25 | +Base.iterate(partials::Partials) = iterate(partials.values) |
| 26 | +Base.iterate(partials::Partials, i) = iterate(partials.values, i) |
| 27 | + |
| 28 | +Base.IndexStyle(::Type{<:Partials}) = IndexLinear() |
| 29 | + |
| 30 | +# Can be deleted after https://github.com/JuliaLang/julia/pull/29854 is on a release |
| 31 | +Base.mightalias(x::AbstractArray, y::Partials) = false |
| 32 | + |
| 33 | +##################### |
| 34 | +# Generic Functions # |
| 35 | +##################### |
| 36 | + |
| 37 | +@inline Base.iszero(partials::Partials) = iszero_tuple(partials.values) |
| 38 | + |
| 39 | +@inline Base.zero(partials::Partials) = zero(typeof(partials)) |
| 40 | +@inline Base.zero(::Type{Partials{N,V}}) where {N,V} = Partials{N,V}(zero_tuple(NTuple{N,V})) |
| 41 | + |
| 42 | +@inline Base.one(partials::Partials) = one(typeof(partials)) |
| 43 | +@inline Base.one(::Type{Partials{N,V}}) where {N,V} = Partials{N,V}(one_tuple(NTuple{N,V})) |
| 44 | + |
| 45 | +@inline Random.rand(partials::Partials) = rand(typeof(partials)) |
| 46 | +@inline Random.rand(::Type{Partials{N,V}}) where {N,V} = Partials{N,V}(rand_tuple(NTuple{N,V})) |
| 47 | +@inline Random.rand(rng::AbstractRNG, partials::Partials) = rand(rng, typeof(partials)) |
| 48 | +@inline Random.rand(rng::AbstractRNG, ::Type{Partials{N,V}}) where {N,V} = Partials{N,V}(rand_tuple(rng, NTuple{N,V})) |
| 49 | + |
| 50 | +Base.isequal(a::Partials{N}, b::Partials{N}) where {N} = isequal(a.values, b.values) |
| 51 | +Base.:(==)(a::Partials{N}, b::Partials{N}) where {N} = a.values == b.values |
| 52 | + |
| 53 | +const PARTIALS_HASH = hash(Partials) |
| 54 | + |
| 55 | +Base.hash(partials::Partials) = hash(partials.values, PARTIALS_HASH) |
| 56 | +Base.hash(partials::Partials, hsh::UInt64) = hash(hash(partials), hsh) |
| 57 | + |
| 58 | +@inline Base.copy(partials::Partials) = partials |
| 59 | + |
| 60 | +Base.read(io::IO, ::Type{Partials{N,V}}) where {N,V} = Partials{N,V}(ntuple(i->read(io, V), N)) |
| 61 | + |
| 62 | +function Base.write(io::IO, partials::Partials) |
| 63 | + for p in partials |
| 64 | + write(io, p) |
| 65 | + end |
| 66 | +end |
| 67 | + |
| 68 | +######################## |
| 69 | +# Conversion/Promotion # |
| 70 | +######################## |
| 71 | + |
| 72 | +Base.promote_rule(::Type{Partials{N,A}}, ::Type{Partials{N,B}}) where {N,A,B} = Partials{N,promote_type(A, B)} |
| 73 | + |
| 74 | +Base.convert(::Type{Partials{N,V}}, partials::Partials) where {N,V} = Partials{N,V}(partials.values) |
| 75 | +Base.convert(::Type{Partials{N,V}}, partials::Partials{N,V}) where {N,V} = partials |
| 76 | + |
| 77 | +######################## |
| 78 | +# Arithmetic Functions # |
| 79 | +######################## |
| 80 | + |
| 81 | +@inline Base.:+(a::Partials{N}, b::Partials{N}) where {N} = Partials(add_tuples(a.values, b.values)) |
| 82 | +@inline Base.:-(a::Partials{N}, b::Partials{N}) where {N} = Partials(sub_tuples(a.values, b.values)) |
| 83 | +@inline Base.:-(partials::Partials) = Partials(minus_tuple(partials.values)) |
| 84 | +@inline Base.:*(x::Real, partials::Partials) = partials*x |
| 85 | + |
| 86 | +@inline function _div_partials(a::Partials, b::Partials, aval, bval) |
| 87 | + return _mul_partials(a, b, inv(bval), -(aval / (bval*bval))) |
| 88 | +end |
| 89 | + |
| 90 | +# NaN/Inf-safe methods # |
| 91 | +#----------------------# |
| 92 | + |
| 93 | +if NANSAFE_MODE_ENABLED |
| 94 | + @inline function Base.:*(partials::Partials, x::Real) |
| 95 | + x = ifelse(!isfinite(x) && iszero(partials), one(x), x) |
| 96 | + return Partials(scale_tuple(partials.values, x)) |
| 97 | + end |
| 98 | + |
| 99 | + @inline function Base.:/(partials::Partials, x::Real) |
| 100 | + x = ifelse(x == zero(x) && iszero(partials), one(x), x) |
| 101 | + return Partials(div_tuple_by_scalar(partials.values, x)) |
| 102 | + end |
| 103 | + |
| 104 | + @inline function _mul_partials(a::Partials{N}, b::Partials{N}, x_a, x_b) where N |
| 105 | + x_a = ifelse(!isfinite(x_a) && iszero(a), one(x_a), x_a) |
| 106 | + x_b = ifelse(!isfinite(x_b) && iszero(b), one(x_b), x_b) |
| 107 | + return Partials(mul_tuples(a.values, b.values, x_a, x_b)) |
| 108 | + end |
| 109 | +else |
| 110 | + @inline function Base.:*(partials::Partials, x::Real) |
| 111 | + return Partials(scale_tuple(partials.values, x)) |
| 112 | + end |
| 113 | + |
| 114 | + @inline function Base.:/(partials::Partials, x::Real) |
| 115 | + return Partials(div_tuple_by_scalar(partials.values, x)) |
| 116 | + end |
| 117 | + |
| 118 | + @inline function _mul_partials(a::Partials{N}, b::Partials{N}, x_a, x_b) where N |
| 119 | + return Partials(mul_tuples(a.values, b.values, x_a, x_b)) |
| 120 | + end |
| 121 | +end |
| 122 | + |
| 123 | +# edge cases where N == 0 # |
| 124 | +#-------------------------# |
| 125 | + |
| 126 | +@inline Base.:+(a::Partials{0,A}, b::Partials{0,B}) where {A,B} = Partials{0,promote_type(A,B)}(tuple()) |
| 127 | +@inline Base.:+(a::Partials{0,A}, b::Partials{N,B}) where {N,A,B} = convert(Partials{N,promote_type(A,B)}, b) |
| 128 | +@inline Base.:+(a::Partials{N,A}, b::Partials{0,B}) where {N,A,B} = convert(Partials{N,promote_type(A,B)}, a) |
| 129 | + |
| 130 | +@inline Base.:-(a::Partials{0,A}, b::Partials{0,B}) where {A,B} = Partials{0,promote_type(A,B)}(tuple()) |
| 131 | +@inline Base.:-(a::Partials{0,A}, b::Partials{N,B}) where {N,A,B} = -(convert(Partials{N,promote_type(A,B)}, b)) |
| 132 | +@inline Base.:-(a::Partials{N,A}, b::Partials{0,B}) where {N,A,B} = convert(Partials{N,promote_type(A,B)}, a) |
| 133 | +@inline Base.:-(partials::Partials{0,V}) where {V} = partials |
| 134 | + |
| 135 | +@inline Base.:*(partials::Partials{0,V}, x::Real) where {V} = Partials{0,promote_type(V,typeof(x))}(tuple()) |
| 136 | +@inline Base.:*(x::Real, partials::Partials{0,V}) where {V} = Partials{0,promote_type(V,typeof(x))}(tuple()) |
| 137 | + |
| 138 | +@inline Base.:/(partials::Partials{0,V}, x::Real) where {V} = Partials{0,promote_type(V,typeof(x))}(tuple()) |
| 139 | + |
| 140 | +@inline _mul_partials(a::Partials{0,A}, b::Partials{0,B}, afactor, bfactor) where {A,B} = Partials{0,promote_type(A,B)}(tuple()) |
| 141 | +@inline _mul_partials(a::Partials{0,A}, b::Partials{N,B}, afactor, bfactor) where {N,A,B} = bfactor * b |
| 142 | +@inline _mul_partials(a::Partials{N,A}, b::Partials{0,B}, afactor, bfactor) where {N,A,B} = afactor * a |
| 143 | + |
| 144 | +################################## |
| 145 | +# Generated Functions on NTuples # |
| 146 | +################################## |
| 147 | +# The below functions are generally |
| 148 | +# equivalent to directly mapping over |
| 149 | +# tuples using `map`, but run a bit |
| 150 | +# faster since they generate inline code |
| 151 | +# that doesn't rely on closures. |
| 152 | + |
| 153 | +function tupexpr(f, N) |
| 154 | + ex = Expr(:tuple, [f(i) for i=1:N]...) |
| 155 | + return quote |
| 156 | + $(Expr(:meta, :inline)) |
| 157 | + @inbounds return $ex |
| 158 | + end |
| 159 | +end |
| 160 | + |
| 161 | +@inline iszero_tuple(::Tuple{}) = true |
| 162 | +@inline zero_tuple(::Type{Tuple{}}) = tuple() |
| 163 | +@inline one_tuple(::Type{Tuple{}}) = tuple() |
| 164 | +@inline rand_tuple(::AbstractRNG, ::Type{Tuple{}}) = tuple() |
| 165 | +@inline rand_tuple(::Type{Tuple{}}) = tuple() |
| 166 | + |
| 167 | +@generated function iszero_tuple(tup::NTuple{N,V}) where {N,V} |
| 168 | + ex = Expr(:&&, [:(z == tup[$i]) for i=1:N]...) |
| 169 | + return quote |
| 170 | + z = zero(V) |
| 171 | + $(Expr(:meta, :inline)) |
| 172 | + @inbounds return $ex |
| 173 | + end |
| 174 | +end |
| 175 | + |
| 176 | +@generated function zero_tuple(::Type{NTuple{N,V}}) where {N,V} |
| 177 | + ex = tupexpr(i -> :(z), N) |
| 178 | + return quote |
| 179 | + z = zero(V) |
| 180 | + return $ex |
| 181 | + end |
| 182 | +end |
| 183 | + |
| 184 | +@generated function one_tuple(::Type{NTuple{N,V}}) where {N,V} |
| 185 | + ex = tupexpr(i -> :(z), N) |
| 186 | + return quote |
| 187 | + z = one(V) |
| 188 | + return $ex |
| 189 | + end |
| 190 | +end |
| 191 | + |
| 192 | +@generated function rand_tuple(rng::AbstractRNG, ::Type{NTuple{N,V}}) where {N,V} |
| 193 | + return tupexpr(i -> :(rand(rng, V)), N) |
| 194 | +end |
| 195 | + |
| 196 | +@generated function rand_tuple(::Type{NTuple{N,V}}) where {N,V} |
| 197 | + return tupexpr(i -> :(rand(V)), N) |
| 198 | +end |
| 199 | + |
| 200 | +const SIMDFloat = Union{Float64, Float32} |
| 201 | +const SIMDInt = Union{ |
| 202 | + Int128, Int64, Int32, Int16, Int8, |
| 203 | + UInt128, UInt64, UInt32, UInt16, UInt8, |
| 204 | + Bool |
| 205 | + } |
| 206 | +const SIMDType = Union{SIMDFloat, SIMDInt} |
| 207 | + |
| 208 | +# This may not be a sharp bound, but at least people won't get worse result. |
| 209 | +const HAS_FLEXIABLE_VECTOR_LENGTH = VERSION >= v"1.6" |
| 210 | + |
| 211 | +function julia_type_to_llvm_type(@nospecialize(T::DataType)) |
| 212 | + T === Float64 ? "double" : |
| 213 | + T === Float32 ? "float" : |
| 214 | + T <: Union{Int128,UInt128} ? "i128" : |
| 215 | + T <: Union{Int64,UInt64} ? "i64" : |
| 216 | + T <: Union{Int32,UInt32} ? "i32" : |
| 217 | + T <: Union{Int16,UInt16} ? "i16" : |
| 218 | + T <: Union{Bool,Int8,UInt8} ? "i8" : |
| 219 | + error("$T cannot be mapped to a LLVM type") |
| 220 | +end |
| 221 | + |
| 222 | +@generated function scale_tuple(tup::NTuple{N,T}, x::S) where {N,T,S} |
| 223 | + if !(HAS_FLEXIABLE_VECTOR_LENGTH && T === S && S <: SIMDType) |
| 224 | + return tupexpr(i -> :(tup[$i] * x), N) |
| 225 | + end |
| 226 | + |
| 227 | + S = julia_type_to_llvm_type(T) |
| 228 | + VT = NTuple{N, VecElement{T}} |
| 229 | + op = T <: SIMDFloat ? "fmul nsz contract" : "mul" |
| 230 | + llvmir = """ |
| 231 | + %el = insertelement <$N x $S> undef, $S %1, i32 0 |
| 232 | + %vx = shufflevector <$N x $S> %el, <$N x $S> undef, <$N x i32> zeroinitializer |
| 233 | + %res = $op <$N x $S> %0, %vx |
| 234 | + ret <$N x $S> %res |
| 235 | + """ |
| 236 | + |
| 237 | + quote |
| 238 | + $(Expr(:meta, :inline)) |
| 239 | + ret = Base.llvmcall($llvmir, $VT, Tuple{$VT, $T}, $VT(tup), x) |
| 240 | + Base.@ntuple $N i->ret[i].value |
| 241 | + end |
| 242 | +end |
| 243 | + |
| 244 | +@generated function div_tuple_by_scalar(tup::NTuple{N,T}, x::S) where {N,T,S} |
| 245 | + if !(HAS_FLEXIABLE_VECTOR_LENGTH && T === S === typeof(one(T) / one(S)) && S <: SIMDType) |
| 246 | + return tupexpr(i -> :(tup[$i] / x), N) |
| 247 | + end |
| 248 | + |
| 249 | + S = julia_type_to_llvm_type(T) |
| 250 | + VT = NTuple{N, VecElement{T}} |
| 251 | + op = T <: SIMDFloat ? "fdiv nsz contract" : "div" |
| 252 | + llvmir = """ |
| 253 | + %el = insertelement <$N x $S> undef, $S %1, i32 0 |
| 254 | + %vx = shufflevector <$N x $S> %el, <$N x $S> undef, <$N x i32> zeroinitializer |
| 255 | + %res = $op <$N x $S> %0, %vx |
| 256 | + ret <$N x $S> %res |
| 257 | + """ |
| 258 | + |
| 259 | + quote |
| 260 | + $(Expr(:meta, :inline)) |
| 261 | + ret = Base.llvmcall($llvmir, $VT, Tuple{$VT, $T}, $VT(tup), x) |
| 262 | + Base.@ntuple $N i->ret[i].value |
| 263 | + end |
| 264 | +end |
| 265 | + |
| 266 | +@generated function add_tuples(a::NTuple{N}, b::NTuple{N}) where N |
| 267 | + return tupexpr(i -> :(a[$i] + b[$i]), N) |
| 268 | +end |
| 269 | + |
| 270 | +@generated function minus_tuple(tup::NTuple{N}) where N |
| 271 | + return tupexpr(i -> :(-tup[$i]), N) |
| 272 | +end |
| 273 | + |
| 274 | +@generated function sub_tuples(a::NTuple{N}, b::NTuple{N}) where N |
| 275 | + return tupexpr(i -> :(a[$i] - b[$i]), N) |
| 276 | +end |
| 277 | + |
| 278 | + |
| 279 | +@generated function mul_tuples(a::NTuple{N,V1}, b::NTuple{N,V2}, afactor::S1, bfactor::S2) where {N,V1,V2,S1,S2} |
| 280 | + if !(HAS_FLEXIABLE_VECTOR_LENGTH && V1 === V2 === S1 === S2 && S2 <: SIMDFloat) |
| 281 | + return tupexpr(i -> :((afactor * a[$i]) + (bfactor * b[$i])), N) |
| 282 | + end |
| 283 | + |
| 284 | + T = V1 |
| 285 | + S = julia_type_to_llvm_type(T) |
| 286 | + fmuladd = "@llvm.fmuladd.v$(N)f$(sizeof(T)*8)" |
| 287 | + |
| 288 | + VT = NTuple{N, VecElement{T}} |
| 289 | + llvmir = """ |
| 290 | + declare <$N x $S> $fmuladd(<$N x $S>, <$N x $S>, <$N x $S>) |
| 291 | +
|
| 292 | + define <$N x $S> @entry(<$N x $S>, <$N x $S>, $S, $S) alwaysinline { |
| 293 | + top: |
| 294 | + %el1 = insertelement <$N x $S> undef, $S %2, i32 0 |
| 295 | + %afactor = shufflevector <$N x $S> %el1, <$N x $S> undef, <$N x i32> zeroinitializer |
| 296 | + %el2 = insertelement <$N x $S> undef, $S %3, i32 0 |
| 297 | + %bfactor = shufflevector <$N x $S> %el2, <$N x $S> undef, <$N x i32> zeroinitializer |
| 298 | + %tmp = fmul nsz contract <$N x $S> %1, %bfactor |
| 299 | + %res = call nsz contract <$N x $S> $fmuladd(<$N x $S> %0, <$N x $S> %afactor, <$N x $S> %tmp) |
| 300 | + ret <$N x $S> %res |
| 301 | + } |
| 302 | + """ |
| 303 | + quote |
| 304 | + $(Expr(:meta, :inline)) |
| 305 | + ret = Base.llvmcall(($llvmir, "entry"), $VT, Tuple{$VT, $VT, $T, $T}, $VT(a), $VT(b), afactor, bfactor) |
| 306 | + Base.@ntuple $N i->ret[i].value |
| 307 | + end |
| 308 | +end |
| 309 | + |
| 310 | +################### |
| 311 | +# Pretty Printing # |
| 312 | +################### |
| 313 | + |
| 314 | +Base.show(io::IO, p::Partials{N}) where {N} = print(io, "Partials", p.values) |
0 commit comments