Skip to content

Commit 8ac71d6

Browse files
zero323jkbradley
authored andcommitted
[SPARK-11084] [ML] [PYTHON] Check if index can contain non-zero value before binary search
At this moment `SparseVector.__getitem__` executes `np.searchsorted` first and checks if result is in an expected range after that. It is possible to check if index can contain non-zero value before executing `np.searchsorted`. Author: zero323 <[email protected]> Closes #9098 from zero323/sparse_vector_getitem_improved.
1 parent 10046ea commit 8ac71d6

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

python/pyspark/mllib/linalg/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -770,10 +770,10 @@ def __getitem__(self, index):
770770
if index < 0:
771771
index += self.size
772772

773-
insert_index = np.searchsorted(inds, index)
774-
if insert_index >= inds.size:
773+
if (inds.size == 0) or (index > inds.item(-1)):
775774
return 0.
776775

776+
insert_index = np.searchsorted(inds, index)
777777
row_ind = inds[insert_index]
778778
if row_ind == index:
779779
return vals[insert_index]

python/pyspark/mllib/tests.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,16 @@ def test_sparse_vector_indexing(self):
252252
for ind in [7.8, '1']:
253253
self.assertRaises(TypeError, sv.__getitem__, ind)
254254

255+
zeros = SparseVector(4, {})
256+
self.assertEqual(zeros[0], 0.0)
257+
self.assertEqual(zeros[3], 0.0)
258+
for ind in [4, -5]:
259+
self.assertRaises(ValueError, zeros.__getitem__, ind)
260+
261+
empty = SparseVector(0, {})
262+
for ind in [-1, 0, 1]:
263+
self.assertRaises(ValueError, empty.__getitem__, ind)
264+
255265
def test_matrix_indexing(self):
256266
mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10])
257267
expected = [[0, 6], [1, 8], [4, 10]]

0 commit comments

Comments
 (0)