Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 21 additions & 23 deletions python/pyspark/ml/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,17 @@

from numpy import arange, array, array_equal, inf, ones, tile, zeros

from pyspark.serializers import PickleSerializer
from pyspark.ml.linalg import DenseMatrix, DenseVector, MatrixUDT, SparseMatrix, SparseVector, \
Vector, VectorUDT, Vectors
from pyspark.testing.mllibutils import make_serializer, MLlibTestCase
from pyspark.testing.mllibutils import MLlibTestCase
from pyspark.sql import Row


ser = make_serializer()


def _squared_distance(a, b):
if isinstance(a, Vector):
return a.squared_distance(b)
else:
return b.squared_distance(a)


class VectorTests(MLlibTestCase):

def _test_serialize(self, v):
ser = PickleSerializer()
self.assertEqual(v, ser.loads(ser.dumps(v)))
jvec = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(v)))
nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvec)))
Expand Down Expand Up @@ -77,24 +69,30 @@ def test_dot(self):
self.assertEqual(7.0, sv.dot(arr))

def test_squared_distance(self):
def squared_distance(a, b):
if isinstance(a, Vector):
return a.squared_distance(b)
else:
return b.squared_distance(a)

sv = SparseVector(4, {1: 1, 3: 2})
dv = DenseVector(array([1., 2., 3., 4.]))
lst = DenseVector([4, 3, 2, 1])
lst1 = [4, 3, 2, 1]
arr = pyarray.array('d', [0, 2, 1, 3])
narr = array([0, 2, 1, 3])
self.assertEqual(15.0, _squared_distance(sv, dv))
self.assertEqual(25.0, _squared_distance(sv, lst))
self.assertEqual(20.0, _squared_distance(dv, lst))
self.assertEqual(15.0, _squared_distance(dv, sv))
self.assertEqual(25.0, _squared_distance(lst, sv))
self.assertEqual(20.0, _squared_distance(lst, dv))
self.assertEqual(0.0, _squared_distance(sv, sv))
self.assertEqual(0.0, _squared_distance(dv, dv))
self.assertEqual(0.0, _squared_distance(lst, lst))
self.assertEqual(25.0, _squared_distance(sv, lst1))
self.assertEqual(3.0, _squared_distance(sv, arr))
self.assertEqual(3.0, _squared_distance(sv, narr))
self.assertEqual(15.0, squared_distance(sv, dv))
self.assertEqual(25.0, squared_distance(sv, lst))
self.assertEqual(20.0, squared_distance(dv, lst))
self.assertEqual(15.0, squared_distance(dv, sv))
self.assertEqual(25.0, squared_distance(lst, sv))
self.assertEqual(20.0, squared_distance(lst, dv))
self.assertEqual(0.0, squared_distance(sv, sv))
self.assertEqual(0.0, squared_distance(dv, dv))
self.assertEqual(0.0, squared_distance(lst, lst))
self.assertEqual(25.0, squared_distance(sv, lst1))
self.assertEqual(3.0, squared_distance(sv, arr))
self.assertEqual(3.0, squared_distance(sv, narr))

def test_hash(self):
v1 = DenseVector([0.0, 1.0, 0.0, 5.5])
Expand Down
8 changes: 4 additions & 4 deletions python/pyspark/mllib/tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@
from pyspark.mllib.fpm import FPGrowth
from pyspark.mllib.recommendation import Rating
from pyspark.mllib.regression import LabeledPoint
from pyspark.testing.mllibutils import make_serializer, MLlibTestCase


ser = make_serializer()
from pyspark.serializers import PickleSerializer
from pyspark.testing.mllibutils import MLlibTestCase


class ListTests(MLlibTestCase):
Expand Down Expand Up @@ -265,6 +263,7 @@ def test_regression(self):
class ALSTests(MLlibTestCase):

def test_als_ratings_serialize(self):
ser = PickleSerializer()
r = Rating(7, 1123, 3.14)
jr = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(r)))
nr = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jr)))
Expand All @@ -273,6 +272,7 @@ def test_als_ratings_serialize(self):
self.assertAlmostEqual(r.rating, nr.rating, 2)

def test_als_ratings_id_long_error(self):
ser = PickleSerializer()
r = Rating(1205640308657491975, 50233468418, 1.0)
# rating user id exceeds max int value, should fail when pickled
self.assertRaises(Py4JJavaError, self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads,
Expand Down
62 changes: 26 additions & 36 deletions python/pyspark/mllib/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,33 +22,18 @@
from numpy import array, array_equal, zeros, arange, tile, ones, inf

import pyspark.ml.linalg as newlinalg
from pyspark.serializers import PickleSerializer
from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector, \
DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT
from pyspark.mllib.regression import LabeledPoint
from pyspark.testing.mllibutils import make_serializer, MLlibTestCase

_have_scipy = False
try:
import scipy.sparse
_have_scipy = True
except:
# No SciPy, but that's okay, we'll skip those tests
pass


ser = make_serializer()


def _squared_distance(a, b):
if isinstance(a, Vector):
return a.squared_distance(b)
else:
return b.squared_distance(a)
from pyspark.testing.mllibutils import MLlibTestCase
from pyspark.testing.utils import have_scipy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh that's good, didn't realize have_scipy was there



class VectorTests(MLlibTestCase):

def _test_serialize(self, v):
ser = PickleSerializer()
self.assertEqual(v, ser.loads(ser.dumps(v)))
jvec = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(v)))
nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvec)))
Expand Down Expand Up @@ -87,24 +72,30 @@ def test_dot(self):
self.assertEqual(7.0, sv.dot(arr))

def test_squared_distance(self):
def squared_distance(a, b):
if isinstance(a, Vector):
return a.squared_distance(b)
else:
return b.squared_distance(a)

sv = SparseVector(4, {1: 1, 3: 2})
dv = DenseVector(array([1., 2., 3., 4.]))
lst = DenseVector([4, 3, 2, 1])
lst1 = [4, 3, 2, 1]
arr = pyarray.array('d', [0, 2, 1, 3])
narr = array([0, 2, 1, 3])
self.assertEqual(15.0, _squared_distance(sv, dv))
self.assertEqual(25.0, _squared_distance(sv, lst))
self.assertEqual(20.0, _squared_distance(dv, lst))
self.assertEqual(15.0, _squared_distance(dv, sv))
self.assertEqual(25.0, _squared_distance(lst, sv))
self.assertEqual(20.0, _squared_distance(lst, dv))
self.assertEqual(0.0, _squared_distance(sv, sv))
self.assertEqual(0.0, _squared_distance(dv, dv))
self.assertEqual(0.0, _squared_distance(lst, lst))
self.assertEqual(25.0, _squared_distance(sv, lst1))
self.assertEqual(3.0, _squared_distance(sv, arr))
self.assertEqual(3.0, _squared_distance(sv, narr))
self.assertEqual(15.0, squared_distance(sv, dv))
self.assertEqual(25.0, squared_distance(sv, lst))
self.assertEqual(20.0, squared_distance(dv, lst))
self.assertEqual(15.0, squared_distance(dv, sv))
self.assertEqual(25.0, squared_distance(lst, sv))
self.assertEqual(20.0, squared_distance(lst, dv))
self.assertEqual(0.0, squared_distance(sv, sv))
self.assertEqual(0.0, squared_distance(dv, dv))
self.assertEqual(0.0, squared_distance(lst, lst))
self.assertEqual(25.0, squared_distance(sv, lst1))
self.assertEqual(3.0, squared_distance(sv, arr))
self.assertEqual(3.0, squared_distance(sv, narr))

def test_hash(self):
v1 = DenseVector([0.0, 1.0, 0.0, 5.5])
Expand Down Expand Up @@ -466,7 +457,7 @@ def test_infer_schema(self):
raise ValueError("Expected a matrix but got type %r" % type(m))


@unittest.skipIf(not _have_scipy, "SciPy not installed")
@unittest.skipIf(not have_scipy, "SciPy not installed")
class SciPyTests(MLlibTestCase):

"""
Expand All @@ -476,6 +467,8 @@ class SciPyTests(MLlibTestCase):

def test_serialize(self):
from scipy.sparse import lil_matrix

ser = PickleSerializer()
lil = lil_matrix((4, 1))
lil[1, 0] = 1
lil[3, 0] = 2
Expand Down Expand Up @@ -621,13 +614,10 @@ def test_regression(self):

if __name__ == "__main__":
from pyspark.mllib.tests.test_linalg import *
if not _have_scipy:
print("NOTE: Skipping SciPy tests as it does not seem to be installed")

try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
if not _have_scipy:
print("NOTE: SciPy tests were skipped as it does not seem to be installed")
5 changes: 0 additions & 5 deletions python/pyspark/testing/mllibutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,9 @@
import unittest

from pyspark import SparkContext
from pyspark.serializers import PickleSerializer
from pyspark.sql import SparkSession


def make_serializer():
return PickleSerializer()


class MLlibTestCase(unittest.TestCase):
def setUp(self):
self.sc = SparkContext('local[4]', "MLlib tests")
Expand Down