From 31c192efafe0e0acd95a0c01447c4df67ce7d3be Mon Sep 17 00:00:00 2001 From: zero323 Date: Tue, 13 Oct 2015 13:59:42 +0200 Subject: [PATCH] Check if index can contain non-zero value before binary search --- python/pyspark/mllib/linalg/__init__.py | 4 ++-- python/pyspark/mllib/tests.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index d903b9030d8c..e86668a3601f 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -769,10 +769,10 @@ def __getitem__(self, index): if index >= self.size or index < 0: raise ValueError("Index %d out of bounds." % index) - insert_index = np.searchsorted(inds, index) - if insert_index >= inds.size: + if (inds.size == 0) or (index > inds.item(-1)): return 0. + insert_index = np.searchsorted(inds, index) row_ind = inds[insert_index] if row_ind == index: return vals[insert_index] diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 2a6a5cd3fe40..2ad69a0ab1d3 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -252,6 +252,16 @@ def test_sparse_vector_indexing(self): for ind in [7.8, '1']: self.assertRaises(TypeError, sv.__getitem__, ind) + zeros = SparseVector(4, {}) + self.assertEqual(zeros[0], 0.0) + self.assertEqual(zeros[3], 0.0) + for ind in [4, -5]: + self.assertRaises(ValueError, zeros.__getitem__, ind) + + empty = SparseVector(0, {}) + for ind in [-1, 0, 1]: + self.assertRaises(ValueError, empty.__getitem__, ind) + def test_matrix_indexing(self): mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) expected = [[0, 6], [1, 8], [4, 10]]