diff --git a/src/StaticArrays.jl b/src/StaticArrays.jl index 1528f371..99218c0b 100644 --- a/src/StaticArrays.jl +++ b/src/StaticArrays.jl @@ -114,6 +114,7 @@ include("svd.jl") include("lu.jl") include("qr.jl") include("deque.jl") +include("flatten.jl") include("io.jl") include("FixedSizeArrays.jl") diff --git a/src/flatten.jl b/src/flatten.jl new file mode 100644 index 00000000..727e4c1c --- /dev/null +++ b/src/flatten.jl @@ -0,0 +1,6 @@ +# Special case flatten of iterators of static arrays. +import Base.Iterators: flatten_iteratorsize, flatten_length +flatten_iteratorsize(::Union{Base.HasShape, Base.HasLength}, ::Type{<:StaticArray{S}}) where {S} = Base.HasLength() +function flatten_length(f, T::Type{<:StaticArray{S}}) where {S} + length(T)*length(f.it) +end diff --git a/test/flatten.jl b/test/flatten.jl new file mode 100644 index 00000000..9c8796c8 --- /dev/null +++ b/test/flatten.jl @@ -0,0 +1,15 @@ +using StaticArrays, Test + +@testset "Iterators.flatten" begin + for x in [SVector(1.0, 2.0), MVector(1.0, 2.0), + @SMatrix([1.0 2.0; 3.0 4.0]), @MMatrix([1.0 2.0]), + Size(1,2)([1.0 2.0]) + ] + X = [x,x,x] + @test length(Iterators.flatten(X)) == length(X)*length(x) + @test collect(Iterators.flatten(typeof(x)[])) == [] + @test collect(Iterators.flatten(X)) == [x..., x..., x...] + end + @test collect(Iterators.flatten([SVector(1,1), SVector(1)])) == [1,1,1] + @test_throws ArgumentError length(Iterators.flatten([SVector(1,1), SVector(1)])) +end diff --git a/test/runtests.jl b/test/runtests.jl index 5d04b1ac..fd64c65b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -40,6 +40,7 @@ include("lu.jl") Random.seed!(42); include("qr.jl") Random.seed!(42); include("chol.jl") # hermitian_type(::Type{Any}) for block algorithm include("deque.jl") +include("flatten.jl") include("io.jl") include("svd.jl") Random.seed!(42); include("fixed_size_arrays.jl")