|
49 | 49 | from pyspark.mllib.feature import IDF |
50 | 50 | from pyspark.mllib.feature import StandardScaler |
51 | 51 | from pyspark.mllib.feature import ElementwiseProduct |
| 52 | +from pyspark.mllib.util import MLUtils |
52 | 53 | from pyspark.serializers import PickleSerializer |
53 | 54 | from pyspark.streaming import StreamingContext |
54 | 55 | from pyspark.sql import SQLContext |
@@ -1010,6 +1011,48 @@ def collect(rdd): |
1010 | 1011 | self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]]) |
1011 | 1012 |
|
1012 | 1013 |
|
| 1014 | +class MLUtilsTests(MLlibTestCase): |
| 1015 | + def test_append_bias(self): |
| 1016 | + data = [2.0, 2.0, 2.0] |
| 1017 | + ret = MLUtils.appendBias(data) |
| 1018 | + self.assertEqual(ret[3], 1.0) |
| 1019 | + self.assertEqual(type(ret), DenseVector) |
| 1020 | + |
| 1021 | + def test_append_bias_with_vector(self): |
| 1022 | + data = Vectors.dense([2.0, 2.0, 2.0]) |
| 1023 | + ret = MLUtils.appendBias(data) |
| 1024 | + self.assertEqual(ret[3], 1.0) |
| 1025 | + self.assertEqual(type(ret), DenseVector) |
| 1026 | + |
| 1027 | + def test_append_bias_with_sp_vector(self): |
| 1028 | + data = Vectors.sparse(3, {0: 2.0, 2: 2.0}) |
| 1029 | + expected = Vectors.sparse(4, {0: 2.0, 2: 2.0, 3: 1.0}) |
| 1030 | + # Returned value must be SparseVector |
| 1031 | + ret = MLUtils.appendBias(data) |
| 1032 | + self.assertEqual(ret, expected) |
| 1033 | + self.assertEqual(type(ret), SparseVector) |
| 1034 | + |
| 1035 | + def test_load_vectors(self): |
| 1036 | + import shutil |
| 1037 | + data = [ |
| 1038 | + [1.0, 2.0, 3.0], |
| 1039 | + [1.0, 2.0, 3.0] |
| 1040 | + ] |
| 1041 | + temp_dir = tempfile.mkdtemp() |
| 1042 | + load_vectors_path = os.path.join(temp_dir, "test_load_vectors") |
| 1043 | + try: |
| 1044 | + self.sc.parallelize(data).saveAsTextFile(load_vectors_path) |
| 1045 | + ret_rdd = MLUtils.loadVectors(self.sc, load_vectors_path) |
| 1046 | + ret = ret_rdd.collect() |
| 1047 | + self.assertEqual(len(ret), 2) |
| 1048 | + self.assertEqual(ret[0], DenseVector([1.0, 2.0, 3.0])) |
| 1049 | + self.assertEqual(ret[1], DenseVector([1.0, 2.0, 3.0])) |
| 1050 | + except: |
| 1051 | + self.fail() |
| 1052 | + finally: |
| 1053 | + shutil.rmtree(load_vectors_path) |
| 1054 | + |
| 1055 | + |
1013 | 1056 | if __name__ == "__main__": |
1014 | 1057 | if not _have_scipy: |
1015 | 1058 | print("NOTE: Skipping SciPy tests as it does not seem to be installed") |
|
0 commit comments