@@ -932,6 +932,9 @@ class OneVsRest(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol):
932932 >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()
933933 >>> model.transform(test1).head().indexed
934934 0.0
935+ >>> test2 = sc.parallelize([Row(features=Vectors.dense(0.5, 0.4))]).toDF()
936+ >>> model.transform(test2).head().indexed
937+ 2.0
935938
936939 .. versionadded:: 2.0.0
937940 """
@@ -976,11 +979,13 @@ def getClassifier(self):
976979 return self .getOrDefault (self .classifier )
977980
978981 def _fit (self , dataset ):
979-
980982 labelCol = self .getLabelCol ()
981- featureCol = self .getFeaturesCol ()
982- numClasses = int (dataset .agg ({labelCol : "max" }).head ()["max(" + labelCol + ")" ])
983- multiclassLabeled = dataset .select (labelCol , featureCol )
983+ featuresCol = self .getFeaturesCol ()
984+ predictionCol = self .getPredictionCol ()
985+
986+ numClasses = int (dataset .agg ({labelCol : "max" }).head ()["max(" + labelCol + ")" ]) + 1
987+
988+ multiclassLabeled = dataset .select (labelCol , featuresCol )
984989
985990 # persist if underlying dataset is not persistent.
986991 handlePersistence = \
@@ -991,24 +996,23 @@ def _fit(self, dataset):
991996 models = []
992997
993998 for index in range (0 , numClasses ):
994- # newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata()
995- labelColName = "mc2b$" + str (index )
999+ binaryLabelCol = "mc2b$" + str (index )
9961000 trainingDataset = multiclassLabeled .withColumn (
997- labelColName ,
998- when (dataset [self . getLabelCol () ] == float (index ), 1.0 ).otherwise (0.0 ))
1001+ binaryLabelCol ,
1002+ when (dataset [labelCol ] == float (index ), 1.0 ).otherwise (0.0 ))
9991003 classifier = self .getClassifier ()
1000- paramMap = dict ([(classifier .labelCol , labelColName ),
1001- (classifier .featuresCol , self . getFeaturesCol () ),
1002- (classifier .predictionCol , self . getPredictionCol () )])
1004+ paramMap = dict ([(classifier .labelCol , binaryLabelCol ),
1005+ (classifier .featuresCol , featuresCol ),
1006+ (classifier .predictionCol , predictionCol )])
10031007 models .append (classifier .fit (trainingDataset , paramMap ))
10041008
10051009 if handlePersistence :
10061010 multiclassLabeled .unpersist ()
10071011
10081012 return OneVsRestModel (models = models )\
1009- .setFeaturesCol (self . getFeaturesCol () )\
1010- .setLabelCol (self . getLabelCol () )\
1011- .setPredictionCol (self . getPredictionCol () )
1013+ .setFeaturesCol (featuresCol )\
1014+ .setLabelCol (labelCol )\
1015+ .setPredictionCol (predictionCol )
10121016
10131017 @since ("2.0.0" )
10141018 def copy (self , extra = None ):
@@ -1022,10 +1026,10 @@ def copy(self, extra=None):
10221026 """
10231027 if extra is None :
10241028 extra = dict ()
1025- newOVR = Params .copy (self , extra )
1029+ newOvr = Params .copy (self , extra )
10261030 if self .isSet (self .classifier ):
1027- newOVR .setClassifier (self .getClassifier ().copy (extra ))
1028- return newOVR
1031+ newOvr .setClassifier (self .getClassifier ().copy (extra ))
1032+ return newOvr
10291033
10301034
10311035class OneVsRestModel (Model , HasFeaturesCol , HasLabelCol , HasPredictionCol ):
0 commit comments