Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
Expand All @@ -53,7 +54,8 @@ private[ml] trait ClassifierTypeTrait {
/**
* Params for [[OneVsRest]].
*/
private[ml] trait OneVsRestParams extends PredictorParams with ClassifierTypeTrait {
private[ml] trait OneVsRestParams extends PredictorParams
with ClassifierTypeTrait with HasWeightCol {

/**
* param for the base binary classifier that we reduce multiclass classification into.
Expand Down Expand Up @@ -299,6 +301,18 @@ final class OneVsRest @Since("1.4.0") (
@Since("1.5.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)

/**
* Sets the value of param [[weightCol]].
*
* This is ignored if weight is not supported by [[classifier]].
* If this is not set or empty, we treat all instance weights as 1.0.
* Default is not set, so all instances have weight one.
*
* @group setParam
*/
@Since("2.3.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
Expand All @@ -317,7 +331,20 @@ final class OneVsRest @Since("1.4.0") (
}
val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity)

val multiclassLabeled = dataset.select($(labelCol), $(featuresCol))
val weightColIsUsed = isDefined(weightCol) && $(weightCol).nonEmpty && {
getClassifier match {
case _: HasWeightCol => true
case c =>
logWarning(s"weightCol is ignored, as it is not supported by $c now.")
false
}
}

val multiclassLabeled = if (weightColIsUsed) {
dataset.select($(labelCol), $(featuresCol), $(weightCol))
} else {
dataset.select($(labelCol), $(featuresCol))
}

// persist if underlying dataset is not persistent.
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
Expand All @@ -337,7 +364,13 @@ final class OneVsRest @Since("1.4.0") (
paramMap.put(classifier.labelCol -> labelColName)
paramMap.put(classifier.featuresCol -> getFeaturesCol)
paramMap.put(classifier.predictionCol -> getPredictionCol)
classifier.fit(trainingDataset, paramMap)
if (weightColIsUsed) {
val classifier_ = classifier.asInstanceOf[ClassifierType with HasWeightCol]
paramMap.put(classifier_.weightCol -> getWeightCol)
classifier_.fit(trainingDataset, paramMap)
} else {
classifier.fit(trainingDataset, paramMap)
}
}.toArray[ClassificationModel[_, _]]

if (handlePersistence) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,16 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
}

test("SPARK-21306: OneVsRest should support setWeightCol") {
val dataset2 = dataset.withColumn("weight", lit(1.0))
// classifier inherits hasWeightCol
val ova = new OneVsRest().setWeightCol("weight").setClassifier(new LogisticRegression())
assert(ova.fit(dataset2) !== null)
// classifier doesn't inherit hasWeightCol
val ova2 = new OneVsRest().setWeightCol("weight").setClassifier(new DecisionTreeClassifier())
assert(ova2.fit(dataset2) !== null)
}

test("OneVsRest.copy and OneVsRestModel.copy") {
val lr = new LogisticRegression()
.setMaxIter(1)
Expand Down
27 changes: 21 additions & 6 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,7 +1331,7 @@ def weights(self):
return self._call_java("weights")


class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasPredictionCol):
class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasWeightCol, HasPredictionCol):
"""
Parameters for OneVsRest and OneVsRestModel.
"""
Expand Down Expand Up @@ -1394,20 +1394,22 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable):

@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
classifier=None):
classifier=None, weightCol=None):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
classifier=None)
classifier=None, weightCol=None)
"""
super(OneVsRest, self).__init__()
kwargs = self._input_kwargs
self._set(**kwargs)

@keyword_only
@since("2.0.0")
def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None):
def setParams(self, featuresCol=None, labelCol=None, predictionCol=None,
classifier=None, weightCol=None):
"""
setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None):
setParams(self, featuresCol=None, labelCol=None, predictionCol=None, \
classifier=None, weightCol=None):
Sets params for OneVsRest.
"""
kwargs = self._input_kwargs
Expand All @@ -1423,7 +1425,18 @@ def _fit(self, dataset):

numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1

multiclassLabeled = dataset.select(labelCol, featuresCol)
weightCol = None
if (self.isDefined(self.weightCol) and self.getWeightCol()):
if isinstance(classifier, HasWeightCol):
weightCol = self.getWeightCol()
else:
warnings.warn("weightCol is ignored, "
"as it is not supported by {0} now.".format(classifier))

if weightCol:
multiclassLabeled = dataset.select(labelCol, featuresCol, weightCol)
else:
multiclassLabeled = dataset.select(labelCol, featuresCol)

# persist if underlying dataset is not persistent.
handlePersistence = \
Expand All @@ -1439,6 +1452,8 @@ def trainSingleClass(index):
paramMap = dict([(classifier.labelCol, binaryLabelCol),
(classifier.featuresCol, featuresCol),
(classifier.predictionCol, predictionCol)])
if weightCol:
paramMap[classifier.weightCol] = weightCol
return classifier.fit(trainingDataset, paramMap)

# TODO: Parallel training for all classes.
Expand Down
14 changes: 14 additions & 0 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,6 +1218,20 @@ def test_output_columns(self):
output = model.transform(df)
self.assertEqual(output.columns, ["label", "features", "prediction"])

def test_support_for_weightCol(self):
df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8), 1.0),
(1.0, Vectors.sparse(2, [], []), 1.0),
(2.0, Vectors.dense(0.5, 0.5), 1.0)],
["label", "features", "weight"])
# classifier inherits hasWeightCol
lr = LogisticRegression(maxIter=5, regParam=0.01)
ovr = OneVsRest(classifier=lr, weightCol="weight")
self.assertIsNotNone(ovr.fit(df))
# classifier doesn't inherit hasWeightCol
dt = DecisionTreeClassifier()
ovr2 = OneVsRest(classifier=dt, weightCol="weight")
self.assertIsNotNone(ovr2.fit(df))


class HashingTFTest(SparkSessionTestCase):

Expand Down