From d28a644f6f65ddc48549766f6037dec2f1f1dc8d Mon Sep 17 00:00:00 2001 From: zero323 Date: Wed, 7 Oct 2015 13:18:15 +0200 Subject: [PATCH 1/2] [SPARK-10973] __gettitem__ method throws IndexError exception when we try to access index after the last non-zero entry. --- python/pyspark/mllib/linalg/__init__.py | 3 +++ python/pyspark/mllib/tests.py | 12 +++++++----- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index ea42127f1651f..1dab377c59dc1 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -770,6 +770,9 @@ def __getitem__(self, index): raise ValueError("Index %d out of bounds." % index) insert_index = np.searchsorted(inds, index) + if insert_index >= self.indices.size: + return 0. + row_ind = inds[insert_index] if row_ind == index: return vals[insert_index] diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 96cf13495aa95..2a6a5cd3fe40e 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -237,15 +237,17 @@ def test_conversion(self): self.assertTrue(dv.array.dtype == 'float64') def test_sparse_vector_indexing(self): - sv = SparseVector(4, {1: 1, 3: 2}) + sv = SparseVector(5, {1: 1, 3: 2}) self.assertEqual(sv[0], 0.) self.assertEqual(sv[3], 2.) self.assertEqual(sv[1], 1.) self.assertEqual(sv[2], 0.) - self.assertEqual(sv[-1], 2) - self.assertEqual(sv[-2], 0) - self.assertEqual(sv[-4], 0) - for ind in [4, -5]: + self.assertEqual(sv[4], 0.) + self.assertEqual(sv[-1], 0.) + self.assertEqual(sv[-2], 2.) + self.assertEqual(sv[-3], 0.) + self.assertEqual(sv[-5], 0.) + for ind in [5, -6]: self.assertRaises(ValueError, sv.__getitem__, ind) for ind in [7.8, '1']: self.assertRaises(TypeError, sv.__getitem__, ind) From a1898ee172d1b3b4e8f69650edb2ecbc507f13d7 Mon Sep 17 00:00:00 2001 From: zero323 Date: Thu, 8 Oct 2015 09:58:22 +0200 Subject: [PATCH 2/2] Use inds.size instead of self.indices.size --- python/pyspark/mllib/linalg/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 1dab377c59dc1..d903b9030d8ce 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -770,7 +770,7 @@ def __getitem__(self, index): raise ValueError("Index %d out of bounds." % index) insert_index = np.searchsorted(inds, index) - if insert_index >= self.indices.size: + if insert_index >= inds.size: return 0. row_ind = inds[insert_index]