@@ -166,15 +166,15 @@ index_lengths_dim(A, dim, ::Colon) = (trailingsize(A, dim),)
166166@inline index_lengths_dim (A, dim, i:: AbstractArray{Bool} , I... ) = (sum (i), index_lengths_dim (A, dim+ 1 , I... )... )
167167@inline index_lengths_dim (A, dim, i:: AbstractArray , I... ) = (length (i), index_lengths_dim (A, dim+ 1 , I... )... )
168168
169- # shape of array to create for getindex() with indexes I, dropping trailing scalars
169+ # shape of array to create for getindex() with indexes I, dropping scalars
170170index_shape (A:: AbstractArray , I:: AbstractArray ) = size (I) # Linear index reshape
171171index_shape (A:: AbstractArray , I:: AbstractArray{Bool} ) = (sum (I),) # Logical index
172172index_shape (A:: AbstractArray , I:: Colon ) = (length (A),)
173173@inline index_shape (A:: AbstractArray , I... ) = index_shape_dim (A, 1 , I... )
174174index_shape_dim (A, dim, I:: Real... ) = ()
175175index_shape_dim (A, dim, :: Colon ) = (trailingsize (A, dim),)
176176@inline index_shape_dim (A, dim, :: Colon , i, I... ) = (size (A, dim), index_shape_dim (A, dim+ 1 , i, I... )... )
177- @inline index_shape_dim (A, dim, :: Real , I... ) = (1 , index_shape_dim (A, dim+ 1 , I... )... )
177+ @inline index_shape_dim (A, dim, :: Real , I... ) = (index_shape_dim (A, dim+ 1 , I... )... )
178178@inline index_shape_dim (A, dim, i:: AbstractVector{Bool} , I... ) = (sum (i), index_shape_dim (A, dim+ 1 , I... )... )
179179@inline index_shape_dim (A, dim, i:: AbstractVector , I... ) = (length (i), index_shape_dim (A, dim+ 1 , I... )... )
180180
238238 $ (Expr (:meta , :inline ))
239239 D = eachindex (dest)
240240 Ds = start (D)
241- @nloops $ N i dest d-> (j_d = unsafe_getindex (I[d], i_d)) begin
241+ idxlens = index_lengths (src, I... ) # TODO : unsplat?
242+ @nloops $ N i d-> (1 : idxlens[d]) d-> (j_d = unsafe_getindex (I[d], i_d)) begin
242243 d, Ds = next (D, Ds)
243244 v = @ncall $ N unsafe_getindex src j
244245 unsafe_setindex! (dest, v, d)
@@ -248,18 +249,18 @@ end
248249end
249250
250251# checksize ensures the output array A is the correct size for the given indices
251- checksize (A :: AbstractArray , I :: AbstractArray ) = size (A ) == size (I) || throw (DimensionMismatch (" index 1 has size $( size (I)) , but size(A) = $(size (A)) " ))
252- checksize (A :: AbstractArray , I :: AbstractArray{Bool} ) = length (A) == sum (I) || throw ( DimensionMismatch ( " index 1 selects $( sum (I)) elements, but length(A) = $( length (A)) " ))
253- @generated function checksize (A:: AbstractArray , I... )
254- N = length ( I)
255- quote
256- @nexprs $ N d -> ( _checksize (A, d, I[d]) || throw ( DimensionMismatch ( " index $d selects $( length (I[d])) elements, but size( A, $d ) = $( size (A,d)) " )) )
257- end
258- end
259- _checksize (A:: AbstractArray , dim, I ) = size (A, dim) == length (I )
260- _checksize (A:: AbstractArray , dim, I:: AbstractVector{Bool} ) = size (A, dim) == sum (I )
261- _checksize (A:: AbstractArray , dim, :: Colon ) = true
262- _checksize (A:: AbstractArray , dim, :: Real ) = size (A, dim) == 1
252+ @noinline throw_checksize_error (arr, dim, idx ) = throw (DimensionMismatch (" index $d selects $( length (I[d])) elements , but size(A, $d ) = $(size (A,d )) " ))
253+
254+ checksize (A:: AbstractArray , I:: AbstractArray ) = size (A) == size (I) || throw_checksize_error (A, 1 , I )
255+ checksize (A :: AbstractArray , I :: AbstractArray{Bool} ) = length (A) == sum (I) || throw_checksize_error (A, 1 , I)
256+
257+ checksize (A :: AbstractArray , I ... ) = _checksize ( A, 1 , I ... )
258+ _checksize (A :: AbstractArray , dim) = true
259+ # Skip scalars
260+ _checksize (A:: AbstractArray , dim, :: Real , J ... ) = _checksize (A, dim, J ... )
261+ _checksize (A:: AbstractArray , dim, I, J ... ) = ( size (A, dim) == length (I) || throw_checksize_error (A, dim, I); _checksize (A, dim + 1 , J ... ) )
262+ _checksize (A:: AbstractArray , dim, I :: AbstractVector{Bool} , J ... ) = ( size (A, dim) == sum (I) || throw_checksize_error (A, dim, I); _checksize (A, dim + 1 , J ... ))
263+ _checksize (A:: AbstractArray , dim, :: Colon , J ... ) = _checksize (A, dim+ 1 , J ... )
263264
264265@inline unsafe_setindex! (v:: BitArray , x:: Bool , ind:: Int ) = (Base. unsafe_bitsetindex! (v. chunks, x, ind); v)
265266@inline unsafe_setindex! (v:: BitArray , x, ind:: Real ) = (Base. unsafe_bitsetindex! (v. chunks, convert (Bool, x), to_index (ind)); v)
585586
586587 storeind = 1
587588 Xc, Bc = X. chunks, B. chunks
588- @nloops ($ N, i, d-> 1 : size (X, d+ 1 ),
589+ idxlens = index_lengths (B, I0, I... ) # TODO : unsplat?
590+ @nloops ($ N, i, d-> (1 : idxlens[d+ 1 ]),
589591 d-> nothing , # PRE
590592 d-> (ind += stride_lst_d - gap_lst_d), # POST
591593 begin # BODY
608610 $ (symbol (:offset_ , N)) = 1
609611 ind = 0
610612 Xc, Bc = X. chunks, B. chunks
611- @nloops $ N i X d-> (offset_{d- 1 } = offset_d + (unsafe_getindex (I[d], i_d)- 1 )* stride_d) begin
613+ idxlens = index_lengths (B, I... ) # TODO : unsplat?
614+ @nloops $ N i d-> (1 : idxlens[d]) d-> (offset_{d- 1 } = offset_d + (unsafe_getindex (I[d], i_d)- 1 )* stride_d) begin
612615 ind += 1
613616 unsafe_bitsetindex! (Xc, unsafe_bitgetindex (Bc, offset_0), ind)
614617 end
0 commit comments