1515# limitations under the License.
1616#
1717
18- import warnings
19-
2018import operator
21- import uuid
19+ import warnings
2220
23- from pyspark import since
2421from pyspark .ml import Estimator , Model
25- from pyspark .ml .util import *
26- from pyspark .ml .wrapper import JavaEstimator , JavaModel
27- from pyspark .ml .param import TypeConverters
2822from pyspark .ml .param .shared import *
2923from pyspark .ml .regression import (
3024 RandomForestParams , TreeEnsembleParams , DecisionTreeModel , TreeEnsembleModels )
25+ from pyspark .ml .util import *
26+ from pyspark .ml .wrapper import JavaEstimator , JavaModel
3127from pyspark .mllib .common import inherit_doc
3228from pyspark .sql .functions import udf , when
33- from pyspark .sql .types import ArrayType , MapType , IntegerType , DoubleType
29+ from pyspark .sql .types import ArrayType , DoubleType
3430from pyspark .storagelevel import StorageLevel
3531
3632__all__ = ['LogisticRegression' , 'LogisticRegressionModel' ,
3733 'DecisionTreeClassifier' , 'DecisionTreeClassificationModel' ,
3834 'GBTClassifier' , 'GBTClassificationModel' ,
3935 'RandomForestClassifier' , 'RandomForestClassificationModel' ,
4036 'NaiveBayes' , 'NaiveBayesModel' ,
41- 'MultilayerPerceptronClassifier' , 'MultilayerPerceptronClassificationModel' ]
37+ 'MultilayerPerceptronClassifier' , 'MultilayerPerceptronClassificationModel' ,
38+ 'OneVsRest' , 'OneVsRestModel' ]
4239
4340
4441@inherit_doc
@@ -923,16 +920,17 @@ class OneVsRest(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol):
923920 >>> from pyspark.sql import Row
924921 >>> from pyspark.mllib.linalg import Vectors
925922 >>> df = sc.parallelize([
926- ... Row(label=1.0, features=Vectors.dense(1.0)),
927- ... Row(label=0.0, features=Vectors.sparse(1, [], []))]).toDF()
923+ ... Row(label=0.0, features=Vectors.dense(1.0, 0.8)),
924+ ... Row(label=1.0, features=Vectors.sparse(2, [], [])),
925+ ... Row(label=2.0, features=Vectors.dense(0.5, 0.5))]).toDF()
928926 >>> lr = LogisticRegression(maxIter=5, regParam=0.01)
929927 >>> ovr = OneVsRest(classifier=lr).setPredictionCol("indexed")
930928 >>> model = ovr.fit(df)
931- >>> model.models[0].weights
932- >>> model.models[0].coefficients
933- >>> model.models[0].intercept
934- >>> test0 = sc.parallelize([Row(features=Vectors.dense(- 1.0))]).toDF()
935- >>> model.transform(test0).show()
929+ >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 0.0))]).toDF()
930+ >>> model.transform(test0).head().indexed
931+ 1.0
932+ >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [ 1.0] ))]).toDF()
933+ >>> model.transform(test1).head().indexed
936934 0.0
937935
938936 .. versionadded:: 2.0.0
@@ -965,7 +963,7 @@ def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classif
965963 @since ("2.0.0" )
966964 def setClassifier (self , value ):
967965 """
968- Sets the value of :py:attr:`estimator `.
966+ Sets the value of :py:attr:`classifier `.
969967 """
970968 self ._paramMap [self .classifier ] = value
971969 return self
@@ -985,7 +983,7 @@ def _fit(self, dataset):
985983 multiclassLabeled = dataset .select (labelCol , featureCol )
986984
987985 # persist if underlying dataset is not persistent.
988- handlePersistence = \
986+ handlePersistence = \
989987 dataset .rdd .getStorageLevel () == StorageLevel (False , False , False , False )
990988 if handlePersistence :
991989 multiclassLabeled .persist (StorageLevel .MEMORY_AND_DISK )
@@ -1007,32 +1005,32 @@ def _fit(self, dataset):
10071005 if handlePersistence :
10081006 multiclassLabeled .unpersist ()
10091007
1010- return OneVsRestModel (models = models )
1011-
1012- # @since("2.0.0")
1013- # def copy (self, extra=None):
1014- # """
1015- # Creates a copy of this instance with a randomly generated uid
1016- # and some extra params. This copies creates a deep copy of
1017- # the embedded paramMap, and copies the embedded and extra parameters over.
1018-
1019- # :param extra: Extra parameters to copy to the new instance
1020- # :return: Copy of this instance
1021- # """
1022- # if extra is None:
1023- # extra = dict()
1024- # newCV = Params.copy(self, extra)
1025- # if self.isSet(self.estimator) :
1026- # newCV.setEstimator(self.getEstimator().copy(extra) )
1027- # # estimatorParamMaps remain the same
1028- # if self.isSet(self.evaluator ):
1029- # newCV.setEvaluator (self.getEvaluator ().copy(extra))
1030- # return newCV
1008+ return OneVsRestModel (models = models )\
1009+ . setFeaturesCol ( self . getFeaturesCol ())\
1010+ . setLabelCol ( self . getLabelCol ())\
1011+ . setPredictionCol (self . getPredictionCol ())
1012+
1013+ @ since ( "2.0.0" )
1014+ def copy ( self , extra = None ):
1015+ """
1016+ Creates a copy of this instance with a randomly generated uid
1017+ and some extra params. This copies creates a deep copy of
1018+ the embedded paramMap, and copies the embedded and extra parameters over.
1019+
1020+ :param extra: Extra parameters to copy to the new instance
1021+ :return: Copy of this instance
1022+ """
1023+ if extra is None :
1024+ extra = dict ( )
1025+ newOVR = Params . copy ( self , extra )
1026+ if self .isSet (self .classifier ):
1027+ newOVR . setClassifier (self .getClassifier ().copy (extra ))
1028+ return newOVR
10311029
10321030
10331031class OneVsRestModel (Model , HasFeaturesCol , HasLabelCol , HasPredictionCol ):
10341032 """
1035- Model produced by [[ OneVsRest]] .
1033+ Model fitted by OneVsRest.
10361034 This stores the models resulting from training k binary classifiers: one for each class.
10371035 Each example is scored against all k models, and the model with the highest score
10381036 is picked to label the example.
@@ -1055,7 +1053,7 @@ def _transform(self, dataset):
10551053 newDataset = dataset .withColumn (accColName , initUDF (dataset [origCols [0 ]]))
10561054
10571055 # persist if underlying dataset is not persistent.
1058- handlePersistence = \
1056+ handlePersistence = \
10591057 dataset .rdd .getStorageLevel () == StorageLevel (False , False , False , False )
10601058 if handlePersistence :
10611059 newDataset .persist (StorageLevel .MEMORY_AND_DISK )
@@ -1086,26 +1084,26 @@ def _transform(self, dataset):
10861084
10871085 # output the index of the classifier with highest confidence as prediction
10881086 labelUDF = udf (
1089- lambda predictions : float (max (enumerate (predictions ), key = operator .itemgetter (1 ))[0 ]))
1087+ lambda predictions : float (max (enumerate (predictions ), key = operator .itemgetter (1 ))[0 ]),
1088+ DoubleType ())
10901089
10911090 # output label and label metadata as prediction
10921091 return aggregatedDataset .withColumn (
10931092 self .getPredictionCol (), labelUDF (aggregatedDataset [accColName ])).drop (accColName )
10941093
1095- # @since("1.4.0")
1096- # def copy(self, extra=None):
1097- # """
1098- # Creates a copy of this instance with a randomly generated uid
1099- # and some extra params. This copies the underlying bestModel,
1100- # creates a deep copy of the embedded paramMap, and
1101- # copies the embedded and extra parameters over.
1102-
1103- # :param extra: Extra parameters to copy to the new instance
1104- # :return: Copy of this instance
1105- # """
1106- # if extra is None:
1107- # extra = dict()
1108- # return OneVsRestModel(self.models.copy(extra))
1094+ @since ("2.0.0" )
1095+ def copy (self , extra = None ):
1096+ """
1097+ Creates a copy of this instance with a randomly generated uid
1098+ and some extra params. This copies creates a deep copy of
1099+ the embedded paramMap, and copies the embedded and extra parameters over.
1100+
1101+ :param extra: Extra parameters to copy to the new instance
1102+ :return: Copy of this instance
1103+ """
1104+ if extra is None :
1105+ extra = dict ()
1106+ return OneVsRestModel ([model .copy (extra ) for model in self .models .copy (extra )])
11091107
11101108
11111109if __name__ == "__main__" :
0 commit comments