Skip to content

Commit d28a644

Browse files
committed
[SPARK-10973] __gettitem__ method throws IndexError exception when we try to access index after the last non-zero entry.
1 parent ffe6831 commit d28a644

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

python/pyspark/mllib/linalg/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,9 @@ def __getitem__(self, index):
770770
raise ValueError("Index %d out of bounds." % index)
771771

772772
insert_index = np.searchsorted(inds, index)
773+
if insert_index >= self.indices.size:
774+
return 0.
775+
773776
row_ind = inds[insert_index]
774777
if row_ind == index:
775778
return vals[insert_index]

python/pyspark/mllib/tests.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -237,15 +237,17 @@ def test_conversion(self):
237237
self.assertTrue(dv.array.dtype == 'float64')
238238

239239
def test_sparse_vector_indexing(self):
240-
sv = SparseVector(4, {1: 1, 3: 2})
240+
sv = SparseVector(5, {1: 1, 3: 2})
241241
self.assertEqual(sv[0], 0.)
242242
self.assertEqual(sv[3], 2.)
243243
self.assertEqual(sv[1], 1.)
244244
self.assertEqual(sv[2], 0.)
245-
self.assertEqual(sv[-1], 2)
246-
self.assertEqual(sv[-2], 0)
247-
self.assertEqual(sv[-4], 0)
248-
for ind in [4, -5]:
245+
self.assertEqual(sv[4], 0.)
246+
self.assertEqual(sv[-1], 0.)
247+
self.assertEqual(sv[-2], 2.)
248+
self.assertEqual(sv[-3], 0.)
249+
self.assertEqual(sv[-5], 0.)
250+
for ind in [5, -6]:
249251
self.assertRaises(ValueError, sv.__getitem__, ind)
250252
for ind in [7.8, '1']:
251253
self.assertRaises(TypeError, sv.__getitem__, ind)

0 commit comments

Comments
 (0)