From 3d55528042a103e51985831e1a8eb75e419018d1 Mon Sep 17 00:00:00 2001 From: mschauer Date: Thu, 27 Sep 2018 15:42:23 +0200 Subject: [PATCH 1/3] Special case flatten of iterators of static arrays --- src/StaticArrays.jl | 1 + src/flatten.jl | 6 ++++++ 2 files changed, 7 insertions(+) create mode 100644 src/flatten.jl diff --git a/src/StaticArrays.jl b/src/StaticArrays.jl index 1528f371..99218c0b 100644 --- a/src/StaticArrays.jl +++ b/src/StaticArrays.jl @@ -114,6 +114,7 @@ include("svd.jl") include("lu.jl") include("qr.jl") include("deque.jl") +include("flatten.jl") include("io.jl") include("FixedSizeArrays.jl") diff --git a/src/flatten.jl b/src/flatten.jl new file mode 100644 index 00000000..90d121aa --- /dev/null +++ b/src/flatten.jl @@ -0,0 +1,6 @@ +# Special case flatten of iterators of static arrays. +import Base.Iterators: flatten_iteratorsize, flatten_length +flatten_iteratorsize(::Union{Base.HasShape, Base.HasLength}, ::Type{<:Union{SArray,MArray}}) = Base.HasLength() +function flatten_length(f, T::Type{<:Union{SArray,MArray}}) + length(T)*length(f.it) +end From e07ad4fbd3da4054c348af45ed740354d3b782dc Mon Sep 17 00:00:00 2001 From: mschauer Date: Thu, 27 Sep 2018 16:11:53 +0200 Subject: [PATCH 2/3] Tests and refinement --- src/flatten.jl | 4 ++-- test/flatten.jl | 13 +++++++++++++ test/runtests.jl | 1 + 3 files changed, 16 insertions(+), 2 deletions(-) create mode 100644 test/flatten.jl diff --git a/src/flatten.jl b/src/flatten.jl index 90d121aa..34222d91 100644 --- a/src/flatten.jl +++ b/src/flatten.jl @@ -1,6 +1,6 @@ # Special case flatten of iterators of static arrays. import Base.Iterators: flatten_iteratorsize, flatten_length -flatten_iteratorsize(::Union{Base.HasShape, Base.HasLength}, ::Type{<:Union{SArray,MArray}}) = Base.HasLength() -function flatten_length(f, T::Type{<:Union{SArray,MArray}}) +flatten_iteratorsize(::Union{Base.HasShape, Base.HasLength}, ::Type{<:StaticArray}) = Base.HasLength() +function flatten_length(f, T::Type{<:StaticArray}) length(T)*length(f.it) end diff --git a/test/flatten.jl b/test/flatten.jl new file mode 100644 index 00000000..c88735ca --- /dev/null +++ b/test/flatten.jl @@ -0,0 +1,13 @@ +using StaticArrays, Test + +@testset "Iterators.flatten" begin + for x in [SVector(1.0, 2.0), MVector(1.0, 2.0), + @SMatrix([1.0 2.0; 3.0 4.0]), @MMatrix([1.0 2.0]), + Size(1,2)([1.0 2.0]) + ] + X = [x,x,x] + @test length(Iterators.flatten(X)) == length(X)*length(x) + @test collect(Iterators.flatten(typeof(x)[])) == [] + @test collect(Iterators.flatten(X)) == [x..., x..., x...] + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 5d04b1ac..fd64c65b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -40,6 +40,7 @@ include("lu.jl") Random.seed!(42); include("qr.jl") Random.seed!(42); include("chol.jl") # hermitian_type(::Type{Any}) for block algorithm include("deque.jl") +include("flatten.jl") include("io.jl") include("svd.jl") Random.seed!(42); include("fixed_size_arrays.jl") From 76eea3c6dcaa34207aebe0f2dadf228a044808b6 Mon Sep 17 00:00:00 2001 From: mschauer Date: Thu, 27 Sep 2018 16:41:06 +0200 Subject: [PATCH 3/3] Handle mixed size iterators --- src/flatten.jl | 4 ++-- test/flatten.jl | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/flatten.jl b/src/flatten.jl index 34222d91..727e4c1c 100644 --- a/src/flatten.jl +++ b/src/flatten.jl @@ -1,6 +1,6 @@ # Special case flatten of iterators of static arrays. import Base.Iterators: flatten_iteratorsize, flatten_length -flatten_iteratorsize(::Union{Base.HasShape, Base.HasLength}, ::Type{<:StaticArray}) = Base.HasLength() -function flatten_length(f, T::Type{<:StaticArray}) +flatten_iteratorsize(::Union{Base.HasShape, Base.HasLength}, ::Type{<:StaticArray{S}}) where {S} = Base.HasLength() +function flatten_length(f, T::Type{<:StaticArray{S}}) where {S} length(T)*length(f.it) end diff --git a/test/flatten.jl b/test/flatten.jl index c88735ca..9c8796c8 100644 --- a/test/flatten.jl +++ b/test/flatten.jl @@ -10,4 +10,6 @@ using StaticArrays, Test @test collect(Iterators.flatten(typeof(x)[])) == [] @test collect(Iterators.flatten(X)) == [x..., x..., x...] end + @test collect(Iterators.flatten([SVector(1,1), SVector(1)])) == [1,1,1] + @test_throws ArgumentError length(Iterators.flatten([SVector(1,1), SVector(1)])) end