Skip to content

Commit df2b34c

Browse files
committed
[SPARK-10973][ML][PYTHON] Fix IndexError exception on SparseVector when asked for index after the last non-zero entry
1 parent d4a74a2 commit df2b34c

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

python/pyspark/mllib/linalg.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,9 @@ def __getitem__(self, index):
673673
raise ValueError("Index %d out of bounds." % index)
674674

675675
insert_index = np.searchsorted(inds, index)
676+
if insert_index >= inds.size:
677+
return 0.0
678+
676679
row_ind = inds[insert_index]
677680
if row_ind == index:
678681
return vals[insert_index]

python/pyspark/mllib/tests.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -137,15 +137,17 @@ def test_conversion(self):
137137
self.assertTrue(dv.array.dtype == 'float64')
138138

139139
def test_sparse_vector_indexing(self):
140-
sv = SparseVector(4, {1: 1, 3: 2})
141-
self.assertEquals(sv[0], 0.)
142-
self.assertEquals(sv[3], 2.)
143-
self.assertEquals(sv[1], 1.)
144-
self.assertEquals(sv[2], 0.)
145-
self.assertEquals(sv[-1], 2)
146-
self.assertEquals(sv[-2], 0)
147-
self.assertEquals(sv[-4], 0)
148-
for ind in [4, -5]:
140+
sv = SparseVector(5, {1: 1, 3: 2})
141+
self.assertEqual(sv[0], 0.)
142+
self.assertEqual(sv[3], 2.)
143+
self.assertEqual(sv[1], 1.)
144+
self.assertEqual(sv[2], 0.)
145+
self.assertEqual(sv[4], 0.)
146+
self.assertEqual(sv[-1], 0.)
147+
self.assertEqual(sv[-2], 2.)
148+
self.assertEqual(sv[-3], 0.)
149+
self.assertEqual(sv[-5], 0.)
150+
for ind in [5, -6]:
149151
self.assertRaises(ValueError, sv.__getitem__, ind)
150152
for ind in [7.8, '1']:
151153
self.assertRaises(TypeError, sv.__getitem__, ind)

0 commit comments

Comments
 (0)