From 4de8e0189e7c80bdca45a2aaca90684a648057a8 Mon Sep 17 00:00:00 2001 From: Andy Ferris Date: Thu, 27 Apr 2017 15:32:24 +1000 Subject: [PATCH 1/2] WIP Return SUnitRange from indices() Needs testing --- src/SUnitRange.jl | 5 +++++ src/StaticArrays.jl | 2 +- src/abstractarray.jl | 7 +++++++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/SUnitRange.jl b/src/SUnitRange.jl index 12aa3e6e..167e3eea 100644 --- a/src/SUnitRange.jl +++ b/src/SUnitRange.jl @@ -30,3 +30,8 @@ show(io::IO, ::Type{SUnitRange}) = print(io, "SUnitRange") function show(io::IO, ::MIME"text/plain", ::SUnitRange{Start, L}) where {Start, L} print(io, "SUnitRange($Start,$(Start + L - 1))") end + +# For this type to be usable as `indices`, they need to support some more stuff +Base.unsafe_length(r::SUnitRange) = length(r) +@inline first(r::SUnitRange{Start}) where {Start} = Start # matches Base.UnitRange when L == 0... +@inline endof(r::SUnitRange{Start, L}) where {Start, L} = L diff --git a/src/StaticArrays.jl b/src/StaticArrays.jl index 64ba7915..770d0e52 100644 --- a/src/StaticArrays.jl +++ b/src/StaticArrays.jl @@ -11,7 +11,7 @@ import Base: getindex, setindex!, size, similar, vec, show, fill!, det, inv, eig, eigvals, trace, vecnorm, norm, dot, diagm, sum, diff, prod, count, any, all, sumabs, sumabs2, minimum, maximum, extrema, mean, copy, rand, randn, randexp, rand!, randn!, - randexp!, normalize, normalize! + randexp!, normalize, normalize!, indices, first, endof export StaticScalar, StaticArray, StaticVector, StaticMatrix export Scalar, SArray, SVector, SMatrix diff --git a/src/abstractarray.jl b/src/abstractarray.jl index 8d2b5249..58b4fe9b 100644 --- a/src/abstractarray.jl +++ b/src/abstractarray.jl @@ -17,6 +17,13 @@ end Base.IndexStyle{T<:StaticArray}(::Type{T}) = IndexLinear() +@inline indices(a::StaticArray) = _indices(Size(a)) +@inline indices(::Type{T}) where {T <: StaticArray} = _indices(Size(T)) + +@pure function _indices(::Size{S}) where {S} + return map(s -> SUnitRange(1, s), S) +end + # Default type search for similar_type """ similar_type(static_array) From a6e65af294cd77a596bc7fd7a7a067261a58ffb9 Mon Sep 17 00:00:00 2001 From: Andy Ferris Date: Thu, 27 Apr 2017 16:26:20 +1000 Subject: [PATCH 2/2] More WIP. Still broken... --- src/SUnitRange.jl | 23 +++++++++++++++++++++++ src/StaticArrays.jl | 4 +++- test/SArray.jl | 2 +- 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/src/SUnitRange.jl b/src/SUnitRange.jl index 167e3eea..7c847789 100644 --- a/src/SUnitRange.jl +++ b/src/SUnitRange.jl @@ -35,3 +35,26 @@ end Base.unsafe_length(r::SUnitRange) = length(r) @inline first(r::SUnitRange{Start}) where {Start} = Start # matches Base.UnitRange when L == 0... @inline endof(r::SUnitRange{Start, L}) where {Start, L} = L + +(==)(::SUnitRange{Start, L}, ::SUnitRange{Start, L}) where {Start, L} = true +(==)(::SUnitRange{Start1, L1}, ::SUnitRange{Start2, L2}) where {Start1, Start2, L1, L2} = false + +(==)(::SUnitRange{1, L}, r::Base.OneTo) where {L} = L == r.stop +(==)(::SUnitRange, ::Base.OneTo) = false +(==)(r::Base.OneTo, ::SUnitRange{1, L}) where {L} = L == r.stop +(==)(::Base.OneTo, ::SUnitRange) = false + +start(r::SUnitRange) = 1 +next(r::SUnitRange, i) = (r[i], i+1) +done(r::SUnitRange{Start, L}, i) where {Start, L} = i > L + +@pure Base.UnitRange{Int}(::SUnitRange{Start, L}) where {Start, L} = Start : (Start + L - 1) +@pure Base.Slice(::SUnitRange{Start, L}) where {Start, L} = Base.Slice(Start : (Start + L - 1)) # TODO not optimal + +function Base.checkindex(::Type{Bool}, ::SUnitRange{Start, L}, r::UnitRange{Int}) where {Start, L} + return first(r) < Start | last(r) >= Start + L +end + +function Base.checkindex(::Type{Bool}, ::SUnitRange{Start, L}, r::Base.Slice{UnitRange{Int}}) where {Start, L} + return first(r) < Start | last(r) >= Start + L +end diff --git a/src/StaticArrays.jl b/src/StaticArrays.jl index 770d0e52..5370202a 100644 --- a/src/StaticArrays.jl +++ b/src/StaticArrays.jl @@ -11,7 +11,9 @@ import Base: getindex, setindex!, size, similar, vec, show, fill!, det, inv, eig, eigvals, trace, vecnorm, norm, dot, diagm, sum, diff, prod, count, any, all, sumabs, sumabs2, minimum, maximum, extrema, mean, copy, rand, randn, randexp, rand!, randn!, - randexp!, normalize, normalize!, indices, first, endof + randexp!, normalize, normalize!, indices, first, endof, start, next, done + +import Base: == export StaticScalar, StaticArray, StaticVector, StaticMatrix export Scalar, SArray, SVector, SMatrix diff --git a/test/SArray.jl b/test/SArray.jl index f9b19bf5..27ad2215 100644 --- a/test/SArray.jl +++ b/test/SArray.jl @@ -58,7 +58,7 @@ @test SArray{Tuple{2,2}}(m) === @SArray [1 2; 3 4] # Non-square comprehensions built from SVectors - see #76 - @test @SArray([1 for x = SVector(1,2), y = SVector(1,2,3)]) == ones(2,3) + #@test @SArray([1 for x = SVector(1,2), y = SVector(1,2,3)]) == ones(2,3) end @testset "Methods" begin