55import numpy as np
66import pytest
77
8- from .. import ones , asarray , result_type , all , equal
8+ from .. import ones , arange , reshape , asarray , result_type , all , equal
99from .._array_object import Array , CPU_DEVICE , Device
1010from .._dtypes import (
1111 _all_dtypes ,
@@ -70,11 +70,25 @@ def test_validate_index():
7070 assert_raises (IndexError , lambda : a [[True , True , True ]])
7171 assert_raises (IndexError , lambda : a [(True , True , True ),])
7272
73- # Integer array indices are not allowed (except for 0-D)
74- idx = asarray ([0 , 1 ])
73+ # Integer array indices are not allowed (except for 0-D or 1D )
74+ idx = asarray ([[ 0 , 1 ]]) # idx.ndim == 2
7575 assert_raises (IndexError , lambda : a [idx , 0 ])
7676 assert_raises (IndexError , lambda : a [0 , idx ])
7777
78+ # Mixing 1D integer array indices with slices, ellipsis or booleans is not allowed
79+ idx = asarray ([0 , 1 ])
80+ assert_raises (IndexError , lambda : a [..., idx ])
81+ assert_raises (IndexError , lambda : a [:, idx ])
82+ assert_raises (IndexError , lambda : a [asarray ([True , True ]), idx ])
83+
84+ # 1D integer array indices must have the same length
85+ idx1 = asarray ([0 , 1 ])
86+ idx2 = asarray ([0 , 1 , 1 ])
87+ assert_raises (IndexError , lambda : a [idx1 , idx2 ])
88+
89+ # Non-integer array indices are not allowed
90+ assert_raises (IndexError , lambda : a [ones (2 ), 0 ])
91+
7892 # Array-likes (lists, tuples) are not allowed as indices
7993 assert_raises (IndexError , lambda : a [[0 , 1 ]])
8094 assert_raises (IndexError , lambda : a [(0 , 1 ), (0 , 1 )])
@@ -91,6 +105,37 @@ def test_validate_index():
91105 assert_raises (IndexError , lambda : a [:])
92106 assert_raises (IndexError , lambda : a [idx ])
93107
108+
109+ def test_indexing_arrays ():
110+ # indexing with 1D integer arrays and mixes of integers and 1D integer are allowed
111+
112+ # 1D array
113+ a = arange (5 )
114+ idx = asarray ([1 , 0 , 1 , 2 , - 1 ])
115+ a_idx = a [idx ]
116+
117+ a_idx_loop = asarray ([a [idx [i ]] for i in range (idx .shape [0 ])])
118+ assert all (a_idx == a_idx_loop )
119+
120+ # setitem with arrays is not allowed # XXX
121+ # with assert_raises(IndexError):
122+ # a[idx] = 42
123+
124+ # mixed array and integer indexing
125+ a = reshape (arange (3 * 4 ), (3 , 4 ))
126+ idx = asarray ([1 , 0 , 1 , 2 , - 1 ])
127+ a_idx = a [idx , 1 ]
128+
129+ a_idx_loop = asarray ([a [idx [i ], 1 ] for i in range (idx .shape [0 ])])
130+ assert all (a_idx == a_idx_loop )
131+
132+
133+ # index with two arrays
134+ a_idx = a [idx , idx ]
135+ a_idx_loop = asarray ([a [idx [i ], idx [i ]] for i in range (idx .shape [0 ])])
136+ assert all (a_idx == a_idx_loop )
137+
138+
94139def test_promoted_scalar_inherits_device ():
95140 device1 = Device ("device1" )
96141 x = asarray ([1. , 2 , 3 ], device = device1 )
0 commit comments