|
26 | 26 | from time import time, sleep |
27 | 27 | from shutil import rmtree |
28 | 28 |
|
29 | | -from numpy import array, array_equal, zeros, inf, all, random |
| 29 | +from numpy import ( |
| 30 | + array, array_equal, zeros, inf, random, exp, dot, all, mean) |
30 | 31 | from numpy import sum as array_sum |
31 | 32 | from py4j.protocol import Py4JJavaError |
32 | 33 |
|
|
45 | 46 | from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ |
46 | 47 | DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT |
47 | 48 | from pyspark.mllib.regression import LabeledPoint |
| 49 | +from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD |
48 | 50 | from pyspark.mllib.random import RandomRDDs |
49 | 51 | from pyspark.mllib.stat import Statistics |
50 | 52 | from pyspark.mllib.feature import Word2Vec |
@@ -1037,6 +1039,137 @@ def test_dim(self): |
1037 | 1039 | self.assertEqual(len(point.features), 2) |
1038 | 1040 |
|
1039 | 1041 |
|
| 1042 | +class StreamingLogisticRegressionWithSGDTests(MLLibStreamingTestCase): |
| 1043 | + |
| 1044 | + @staticmethod |
| 1045 | + def generateLogisticInput(offset, scale, nPoints, seed): |
| 1046 | + """ |
| 1047 | + Generate 1 / (1 + exp(-x * scale + offset)) |
| 1048 | +
|
| 1049 | + where, |
| 1050 | + x is randomnly distributed and the threshold |
| 1051 | + and labels for each sample in x is obtained from a random uniform |
| 1052 | + distribution. |
| 1053 | + """ |
| 1054 | + rng = random.RandomState(seed) |
| 1055 | + x = rng.randn(nPoints) |
| 1056 | + sigmoid = 1. / (1 + exp(-(dot(x, scale) + offset))) |
| 1057 | + y_p = rng.rand(nPoints) |
| 1058 | + cut_off = y_p <= sigmoid |
| 1059 | + y_p[cut_off] = 1.0 |
| 1060 | + y_p[~cut_off] = 0.0 |
| 1061 | + return [ |
| 1062 | + LabeledPoint(y_p[i], Vectors.dense([x[i]])) |
| 1063 | + for i in range(nPoints)] |
| 1064 | + |
| 1065 | + def test_parameter_accuracy(self): |
| 1066 | + """ |
| 1067 | + Test that the final value of weights is close to the desired value. |
| 1068 | + """ |
| 1069 | + input_batches = [ |
| 1070 | + self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) |
| 1071 | + for i in range(20)] |
| 1072 | + input_stream = self.ssc.queueStream(input_batches) |
| 1073 | + |
| 1074 | + slr = StreamingLogisticRegressionWithSGD( |
| 1075 | + stepSize=0.2, numIterations=25) |
| 1076 | + slr.setInitialWeights([0.0]) |
| 1077 | + slr.trainOn(input_stream) |
| 1078 | + |
| 1079 | + t = time() |
| 1080 | + self.ssc.start() |
| 1081 | + self._ssc_wait(t, 20.0, 0.01) |
| 1082 | + rel = (1.5 - slr.latestModel().weights.array[0]) / 1.5 |
| 1083 | + self.assertAlmostEqual(rel, 0.1, 1) |
| 1084 | + |
| 1085 | + def test_convergence(self): |
| 1086 | + """ |
| 1087 | + Test that weights converge to the required value on toy data. |
| 1088 | + """ |
| 1089 | + input_batches = [ |
| 1090 | + self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) |
| 1091 | + for i in range(20)] |
| 1092 | + input_stream = self.ssc.queueStream(input_batches) |
| 1093 | + models = [] |
| 1094 | + |
| 1095 | + slr = StreamingLogisticRegressionWithSGD( |
| 1096 | + stepSize=0.2, numIterations=25) |
| 1097 | + slr.setInitialWeights([0.0]) |
| 1098 | + slr.trainOn(input_stream) |
| 1099 | + input_stream.foreachRDD( |
| 1100 | + lambda x: models.append(slr.latestModel().weights[0])) |
| 1101 | + |
| 1102 | + t = time() |
| 1103 | + self.ssc.start() |
| 1104 | + self._ssc_wait(t, 15.0, 0.01) |
| 1105 | + t_models = array(models) |
| 1106 | + diff = t_models[1:] - t_models[:-1] |
| 1107 | + |
| 1108 | + # Test that weights improve with a small tolerance, |
| 1109 | + self.assertTrue(all(diff >= -0.1)) |
| 1110 | + self.assertTrue(array_sum(diff > 0) > 1) |
| 1111 | + |
| 1112 | + @staticmethod |
| 1113 | + def calculate_accuracy_error(true, predicted): |
| 1114 | + return sum(abs(array(true) - array(predicted))) / len(true) |
| 1115 | + |
| 1116 | + def test_predictions(self): |
| 1117 | + """Test predicted values on a toy model.""" |
| 1118 | + input_batches = [] |
| 1119 | + for i in range(20): |
| 1120 | + batch = self.sc.parallelize( |
| 1121 | + self.generateLogisticInput(0, 1.5, 100, 42 + i)) |
| 1122 | + input_batches.append(batch.map(lambda x: (x.label, x.features))) |
| 1123 | + input_stream = self.ssc.queueStream(input_batches) |
| 1124 | + |
| 1125 | + slr = StreamingLogisticRegressionWithSGD( |
| 1126 | + stepSize=0.2, numIterations=25) |
| 1127 | + slr.setInitialWeights([1.5]) |
| 1128 | + predict_stream = slr.predictOnValues(input_stream) |
| 1129 | + true_predicted = [] |
| 1130 | + predict_stream.foreachRDD(lambda x: true_predicted.append(x.collect())) |
| 1131 | + t = time() |
| 1132 | + self.ssc.start() |
| 1133 | + self._ssc_wait(t, 5.0, 0.01) |
| 1134 | + |
| 1135 | + # Test that the accuracy error is no more than 0.4 on each batch. |
| 1136 | + for batch in true_predicted: |
| 1137 | + true, predicted = zip(*batch) |
| 1138 | + self.assertTrue( |
| 1139 | + self.calculate_accuracy_error(true, predicted) < 0.4) |
| 1140 | + |
| 1141 | + def test_training_and_prediction(self): |
| 1142 | + """Test that the model improves on toy data with no. of batches""" |
| 1143 | + input_batches = [ |
| 1144 | + self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) |
| 1145 | + for i in range(20)] |
| 1146 | + predict_batches = [ |
| 1147 | + b.map(lambda lp: (lp.label, lp.features)) for b in input_batches] |
| 1148 | + |
| 1149 | + slr = StreamingLogisticRegressionWithSGD( |
| 1150 | + stepSize=0.01, numIterations=25) |
| 1151 | + slr.setInitialWeights([-0.1]) |
| 1152 | + errors = [] |
| 1153 | + |
| 1154 | + def collect_errors(rdd): |
| 1155 | + true, predicted = zip(*rdd.collect()) |
| 1156 | + errors.append(self.calculate_accuracy_error(true, predicted)) |
| 1157 | + |
| 1158 | + true_predicted = [] |
| 1159 | + input_stream = self.ssc.queueStream(input_batches) |
| 1160 | + predict_stream = self.ssc.queueStream(predict_batches) |
| 1161 | + slr.trainOn(input_stream) |
| 1162 | + ps = slr.predictOnValues(predict_stream) |
| 1163 | + ps.foreachRDD(lambda x: collect_errors(x)) |
| 1164 | + |
| 1165 | + t = time() |
| 1166 | + self.ssc.start() |
| 1167 | + self._ssc_wait(t, 20.0, 0.01) |
| 1168 | + |
| 1169 | + # Test that the improvement in error is atleast 0.3 |
| 1170 | + self.assertTrue(errors[1] - errors[-1] > 0.3) |
| 1171 | + |
| 1172 | + |
1040 | 1173 | if __name__ == "__main__": |
1041 | 1174 | if not _have_scipy: |
1042 | 1175 | print("NOTE: Skipping SciPy tests as it does not seem to be installed") |
|
0 commit comments