Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 61 additions & 19 deletions base/reshapedarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,28 @@ ReshapedArray{T,N}(parent::AbstractArray{T}, dims::NTuple{N,Int}, mi) = Reshaped
typealias ReshapedArrayLF{T,N,P<:AbstractArray} ReshapedArray{T,N,P,Tuple{}}

# Fast iteration on ReshapedArrays: use the parent iterator
immutable ReshapedRange{I,M}
immutable ReshapedArrayIterator{I,M}
iter::I
mi::NTuple{M,SignedMultiplicativeInverse{Int}}
end
ReshapedRange(A::ReshapedArray) = reshapedrange(parent(A), A.mi)
function reshapedrange{M}(P, mi::NTuple{M})
ReshapedArrayIterator(A::ReshapedArray) = _rs_iterator(parent(A), A.mi)
function _rs_iterator{M}(P, mi::NTuple{M})
iter = eachindex(P)
ReshapedRange{typeof(iter),M}(iter, mi)
ReshapedArrayIterator{typeof(iter),M}(iter, mi)
end

immutable ReshapedIndex{T}
parentindex::T
end

# eachindex(A::ReshapedArray) = ReshapedRange(A) # TODO: uncomment this line
start(R::ReshapedRange) = start(R.iter)
@inline done(R::ReshapedRange, i) = done(R.iter, i)
@inline function next(R::ReshapedRange, i)
# eachindex(A::ReshapedArray) = ReshapedArrayIterator(A) # TODO: uncomment this line
start(R::ReshapedArrayIterator) = start(R.iter)
@inline done(R::ReshapedArrayIterator, i) = done(R.iter, i)
@inline function next(R::ReshapedArrayIterator, i)
item, inext = next(R.iter, i)
ReshapedIndex(item), inext
end
length(R::ReshapedRange) = length(R.iter)
length(R::ReshapedArrayIterator) = length(R.iter)

function reshape(parent::AbstractArray, dims::Dims)
prod(dims) == length(parent) || throw(DimensionMismatch("parent has $(length(parent)) elements, which is incompatible with size $dims"))
Expand Down Expand Up @@ -84,19 +84,61 @@ reinterpret{T}(::Type{T}, A::ReshapedArray, dims::Dims) = reinterpret(T, parent(
ind2sub_rs((d+1, out...), tail(strds), r)
end

@inline getindex(A::ReshapedArrayLF, index::Int) = (@boundscheck checkbounds(A, index); @inbounds ret = parent(A)[index]; ret)
@inline getindex(A::ReshapedArray, indexes::Int...) = (@boundscheck checkbounds(A, indexes...); _unsafe_getindex(A, indexes...))
@inline getindex(A::ReshapedArray, index::ReshapedIndex) = (@boundscheck checkbounds(parent(A), index.parentindex); @inbounds ret = parent(A)[index.parentindex]; ret)
@inline function getindex(A::ReshapedArrayLF, index::Int)
@boundscheck checkbounds(A, index)
@inbounds ret = parent(A)[index]
ret
end
@inline function getindex(A::ReshapedArray, indexes::Int...)
@boundscheck checkbounds(A, indexes...)
_unsafe_getindex(A, indexes...)
end
@inline function getindex(A::ReshapedArray, index::ReshapedIndex)
@boundscheck checkbounds(parent(A), index.parentindex)
@inbounds ret = parent(A)[index.parentindex]
ret
end

@inline function _unsafe_getindex(A::ReshapedArray, indexes::Int...)
@inbounds ret = parent(A)[ind2sub_rs(A.mi, sub2ind(size(A), indexes...))...]
ret
end
@inline function _unsafe_getindex(A::ReshapedArrayLF, indexes::Int...)
@inbounds ret = parent(A)[sub2ind(size(A), indexes...)]
ret
end

@inline _unsafe_getindex(A::ReshapedArray, indexes::Int...) = (@inbounds ret = parent(A)[ind2sub_rs(A.mi, sub2ind(size(A), indexes...))...]; ret)
@inline _unsafe_getindex(A::ReshapedArrayLF, indexes::Int...) = (@inbounds ret = parent(A)[sub2ind(size(A), indexes...)]; ret)
@inline function setindex!(A::ReshapedArrayLF, val, index::Int)
@boundscheck checkbounds(A, index)
@inbounds parent(A)[index] = val
val
end
@inline function setindex!(A::ReshapedArray, val, indexes::Int...)
@boundscheck checkbounds(A, indexes...)
_unsafe_setindex!(A, val, indexes...)
end
@inline function setindex!(A::ReshapedArray, val, index::ReshapedIndex)
@boundscheck checkbounds(parent(A), index.parentindex)
@inbounds parent(A)[index.parentindex] = val
val
end

@inline function _unsafe_setindex!(A::ReshapedArray, val, indexes::Int...)
@inbounds parent(A)[ind2sub_rs(A.mi, sub2ind(size(A), indexes...))...] = val
val
end
@inline function _unsafe_setindex!(A::ReshapedArrayLF, val, indexes::Int...)
@inbounds parent(A)[sub2ind(size(A), indexes...)] = val
val
end

@inline setindex!(A::ReshapedArrayLF, val, index::Int) = (@boundscheck checkbounds(A, index); @inbounds parent(A)[index] = val; val)
@inline setindex!(A::ReshapedArray, val, indexes::Int...) = (@boundscheck checkbounds(A, indexes...); _unsafe_setindex!(A, val, indexes...))
@inline setindex!(A::ReshapedArray, val, index::ReshapedIndex) = (@boundscheck checkbounds(parent(A), index.parentindex); @inbounds parent(A)[index.parentindex] = val; val)
# helpful error message for a common failure case
typealias ReshapedRange{T,N,A<:Range} ReshapedArray{T,N,A,Tuple{}}
setindex!(A::ReshapedRange, val, index::Int) = _rs_setindex!_err()
setindex!(A::ReshapedRange, val, indexes::Int...) = _rs_setindex!_err()
setindex!(A::ReshapedRange, val, index::ReshapedIndex) = _rs_setindex!_err()

@inline _unsafe_setindex!(A::ReshapedArray, val, indexes::Int...) = (@inbounds parent(A)[ind2sub_rs(A.mi, sub2ind(size(A), indexes...))...] = val; val)
@inline _unsafe_setindex!(A::ReshapedArrayLF, val, indexes::Int...) = (@inbounds parent(A)[sub2ind(size(A), indexes...)] = val; val)
_rs_setindex!_err() = error("indexed assignment fails for a reshaped range; consider calling collect")

typealias ArrayT{N, T} Array{T,N}
convert{T,S,N}(::Type{Array{T,N}}, V::ReshapedArray{S,N}) = copy!(Array(T, size(V)), V)
Expand Down
41 changes: 39 additions & 2 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,50 @@ a = reshape(b, (2, 2, 2, 2, 2))
@test a[2,2,2,2,2] == b[end]

# reshaping linearslow arrays
a = zeros(1, 5)
a = collect(reshape(1:5, 1, 5))
s = sub(a, :, [2,3,5])
@test length(reshape(s, length(s))) == 3
r = reshape(s, length(s))
@test length(r) == 3
@test r[1] == 2
@test r[3,1] == 5
@test r[Base.ReshapedIndex(CartesianIndex((1,2)))] == 3
@test parent(reshape(r, (1,3))) === r.parent === s
@test parentindexes(r) == (1:1, 1:3)
@test reshape(r, (3,)) === r
r[2] = -1
@test a[3] == -1
a = zeros(0, 5) # an empty linearslow array
s = sub(a, :, [2,3,5])
@test length(reshape(s, length(s))) == 0

@test reshape(1:5, (5,)) === 1:5
@test reshape(1:5, 5) === 1:5

# setindex! on a reshaped range
a = reshape(1:20, 5, 4)
for idx in ((3,), (2,2), (Base.ReshapedIndex(1),))
try
a[idx...] = 7
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is supposed to always fail, needs an error("unexpected") or something so it doesn't silently pass

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, thanks.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

catch err
@test err.msg == "indexed assignment fails for a reshaped range; consider calling collect"
end
end

# operations with LinearFast ReshapedArray
b = collect(1:12)
a = Base.ReshapedArray(b, (4,3), ())
@test a[3,2] == 7
@test a[6] == 6
a[3,2] = -2
a[6] = -3
a[Base.ReshapedIndex(5)] = -4
@test b[5] == -4
@test b[6] == -3
@test b[7] == -2
b = reinterpret(Int, a, (3,4))
b[1] = -1
@test vec(b) == vec(a)

a = rand(1, 1, 8, 8, 1)
@test @inferred(squeeze(a, 1)) == @inferred(squeeze(a, (1,))) == reshape(a, (1, 8, 8, 1))
@test @inferred(squeeze(a, (1, 5))) == squeeze(a, (5, 1)) == reshape(a, (1, 8, 8))
Expand Down
67 changes: 34 additions & 33 deletions test/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -324,22 +324,21 @@ t1 = bitrand(n1, n2)
b2 = bitrand(countnz(t1))
@check_bit_operation setindex!(b1, b2, t1) BitMatrix

m1 = rand(1:n1)
m2 = rand(1:n2)

t1 = bitrand(n1)
b2 = bitrand(countnz(t1), m2)
k2 = randperm(m2)
@check_bit_operation setindex!(b1, b2, t1, 1:m2) BitMatrix
@check_bit_operation setindex!(b1, b2, t1, n2-m2+1:n2) BitMatrix
@check_bit_operation setindex!(b1, b2, t1, k2) BitMatrix

t2 = bitrand(n2)
b2 = bitrand(m1, countnz(t2))
k1 = randperm(m1)
@check_bit_operation setindex!(b1, b2, 1:m1, t2) BitMatrix
@check_bit_operation setindex!(b1, b2, n1-m1+1:n1, t2) BitMatrix
@check_bit_operation setindex!(b1, b2, k1, t2) BitMatrix
let m1 = rand(1:n1), m2 = rand(1:n2)
t1 = bitrand(n1)
b2 = bitrand(countnz(t1), m2)
k2 = randperm(m2)
@check_bit_operation setindex!(b1, b2, t1, 1:m2) BitMatrix
@check_bit_operation setindex!(b1, b2, t1, n2-m2+1:n2) BitMatrix
@check_bit_operation setindex!(b1, b2, t1, k2) BitMatrix

t2 = bitrand(n2)
b2 = bitrand(m1, countnz(t2))
k1 = randperm(m1)
@check_bit_operation setindex!(b1, b2, 1:m1, t2) BitMatrix
@check_bit_operation setindex!(b1, b2, n1-m1+1:n1, t2) BitMatrix
@check_bit_operation setindex!(b1, b2, k1, t2) BitMatrix
end

timesofar("indexing")

Expand Down Expand Up @@ -1054,23 +1053,25 @@ end

## Reductions ##

b1 = bitrand(s1, s2, s3, s4)
m1 = 1
m2 = 3
@check_bit_operation maximum(b1, (m1, m2)) BitArray{4}
@check_bit_operation minimum(b1, (m1, m2)) BitArray{4}
@check_bit_operation sum(b1, (m1, m2)) Array{Int,4}

@check_bit_operation maximum(b1) Bool
@check_bit_operation minimum(b1) Bool
@check_bit_operation any(b1) Bool
@check_bit_operation all(b1) Bool
@check_bit_operation sum(b1) Int

b0 = falses(0)
@check_bit_operation any(b0) Bool
@check_bit_operation all(b0) Bool
@check_bit_operation sum(b0) Int
let
b1 = bitrand(s1, s2, s3, s4)
m1 = 1
m2 = 3
@check_bit_operation maximum(b1, (m1, m2)) BitArray{4}
@check_bit_operation minimum(b1, (m1, m2)) BitArray{4}
@check_bit_operation sum(b1, (m1, m2)) Array{Int,4}

@check_bit_operation maximum(b1) Bool
@check_bit_operation minimum(b1) Bool
@check_bit_operation any(b1) Bool
@check_bit_operation all(b1) Bool
@check_bit_operation sum(b1) Int

b0 = falses(0)
@check_bit_operation any(b0) Bool
@check_bit_operation all(b0) Bool
@check_bit_operation sum(b0) Int
end

timesofar("reductions")

Expand Down