diff --git a/src/mapreduce.jl b/src/mapreduce.jl index 0595a3b8..b1c95636 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -285,3 +285,53 @@ end @inbounds return similar_type(a, T, Size($Snew))(tuple($(exprs...))) end end + +struct _InitialValue end + +_maybe_val(dims::Integer) = Val(Int(dims)) +_maybe_val(dims) = dims +_valof(::Val{D}) where D = D + +@inline Base.accumulate(op::F, a::StaticVector; dims = :, init = _InitialValue()) where {F} = + _accumulate(op, a, _maybe_val(dims), init) + +@inline Base.accumulate(op::F, a::StaticArray; dims, init = _InitialValue()) where {F} = + _accumulate(op, a, _maybe_val(dims), init) + +@inline function _accumulate(op::F, a::StaticArray, dims::Union{Val,Colon}, init) where {F} + # Adjoin the initial value to `op`: + rf(x, y) = x isa _InitialValue ? Base.reduce_first(op, y) : op(x, y) + + if isempty(a) + T = return_type(rf, Tuple{typeof(init), eltype(a)}) + return similar_type(a, T)() + end + + # StaticArrays' `reduce` is `foldl`: + results = _reduce( + a, + dims, + (init = (similar_type(a, Union{}, Size(0))(), init),), + ) do (ys, acc), x + y = rf(acc, x) + # Not using `push(ys, y)` here since we need to widen element type as + # we iterate. + (vcat(ys, SA[y]), y) + end + dims === (:) && return first(results) + + ys = map(first, results) + # Now map over all indices of `a`. Since `_map` needs at least + # one `StaticArray` to be passed, we pass `a` here, even though + # the values of `a` are not used. + data = _map(a, CartesianIndices(a)) do _, CI + D = _valof(dims) + I = Tuple(CI) + J = setindex(I, 1, D) + ys[J...][I[D]] + end + return similar_type(a, eltype(data))(data) +end + +@inline Base.cumsum(a::StaticArray; kw...) = accumulate(Base.add_sum, a; kw...) +@inline Base.cumprod(a::StaticArray; kw...) = accumulate(Base.mul_prod, a; kw...) diff --git a/test/accumulate.jl b/test/accumulate.jl new file mode 100644 index 00000000..97a23e7d --- /dev/null +++ b/test/accumulate.jl @@ -0,0 +1,66 @@ +using StaticArrays, Test + +@testset "accumulate" begin + @testset "cumsum(::$label)" for (label, T) in [ + # label, T + ("SVector", SVector), + ("MVector", MVector), + ("SizedVector", SizedVector), + ] + @testset "$label" for (label, a) in [ + ("[1, 2, 3]", T{3}(SA[1, 2, 3])), + ("[]", T{0,Int}(())), + ] + @test cumsum(a) == cumsum(collect(a)) + @test cumsum(a) isa similar_type(a) + @inferred cumsum(a) + end + @test eltype(cumsum(T{0,Int8}(()))) == eltype(cumsum(Int8[])) + @test eltype(cumsum(T{1,Int8}((1)))) == eltype(cumsum(Int8[1])) + @test eltype(cumsum(T{2,Int8}((1, 2)))) == eltype(cumsum(Int8[1, 2])) + end + + @testset "cumsum(::$label; dims=2)" for (label, T) in [ + # label, T + ("SMatrix", SMatrix), + ("MMatrix", MMatrix), + ("SizedMatrix", SizedMatrix), + ] + @testset "$label" for (label, a) in [ + ("[1 2; 3 4; 5 6]", T{3,2}(SA[1 2; 3 4; 5 6])), + ("0 x 2 matrix", T{0,2,Float64}()), + ("2 x 0 matrix", T{2,0,Float64}()), + ] + @test cumsum(a; dims = 2) == cumsum(collect(a); dims = 2) + @test cumsum(a; dims = 2) isa similar_type(a) + v"1.1" <= VERSION < v"1.2" && continue + @inferred cumsum(a; dims = Val(2)) + end + end + + @testset "cumsum(a::SArray; dims=$i); ndims(a) = $d" for d in 1:4, i in 1:d + shape = Tuple(1:d) + a = similar_type(SArray, Int, Size(shape))(1:prod(shape)) + @test cumsum(a; dims = i) == cumsum(collect(a); dims = i) + @test cumsum(a; dims = i) isa SArray + v"1.1" <= VERSION < v"1.2" && continue + @inferred cumsum(a; dims = Val(i)) + end + + @testset "cumprod" begin + a = SA[1, 2, 3] + @test cumprod(a)::SArray == cumprod(collect(a)) + @inferred cumprod(a) + + @test eltype(cumsum(SA{Int8}[])) == eltype(cumsum(Int8[])) + @test eltype(cumsum(SA{Int8}[1])) == eltype(cumsum(Int8[1])) + @test eltype(cumsum(SA{Int8}[1, 2])) == eltype(cumsum(Int8[1, 2])) + end + + @testset "empty vector with init" begin + a = SA{Int}[] + right(_, x) = x + @test accumulate(right, a; init = Val(1)) === SA{Int}[] + @inferred accumulate(right, a; init = Val(1)) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index eee41e04..34b508fc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -33,6 +33,7 @@ include("abstractarray.jl") include("indexing.jl") include("initializers.jl") Random.seed!(42); include("mapreduce.jl") +Random.seed!(42); include("accumulate.jl") Random.seed!(42); include("arraymath.jl") include("broadcast.jl") include("linalg.jl")