From df2b34c9dfdcfa6fa5bf43343a4663c1d1a014b9 Mon Sep 17 00:00:00 2001 From: zero323 Date: Sat, 10 Oct 2015 18:24:43 +0200 Subject: [PATCH] [SPARK-10973][ML][PYTHON] Fix IndexError exception on SparseVector when asked for index after the last non-zero entry --- python/pyspark/mllib/linalg.py | 3 +++ python/pyspark/mllib/tests.py | 20 +++++++++++--------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 7702beb12714..ee1ad033f27a 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -673,6 +673,9 @@ def __getitem__(self, index): raise ValueError("Index %d out of bounds." % index) insert_index = np.searchsorted(inds, index) + if insert_index >= inds.size: + return 0.0 + row_ind = inds[insert_index] if row_ind == index: return vals[insert_index] diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 4335143a8dd5..d883f6f14271 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -137,15 +137,17 @@ def test_conversion(self): self.assertTrue(dv.array.dtype == 'float64') def test_sparse_vector_indexing(self): - sv = SparseVector(4, {1: 1, 3: 2}) - self.assertEquals(sv[0], 0.) - self.assertEquals(sv[3], 2.) - self.assertEquals(sv[1], 1.) - self.assertEquals(sv[2], 0.) - self.assertEquals(sv[-1], 2) - self.assertEquals(sv[-2], 0) - self.assertEquals(sv[-4], 0) - for ind in [4, -5]: + sv = SparseVector(5, {1: 1, 3: 2}) + self.assertEqual(sv[0], 0.) + self.assertEqual(sv[3], 2.) + self.assertEqual(sv[1], 1.) + self.assertEqual(sv[2], 0.) + self.assertEqual(sv[4], 0.) + self.assertEqual(sv[-1], 0.) + self.assertEqual(sv[-2], 2.) + self.assertEqual(sv[-3], 0.) + self.assertEqual(sv[-5], 0.) + for ind in [5, -6]: self.assertRaises(ValueError, sv.__getitem__, ind) for ind in [7.8, '1']: self.assertRaises(TypeError, sv.__getitem__, ind)