|
42 | 42 | import numpy as np |
43 | 43 |
|
44 | 44 | from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer |
45 | | -from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier |
| 45 | +from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier, OneVsRest |
46 | 46 | from pyspark.ml.clustering import KMeans |
47 | 47 | from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator |
48 | 48 | from pyspark.ml.feature import * |
@@ -831,6 +831,36 @@ def test_logistic_regression_summary(self): |
831 | 831 | self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) |
832 | 832 |
|
833 | 833 |
|
| 834 | +class OneVsRestTests(PySparkTestCase): |
| 835 | + |
| 836 | + def test_copy(self): |
| 837 | + sqlContext = SQLContext(self.sc) |
| 838 | + df = sqlContext.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), |
| 839 | + (1.0, Vectors.sparse(2, [], [])), |
| 840 | + (2.0, Vectors.dense(0.5, 0.5))], |
| 841 | + ["label", "features"]) |
| 842 | + lr = LogisticRegression(maxIter=5, regParam=0.01) |
| 843 | + ovr = OneVsRest(classifier=lr) |
| 844 | + ovr1 = ovr.copy({lr.maxIter: 10}) |
| 845 | + self.assertEqual(ovr.getClassifier().getMaxIter(), 5) |
| 846 | + self.assertEqual(ovr1.getClassifier().getMaxIter(), 10) |
| 847 | + model = ovr.fit(df) |
| 848 | + model1 = model.copy({model.predictionCol: "indexed"}) |
| 849 | + self.assertEqual(model1.getPredictionCol(), "indexed") |
| 850 | + |
| 851 | + def test_output_columns(self): |
| 852 | + sqlContext = SQLContext(self.sc) |
| 853 | + df = sqlContext.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), |
| 854 | + (1.0, Vectors.sparse(2, [], [])), |
| 855 | + (2.0, Vectors.dense(0.5, 0.5))], |
| 856 | + ["label", "features"]) |
| 857 | + lr = LogisticRegression(maxIter=5, regParam=0.01) |
| 858 | + ovr = OneVsRest(classifier=lr) |
| 859 | + model = ovr.fit(df) |
| 860 | + output = model.transform(df) |
| 861 | + self.assertEqual(output.columns, ["label", "features", "prediction"]) |
| 862 | + |
| 863 | + |
834 | 864 | if __name__ == "__main__": |
835 | 865 | from pyspark.ml.tests import * |
836 | 866 | if xmlrunner: |
|
0 commit comments