Skip to content

Commit 31c192e

Browse files
committed
Check if index can contain non-zero value before binary search
1 parent 8e67882 commit 31c192e

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

python/pyspark/mllib/linalg/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -769,10 +769,10 @@ def __getitem__(self, index):
769769
if index >= self.size or index < 0:
770770
raise ValueError("Index %d out of bounds." % index)
771771

772-
insert_index = np.searchsorted(inds, index)
773-
if insert_index >= inds.size:
772+
if (inds.size == 0) or (index > inds.item(-1)):
774773
return 0.
775774

775+
insert_index = np.searchsorted(inds, index)
776776
row_ind = inds[insert_index]
777777
if row_ind == index:
778778
return vals[insert_index]

python/pyspark/mllib/tests.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,16 @@ def test_sparse_vector_indexing(self):
252252
for ind in [7.8, '1']:
253253
self.assertRaises(TypeError, sv.__getitem__, ind)
254254

255+
zeros = SparseVector(4, {})
256+
self.assertEqual(zeros[0], 0.0)
257+
self.assertEqual(zeros[3], 0.0)
258+
for ind in [4, -5]:
259+
self.assertRaises(ValueError, zeros.__getitem__, ind)
260+
261+
empty = SparseVector(0, {})
262+
for ind in [-1, 0, 1]:
263+
self.assertRaises(ValueError, empty.__getitem__, ind)
264+
255265
def test_matrix_indexing(self):
256266
mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10])
257267
expected = [[0, 6], [1, 8], [4, 10]]

0 commit comments

Comments
 (0)