Skip to content

Commit e0cf36f

Browse files
committed
add unit tests
1 parent fb337cf commit e0cf36f

File tree

2 files changed

+34
-2
lines changed

2 files changed

+34
-2
lines changed

python/pyspark/ml/classification.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1352,7 +1352,9 @@ def copy(self, extra=None):
13521352
"""
13531353
if extra is None:
13541354
extra = dict()
1355-
return self._copyValues(OneVsRestModel([model.copy(extra) for model in self.models]))
1355+
newModel = Params.copy(self, extra)
1356+
newModel.models = [model.copy(extra) for model in self.models]
1357+
return newModel
13561358

13571359

13581360
if __name__ == "__main__":

python/pyspark/ml/tests.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
import numpy as np
4343

4444
from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer
45-
from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier
45+
from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier, OneVsRest
4646
from pyspark.ml.clustering import KMeans
4747
from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator
4848
from pyspark.ml.feature import *
@@ -831,6 +831,36 @@ def test_logistic_regression_summary(self):
831831
self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC)
832832

833833

834+
class OneVsRestTests(PySparkTestCase):
835+
836+
def test_copy(self):
837+
sqlContext = SQLContext(self.sc)
838+
df = sqlContext.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
839+
(1.0, Vectors.sparse(2, [], [])),
840+
(2.0, Vectors.dense(0.5, 0.5))],
841+
["label", "features"])
842+
lr = LogisticRegression(maxIter=5, regParam=0.01)
843+
ovr = OneVsRest(classifier=lr)
844+
ovr1 = ovr.copy({lr.maxIter: 10})
845+
self.assertEqual(ovr.getClassifier().getMaxIter(), 5)
846+
self.assertEqual(ovr1.getClassifier().getMaxIter(), 10)
847+
model = ovr.fit(df)
848+
model1 = model.copy({model.predictionCol: "indexed"})
849+
self.assertEqual(model1.getPredictionCol(), "indexed")
850+
851+
def test_output_columns(self):
852+
sqlContext = SQLContext(self.sc)
853+
df = sqlContext.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
854+
(1.0, Vectors.sparse(2, [], [])),
855+
(2.0, Vectors.dense(0.5, 0.5))],
856+
["label", "features"])
857+
lr = LogisticRegression(maxIter=5, regParam=0.01)
858+
ovr = OneVsRest(classifier=lr)
859+
model = ovr.fit(df)
860+
output = model.transform(df)
861+
self.assertEqual(output.columns, ["label", "features", "prediction"])
862+
863+
834864
if __name__ == "__main__":
835865
from pyspark.ml.tests import *
836866
if xmlrunner:

0 commit comments

Comments
 (0)