Skip to content

Commit 6d30d77

Browse files
committed
add copy and more tests
1 parent 417d13f commit 6d30d77

File tree

1 file changed

+54
-56
lines changed

1 file changed

+54
-56
lines changed

python/pyspark/ml/classification.py

Lines changed: 54 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -15,30 +15,27 @@
1515
# limitations under the License.
1616
#
1717

18-
import warnings
19-
2018
import operator
21-
import uuid
19+
import warnings
2220

23-
from pyspark import since
2421
from pyspark.ml import Estimator, Model
25-
from pyspark.ml.util import *
26-
from pyspark.ml.wrapper import JavaEstimator, JavaModel
27-
from pyspark.ml.param import TypeConverters
2822
from pyspark.ml.param.shared import *
2923
from pyspark.ml.regression import (
3024
RandomForestParams, TreeEnsembleParams, DecisionTreeModel, TreeEnsembleModels)
25+
from pyspark.ml.util import *
26+
from pyspark.ml.wrapper import JavaEstimator, JavaModel
3127
from pyspark.mllib.common import inherit_doc
3228
from pyspark.sql.functions import udf, when
33-
from pyspark.sql.types import ArrayType, MapType, IntegerType, DoubleType
29+
from pyspark.sql.types import ArrayType, DoubleType
3430
from pyspark.storagelevel import StorageLevel
3531

3632
__all__ = ['LogisticRegression', 'LogisticRegressionModel',
3733
'DecisionTreeClassifier', 'DecisionTreeClassificationModel',
3834
'GBTClassifier', 'GBTClassificationModel',
3935
'RandomForestClassifier', 'RandomForestClassificationModel',
4036
'NaiveBayes', 'NaiveBayesModel',
41-
'MultilayerPerceptronClassifier', 'MultilayerPerceptronClassificationModel']
37+
'MultilayerPerceptronClassifier', 'MultilayerPerceptronClassificationModel',
38+
'OneVsRest', 'OneVsRestModel']
4239

4340

4441
@inherit_doc
@@ -923,16 +920,17 @@ class OneVsRest(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol):
923920
>>> from pyspark.sql import Row
924921
>>> from pyspark.mllib.linalg import Vectors
925922
>>> df = sc.parallelize([
926-
... Row(label=1.0, features=Vectors.dense(1.0)),
927-
... Row(label=0.0, features=Vectors.sparse(1, [], []))]).toDF()
923+
... Row(label=0.0, features=Vectors.dense(1.0, 0.8)),
924+
... Row(label=1.0, features=Vectors.sparse(2, [], [])),
925+
... Row(label=2.0, features=Vectors.dense(0.5, 0.5))]).toDF()
928926
>>> lr = LogisticRegression(maxIter=5, regParam=0.01)
929927
>>> ovr = OneVsRest(classifier=lr).setPredictionCol("indexed")
930928
>>> model = ovr.fit(df)
931-
>>> model.models[0].weights
932-
>>> model.models[0].coefficients
933-
>>> model.models[0].intercept
934-
>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF()
935-
>>> model.transform(test0).show()
929+
>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 0.0))]).toDF()
930+
>>> model.transform(test0).head().indexed
931+
1.0
932+
>>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()
933+
>>> model.transform(test1).head().indexed
936934
0.0
937935
938936
.. versionadded:: 2.0.0
@@ -965,7 +963,7 @@ def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classif
965963
@since("2.0.0")
966964
def setClassifier(self, value):
967965
"""
968-
Sets the value of :py:attr:`estimator`.
966+
Sets the value of :py:attr:`classifier`.
969967
"""
970968
self._paramMap[self.classifier] = value
971969
return self
@@ -985,7 +983,7 @@ def _fit(self, dataset):
985983
multiclassLabeled = dataset.select(labelCol, featureCol)
986984

987985
# persist if underlying dataset is not persistent.
988-
handlePersistence =\
986+
handlePersistence = \
989987
dataset.rdd.getStorageLevel() == StorageLevel(False, False, False, False)
990988
if handlePersistence:
991989
multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
@@ -1007,32 +1005,32 @@ def _fit(self, dataset):
10071005
if handlePersistence:
10081006
multiclassLabeled.unpersist()
10091007

1010-
return OneVsRestModel(models=models)
1011-
1012-
# @since("2.0.0")
1013-
# def copy(self, extra=None):
1014-
# """
1015-
# Creates a copy of this instance with a randomly generated uid
1016-
# and some extra params. This copies creates a deep copy of
1017-
# the embedded paramMap, and copies the embedded and extra parameters over.
1018-
1019-
# :param extra: Extra parameters to copy to the new instance
1020-
# :return: Copy of this instance
1021-
# """
1022-
# if extra is None:
1023-
# extra = dict()
1024-
# newCV = Params.copy(self, extra)
1025-
# if self.isSet(self.estimator):
1026-
# newCV.setEstimator(self.getEstimator().copy(extra))
1027-
# # estimatorParamMaps remain the same
1028-
# if self.isSet(self.evaluator):
1029-
# newCV.setEvaluator(self.getEvaluator().copy(extra))
1030-
# return newCV
1008+
return OneVsRestModel(models=models)\
1009+
.setFeaturesCol(self.getFeaturesCol())\
1010+
.setLabelCol(self.getLabelCol())\
1011+
.setPredictionCol(self.getPredictionCol())
1012+
1013+
@since("2.0.0")
1014+
def copy(self, extra=None):
1015+
"""
1016+
Creates a copy of this instance with a randomly generated uid
1017+
and some extra params. This copies creates a deep copy of
1018+
the embedded paramMap, and copies the embedded and extra parameters over.
1019+
1020+
:param extra: Extra parameters to copy to the new instance
1021+
:return: Copy of this instance
1022+
"""
1023+
if extra is None:
1024+
extra = dict()
1025+
newOVR = Params.copy(self, extra)
1026+
if self.isSet(self.classifier):
1027+
newOVR.setClassifier(self.getClassifier().copy(extra))
1028+
return newOVR
10311029

10321030

10331031
class OneVsRestModel(Model, HasFeaturesCol, HasLabelCol, HasPredictionCol):
10341032
"""
1035-
Model produced by [[OneVsRest]].
1033+
Model fitted by OneVsRest.
10361034
This stores the models resulting from training k binary classifiers: one for each class.
10371035
Each example is scored against all k models, and the model with the highest score
10381036
is picked to label the example.
@@ -1055,7 +1053,7 @@ def _transform(self, dataset):
10551053
newDataset = dataset.withColumn(accColName, initUDF(dataset[origCols[0]]))
10561054

10571055
# persist if underlying dataset is not persistent.
1058-
handlePersistence =\
1056+
handlePersistence = \
10591057
dataset.rdd.getStorageLevel() == StorageLevel(False, False, False, False)
10601058
if handlePersistence:
10611059
newDataset.persist(StorageLevel.MEMORY_AND_DISK)
@@ -1086,26 +1084,26 @@ def _transform(self, dataset):
10861084

10871085
# output the index of the classifier with highest confidence as prediction
10881086
labelUDF = udf(
1089-
lambda predictions: float(max(enumerate(predictions), key=operator.itemgetter(1))[0]))
1087+
lambda predictions: float(max(enumerate(predictions), key=operator.itemgetter(1))[0]),
1088+
DoubleType())
10901089

10911090
# output label and label metadata as prediction
10921091
return aggregatedDataset.withColumn(
10931092
self.getPredictionCol(), labelUDF(aggregatedDataset[accColName])).drop(accColName)
10941093

1095-
# @since("1.4.0")
1096-
# def copy(self, extra=None):
1097-
# """
1098-
# Creates a copy of this instance with a randomly generated uid
1099-
# and some extra params. This copies the underlying bestModel,
1100-
# creates a deep copy of the embedded paramMap, and
1101-
# copies the embedded and extra parameters over.
1102-
1103-
# :param extra: Extra parameters to copy to the new instance
1104-
# :return: Copy of this instance
1105-
# """
1106-
# if extra is None:
1107-
# extra = dict()
1108-
# return OneVsRestModel(self.models.copy(extra))
1094+
@since("2.0.0")
1095+
def copy(self, extra=None):
1096+
"""
1097+
Creates a copy of this instance with a randomly generated uid
1098+
and some extra params. This copies creates a deep copy of
1099+
the embedded paramMap, and copies the embedded and extra parameters over.
1100+
1101+
:param extra: Extra parameters to copy to the new instance
1102+
:return: Copy of this instance
1103+
"""
1104+
if extra is None:
1105+
extra = dict()
1106+
return OneVsRestModel([model.copy(extra) for model in self.models.copy(extra)])
11091107

11101108

11111109
if __name__ == "__main__":

0 commit comments

Comments
 (0)