Skip to content

Commit b17cc7b

Browse files
committed
fix nits
1 parent 6d30d77 commit b17cc7b

File tree

1 file changed

+21
-17
lines changed

1 file changed

+21
-17
lines changed

python/pyspark/ml/classification.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,9 @@ class OneVsRest(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol):
932932
>>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()
933933
>>> model.transform(test1).head().indexed
934934
0.0
935+
>>> test2 = sc.parallelize([Row(features=Vectors.dense(0.5, 0.4))]).toDF()
936+
>>> model.transform(test2).head().indexed
937+
2.0
935938
936939
.. versionadded:: 2.0.0
937940
"""
@@ -976,11 +979,13 @@ def getClassifier(self):
976979
return self.getOrDefault(self.classifier)
977980

978981
def _fit(self, dataset):
979-
980982
labelCol = self.getLabelCol()
981-
featureCol = self.getFeaturesCol()
982-
numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"])
983-
multiclassLabeled = dataset.select(labelCol, featureCol)
983+
featuresCol = self.getFeaturesCol()
984+
predictionCol = self.getPredictionCol()
985+
986+
numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1
987+
988+
multiclassLabeled = dataset.select(labelCol, featuresCol)
984989

985990
# persist if underlying dataset is not persistent.
986991
handlePersistence = \
@@ -991,24 +996,23 @@ def _fit(self, dataset):
991996
models = []
992997

993998
for index in range(0, numClasses):
994-
# newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata()
995-
labelColName = "mc2b$" + str(index)
999+
binaryLabelCol = "mc2b$" + str(index)
9961000
trainingDataset = multiclassLabeled.withColumn(
997-
labelColName,
998-
when(dataset[self.getLabelCol()] == float(index), 1.0).otherwise(0.0))
1001+
binaryLabelCol,
1002+
when(dataset[labelCol] == float(index), 1.0).otherwise(0.0))
9991003
classifier = self.getClassifier()
1000-
paramMap = dict([(classifier.labelCol, labelColName),
1001-
(classifier.featuresCol, self.getFeaturesCol()),
1002-
(classifier.predictionCol, self.getPredictionCol())])
1004+
paramMap = dict([(classifier.labelCol, binaryLabelCol),
1005+
(classifier.featuresCol, featuresCol),
1006+
(classifier.predictionCol, predictionCol)])
10031007
models.append(classifier.fit(trainingDataset, paramMap))
10041008

10051009
if handlePersistence:
10061010
multiclassLabeled.unpersist()
10071011

10081012
return OneVsRestModel(models=models)\
1009-
.setFeaturesCol(self.getFeaturesCol())\
1010-
.setLabelCol(self.getLabelCol())\
1011-
.setPredictionCol(self.getPredictionCol())
1013+
.setFeaturesCol(featuresCol)\
1014+
.setLabelCol(labelCol)\
1015+
.setPredictionCol(predictionCol)
10121016

10131017
@since("2.0.0")
10141018
def copy(self, extra=None):
@@ -1022,10 +1026,10 @@ def copy(self, extra=None):
10221026
"""
10231027
if extra is None:
10241028
extra = dict()
1025-
newOVR = Params.copy(self, extra)
1029+
newOvr = Params.copy(self, extra)
10261030
if self.isSet(self.classifier):
1027-
newOVR.setClassifier(self.getClassifier().copy(extra))
1028-
return newOVR
1031+
newOvr.setClassifier(self.getClassifier().copy(extra))
1032+
return newOvr
10291033

10301034

10311035
class OneVsRestModel(Model, HasFeaturesCol, HasLabelCol, HasPredictionCol):

0 commit comments

Comments
 (0)