2222from numpy import array , array_equal , zeros , arange , tile , ones , inf
2323
2424import pyspark .ml .linalg as newlinalg
25+ from pyspark .serializers import PickleSerializer
2526from pyspark .mllib .linalg import Vector , SparseVector , DenseVector , VectorUDT , _convert_to_vector , \
2627 DenseMatrix , SparseMatrix , Vectors , Matrices , MatrixUDT
2728from pyspark .mllib .regression import LabeledPoint
28- from pyspark .testing .mllibutils import make_serializer , MLlibTestCase
29-
30- _have_scipy = False
31- try :
32- import scipy .sparse
33- _have_scipy = True
34- except :
35- # No SciPy, but that's okay, we'll skip those tests
36- pass
37-
38-
39- ser = make_serializer ()
40-
41-
42- def _squared_distance (a , b ):
43- if isinstance (a , Vector ):
44- return a .squared_distance (b )
45- else :
46- return b .squared_distance (a )
29+ from pyspark .testing .mllibutils import MLlibTestCase
30+ from pyspark .testing .utils import have_scipy
4731
4832
4933class VectorTests (MLlibTestCase ):
5034
5135 def _test_serialize (self , v ):
36+ ser = PickleSerializer ()
5237 self .assertEqual (v , ser .loads (ser .dumps (v )))
5338 jvec = self .sc ._jvm .org .apache .spark .mllib .api .python .SerDe .loads (bytearray (ser .dumps (v )))
5439 nv = ser .loads (bytes (self .sc ._jvm .org .apache .spark .mllib .api .python .SerDe .dumps (jvec )))
@@ -87,24 +72,30 @@ def test_dot(self):
8772 self .assertEqual (7.0 , sv .dot (arr ))
8873
8974 def test_squared_distance (self ):
75+ def squared_distance (a , b ):
76+ if isinstance (a , Vector ):
77+ return a .squared_distance (b )
78+ else :
79+ return b .squared_distance (a )
80+
9081 sv = SparseVector (4 , {1 : 1 , 3 : 2 })
9182 dv = DenseVector (array ([1. , 2. , 3. , 4. ]))
9283 lst = DenseVector ([4 , 3 , 2 , 1 ])
9384 lst1 = [4 , 3 , 2 , 1 ]
9485 arr = pyarray .array ('d' , [0 , 2 , 1 , 3 ])
9586 narr = array ([0 , 2 , 1 , 3 ])
96- self .assertEqual (15.0 , _squared_distance (sv , dv ))
97- self .assertEqual (25.0 , _squared_distance (sv , lst ))
98- self .assertEqual (20.0 , _squared_distance (dv , lst ))
99- self .assertEqual (15.0 , _squared_distance (dv , sv ))
100- self .assertEqual (25.0 , _squared_distance (lst , sv ))
101- self .assertEqual (20.0 , _squared_distance (lst , dv ))
102- self .assertEqual (0.0 , _squared_distance (sv , sv ))
103- self .assertEqual (0.0 , _squared_distance (dv , dv ))
104- self .assertEqual (0.0 , _squared_distance (lst , lst ))
105- self .assertEqual (25.0 , _squared_distance (sv , lst1 ))
106- self .assertEqual (3.0 , _squared_distance (sv , arr ))
107- self .assertEqual (3.0 , _squared_distance (sv , narr ))
87+ self .assertEqual (15.0 , squared_distance (sv , dv ))
88+ self .assertEqual (25.0 , squared_distance (sv , lst ))
89+ self .assertEqual (20.0 , squared_distance (dv , lst ))
90+ self .assertEqual (15.0 , squared_distance (dv , sv ))
91+ self .assertEqual (25.0 , squared_distance (lst , sv ))
92+ self .assertEqual (20.0 , squared_distance (lst , dv ))
93+ self .assertEqual (0.0 , squared_distance (sv , sv ))
94+ self .assertEqual (0.0 , squared_distance (dv , dv ))
95+ self .assertEqual (0.0 , squared_distance (lst , lst ))
96+ self .assertEqual (25.0 , squared_distance (sv , lst1 ))
97+ self .assertEqual (3.0 , squared_distance (sv , arr ))
98+ self .assertEqual (3.0 , squared_distance (sv , narr ))
10899
109100 def test_hash (self ):
110101 v1 = DenseVector ([0.0 , 1.0 , 0.0 , 5.5 ])
@@ -466,7 +457,7 @@ def test_infer_schema(self):
466457 raise ValueError ("Expected a matrix but got type %r" % type (m ))
467458
468459
469- @unittest .skipIf (not _have_scipy , "SciPy not installed" )
460+ @unittest .skipIf (not have_scipy , "SciPy not installed" )
470461class SciPyTests (MLlibTestCase ):
471462
472463 """
@@ -476,6 +467,8 @@ class SciPyTests(MLlibTestCase):
476467
477468 def test_serialize (self ):
478469 from scipy .sparse import lil_matrix
470+
471+ ser = PickleSerializer ()
479472 lil = lil_matrix ((4 , 1 ))
480473 lil [1 , 0 ] = 1
481474 lil [3 , 0 ] = 2
@@ -621,13 +614,10 @@ def test_regression(self):
621614
622615if __name__ == "__main__" :
623616 from pyspark .mllib .tests .test_linalg import *
624- if not _have_scipy :
625- print ("NOTE: Skipping SciPy tests as it does not seem to be installed" )
617+
626618 try :
627619 import xmlrunner
628620 testRunner = xmlrunner .XMLTestRunner (output = 'target/test-reports' )
629621 except ImportError :
630622 testRunner = None
631623 unittest .main (testRunner = testRunner , verbosity = 2 )
632- if not _have_scipy :
633- print ("NOTE: SciPy tests were skipped as it does not seem to be installed" )
0 commit comments