Skip to content

Commit 518a3d1

Browse files
HyukjinKwonBryanCutler
authored andcommitted
[SPARK-26033][SPARK-26034][PYTHON][FOLLOW-UP] Small cleanup and deduplication in ml/mllib tests
## What changes were proposed in this pull request? This PR is a small follow up that puts some logic and functions into smaller scope and make it localized, and deduplicate. ## How was this patch tested? Manually tested. Jenkins tests as well. Closes #23200 from HyukjinKwon/followup-SPARK-26034-SPARK-26033. Authored-by: Hyukjin Kwon <[email protected]> Signed-off-by: Bryan Cutler <[email protected]>
1 parent 187bb7d commit 518a3d1

File tree

4 files changed

+51
-68
lines changed

4 files changed

+51
-68
lines changed

python/pyspark/ml/tests/test_linalg.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,25 +20,17 @@
2020

2121
from numpy import arange, array, array_equal, inf, ones, tile, zeros
2222

23+
from pyspark.serializers import PickleSerializer
2324
from pyspark.ml.linalg import DenseMatrix, DenseVector, MatrixUDT, SparseMatrix, SparseVector, \
2425
Vector, VectorUDT, Vectors
25-
from pyspark.testing.mllibutils import make_serializer, MLlibTestCase
26+
from pyspark.testing.mllibutils import MLlibTestCase
2627
from pyspark.sql import Row
2728

2829

29-
ser = make_serializer()
30-
31-
32-
def _squared_distance(a, b):
33-
if isinstance(a, Vector):
34-
return a.squared_distance(b)
35-
else:
36-
return b.squared_distance(a)
37-
38-
3930
class VectorTests(MLlibTestCase):
4031

4132
def _test_serialize(self, v):
33+
ser = PickleSerializer()
4234
self.assertEqual(v, ser.loads(ser.dumps(v)))
4335
jvec = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(v)))
4436
nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvec)))
@@ -77,24 +69,30 @@ def test_dot(self):
7769
self.assertEqual(7.0, sv.dot(arr))
7870

7971
def test_squared_distance(self):
72+
def squared_distance(a, b):
73+
if isinstance(a, Vector):
74+
return a.squared_distance(b)
75+
else:
76+
return b.squared_distance(a)
77+
8078
sv = SparseVector(4, {1: 1, 3: 2})
8179
dv = DenseVector(array([1., 2., 3., 4.]))
8280
lst = DenseVector([4, 3, 2, 1])
8381
lst1 = [4, 3, 2, 1]
8482
arr = pyarray.array('d', [0, 2, 1, 3])
8583
narr = array([0, 2, 1, 3])
86-
self.assertEqual(15.0, _squared_distance(sv, dv))
87-
self.assertEqual(25.0, _squared_distance(sv, lst))
88-
self.assertEqual(20.0, _squared_distance(dv, lst))
89-
self.assertEqual(15.0, _squared_distance(dv, sv))
90-
self.assertEqual(25.0, _squared_distance(lst, sv))
91-
self.assertEqual(20.0, _squared_distance(lst, dv))
92-
self.assertEqual(0.0, _squared_distance(sv, sv))
93-
self.assertEqual(0.0, _squared_distance(dv, dv))
94-
self.assertEqual(0.0, _squared_distance(lst, lst))
95-
self.assertEqual(25.0, _squared_distance(sv, lst1))
96-
self.assertEqual(3.0, _squared_distance(sv, arr))
97-
self.assertEqual(3.0, _squared_distance(sv, narr))
84+
self.assertEqual(15.0, squared_distance(sv, dv))
85+
self.assertEqual(25.0, squared_distance(sv, lst))
86+
self.assertEqual(20.0, squared_distance(dv, lst))
87+
self.assertEqual(15.0, squared_distance(dv, sv))
88+
self.assertEqual(25.0, squared_distance(lst, sv))
89+
self.assertEqual(20.0, squared_distance(lst, dv))
90+
self.assertEqual(0.0, squared_distance(sv, sv))
91+
self.assertEqual(0.0, squared_distance(dv, dv))
92+
self.assertEqual(0.0, squared_distance(lst, lst))
93+
self.assertEqual(25.0, squared_distance(sv, lst1))
94+
self.assertEqual(3.0, squared_distance(sv, arr))
95+
self.assertEqual(3.0, squared_distance(sv, narr))
9896

9997
def test_hash(self):
10098
v1 = DenseVector([0.0, 1.0, 0.0, 5.5])

python/pyspark/mllib/tests/test_algorithms.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,8 @@
2626
from pyspark.mllib.fpm import FPGrowth
2727
from pyspark.mllib.recommendation import Rating
2828
from pyspark.mllib.regression import LabeledPoint
29-
from pyspark.testing.mllibutils import make_serializer, MLlibTestCase
30-
31-
32-
ser = make_serializer()
29+
from pyspark.serializers import PickleSerializer
30+
from pyspark.testing.mllibutils import MLlibTestCase
3331

3432

3533
class ListTests(MLlibTestCase):
@@ -265,6 +263,7 @@ def test_regression(self):
265263
class ALSTests(MLlibTestCase):
266264

267265
def test_als_ratings_serialize(self):
266+
ser = PickleSerializer()
268267
r = Rating(7, 1123, 3.14)
269268
jr = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(r)))
270269
nr = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jr)))
@@ -273,6 +272,7 @@ def test_als_ratings_serialize(self):
273272
self.assertAlmostEqual(r.rating, nr.rating, 2)
274273

275274
def test_als_ratings_id_long_error(self):
275+
ser = PickleSerializer()
276276
r = Rating(1205640308657491975, 50233468418, 1.0)
277277
# rating user id exceeds max int value, should fail when pickled
278278
self.assertRaises(Py4JJavaError, self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads,

python/pyspark/mllib/tests/test_linalg.py

Lines changed: 26 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -22,33 +22,18 @@
2222
from numpy import array, array_equal, zeros, arange, tile, ones, inf
2323

2424
import pyspark.ml.linalg as newlinalg
25+
from pyspark.serializers import PickleSerializer
2526
from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector, \
2627
DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT
2728
from pyspark.mllib.regression import LabeledPoint
28-
from pyspark.testing.mllibutils import make_serializer, MLlibTestCase
29-
30-
_have_scipy = False
31-
try:
32-
import scipy.sparse
33-
_have_scipy = True
34-
except:
35-
# No SciPy, but that's okay, we'll skip those tests
36-
pass
37-
38-
39-
ser = make_serializer()
40-
41-
42-
def _squared_distance(a, b):
43-
if isinstance(a, Vector):
44-
return a.squared_distance(b)
45-
else:
46-
return b.squared_distance(a)
29+
from pyspark.testing.mllibutils import MLlibTestCase
30+
from pyspark.testing.utils import have_scipy
4731

4832

4933
class VectorTests(MLlibTestCase):
5034

5135
def _test_serialize(self, v):
36+
ser = PickleSerializer()
5237
self.assertEqual(v, ser.loads(ser.dumps(v)))
5338
jvec = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(v)))
5439
nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvec)))
@@ -87,24 +72,30 @@ def test_dot(self):
8772
self.assertEqual(7.0, sv.dot(arr))
8873

8974
def test_squared_distance(self):
75+
def squared_distance(a, b):
76+
if isinstance(a, Vector):
77+
return a.squared_distance(b)
78+
else:
79+
return b.squared_distance(a)
80+
9081
sv = SparseVector(4, {1: 1, 3: 2})
9182
dv = DenseVector(array([1., 2., 3., 4.]))
9283
lst = DenseVector([4, 3, 2, 1])
9384
lst1 = [4, 3, 2, 1]
9485
arr = pyarray.array('d', [0, 2, 1, 3])
9586
narr = array([0, 2, 1, 3])
96-
self.assertEqual(15.0, _squared_distance(sv, dv))
97-
self.assertEqual(25.0, _squared_distance(sv, lst))
98-
self.assertEqual(20.0, _squared_distance(dv, lst))
99-
self.assertEqual(15.0, _squared_distance(dv, sv))
100-
self.assertEqual(25.0, _squared_distance(lst, sv))
101-
self.assertEqual(20.0, _squared_distance(lst, dv))
102-
self.assertEqual(0.0, _squared_distance(sv, sv))
103-
self.assertEqual(0.0, _squared_distance(dv, dv))
104-
self.assertEqual(0.0, _squared_distance(lst, lst))
105-
self.assertEqual(25.0, _squared_distance(sv, lst1))
106-
self.assertEqual(3.0, _squared_distance(sv, arr))
107-
self.assertEqual(3.0, _squared_distance(sv, narr))
87+
self.assertEqual(15.0, squared_distance(sv, dv))
88+
self.assertEqual(25.0, squared_distance(sv, lst))
89+
self.assertEqual(20.0, squared_distance(dv, lst))
90+
self.assertEqual(15.0, squared_distance(dv, sv))
91+
self.assertEqual(25.0, squared_distance(lst, sv))
92+
self.assertEqual(20.0, squared_distance(lst, dv))
93+
self.assertEqual(0.0, squared_distance(sv, sv))
94+
self.assertEqual(0.0, squared_distance(dv, dv))
95+
self.assertEqual(0.0, squared_distance(lst, lst))
96+
self.assertEqual(25.0, squared_distance(sv, lst1))
97+
self.assertEqual(3.0, squared_distance(sv, arr))
98+
self.assertEqual(3.0, squared_distance(sv, narr))
10899

109100
def test_hash(self):
110101
v1 = DenseVector([0.0, 1.0, 0.0, 5.5])
@@ -466,7 +457,7 @@ def test_infer_schema(self):
466457
raise ValueError("Expected a matrix but got type %r" % type(m))
467458

468459

469-
@unittest.skipIf(not _have_scipy, "SciPy not installed")
460+
@unittest.skipIf(not have_scipy, "SciPy not installed")
470461
class SciPyTests(MLlibTestCase):
471462

472463
"""
@@ -476,6 +467,8 @@ class SciPyTests(MLlibTestCase):
476467

477468
def test_serialize(self):
478469
from scipy.sparse import lil_matrix
470+
471+
ser = PickleSerializer()
479472
lil = lil_matrix((4, 1))
480473
lil[1, 0] = 1
481474
lil[3, 0] = 2
@@ -621,13 +614,10 @@ def test_regression(self):
621614

622615
if __name__ == "__main__":
623616
from pyspark.mllib.tests.test_linalg import *
624-
if not _have_scipy:
625-
print("NOTE: Skipping SciPy tests as it does not seem to be installed")
617+
626618
try:
627619
import xmlrunner
628620
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
629621
except ImportError:
630622
testRunner = None
631623
unittest.main(testRunner=testRunner, verbosity=2)
632-
if not _have_scipy:
633-
print("NOTE: SciPy tests were skipped as it does not seem to be installed")

python/pyspark/testing/mllibutils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,9 @@
1818
import unittest
1919

2020
from pyspark import SparkContext
21-
from pyspark.serializers import PickleSerializer
2221
from pyspark.sql import SparkSession
2322

2423

25-
def make_serializer():
26-
return PickleSerializer()
27-
28-
2924
class MLlibTestCase(unittest.TestCase):
3025
def setUp(self):
3126
self.sc = SparkContext('local[4]', "MLlib tests")

0 commit comments

Comments
 (0)