Skip to content

Commit 9f670ce

Browse files
facaiyyanboliang
authored andcommitted
[SPARK-21306][ML] For branch 2.0, OneVsRest should support setWeightCol
The PR is related to #18554, and is modified for branch 2.0. ## What changes were proposed in this pull request? add `setWeightCol` method for OneVsRest. `weightCol` is ignored if classifier doesn't inherit HasWeightCol trait. ## How was this patch tested? + [x] add an unit test. Author: Yan Facai (颜发才) <[email protected]> Closes #18764 from facaiy/BUG/branch-2.0_OneVsRest_support_setWeightCol.
1 parent c27a01a commit 9f670ce

File tree

4 files changed

+82
-9
lines changed

4 files changed

+82
-9
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import org.apache.spark.ml._
3434
import org.apache.spark.ml.attribute._
3535
import org.apache.spark.ml.linalg.Vector
3636
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
37+
import org.apache.spark.ml.param.shared.HasWeightCol
3738
import org.apache.spark.ml.util._
3839
import org.apache.spark.sql.{DataFrame, Dataset, Row}
3940
import org.apache.spark.sql.functions._
@@ -53,7 +54,8 @@ private[ml] trait ClassifierTypeTrait {
5354
/**
5455
* Params for [[OneVsRest]].
5556
*/
56-
private[ml] trait OneVsRestParams extends PredictorParams with ClassifierTypeTrait {
57+
private[ml] trait OneVsRestParams extends PredictorParams
58+
with ClassifierTypeTrait with HasWeightCol {
5759

5860
/**
5961
* param for the base binary classifier that we reduce multiclass classification into.
@@ -290,6 +292,18 @@ final class OneVsRest @Since("1.4.0") (
290292
@Since("1.5.0")
291293
def setPredictionCol(value: String): this.type = set(predictionCol, value)
292294

295+
/**
296+
* Sets the value of param [[weightCol]].
297+
*
298+
* This is ignored if weight is not supported by [[classifier]].
299+
* If this is not set or empty, we treat all instance weights as 1.0.
300+
* Default is not set, so all instances have weight one.
301+
*
302+
* @group setParam
303+
*/
304+
@Since("2.3.0")
305+
def setWeightCol(value: String): this.type = set(weightCol, value)
306+
293307
@Since("1.4.0")
294308
override def transformSchema(schema: StructType): StructType = {
295309
validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
@@ -308,7 +322,20 @@ final class OneVsRest @Since("1.4.0") (
308322
}
309323
val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity)
310324

311-
val multiclassLabeled = dataset.select($(labelCol), $(featuresCol))
325+
val weightColIsUsed = isDefined(weightCol) && $(weightCol).nonEmpty && {
326+
getClassifier match {
327+
case _: HasWeightCol => true
328+
case c =>
329+
logWarning(s"weightCol is ignored, as it is not supported by $c now.")
330+
false
331+
}
332+
}
333+
334+
val multiclassLabeled = if (weightColIsUsed) {
335+
dataset.select($(labelCol), $(featuresCol), $(weightCol))
336+
} else {
337+
dataset.select($(labelCol), $(featuresCol))
338+
}
312339

313340
// persist if underlying dataset is not persistent.
314341
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
@@ -328,7 +355,13 @@ final class OneVsRest @Since("1.4.0") (
328355
paramMap.put(classifier.labelCol -> labelColName)
329356
paramMap.put(classifier.featuresCol -> getFeaturesCol)
330357
paramMap.put(classifier.predictionCol -> getPredictionCol)
331-
classifier.fit(trainingDataset, paramMap)
358+
if (weightColIsUsed) {
359+
val classifier_ = classifier.asInstanceOf[ClassifierType with HasWeightCol]
360+
paramMap.put(classifier_.weightCol -> getWeightCol)
361+
classifier_.fit(trainingDataset, paramMap)
362+
} else {
363+
classifier.fit(trainingDataset, paramMap)
364+
}
332365
}.toArray[ClassificationModel[_, _]]
333366

334367
if (handlePersistence) {

mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
3333
import org.apache.spark.mllib.util.TestingUtils._
3434
import org.apache.spark.rdd.RDD
3535
import org.apache.spark.sql.Dataset
36+
import org.apache.spark.sql.functions._
3637
import org.apache.spark.sql.types.Metadata
3738

3839
class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -143,6 +144,16 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
143144
assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
144145
}
145146

147+
test("SPARK-21306: OneVsRest should support setWeightCol") {
148+
val dataset2 = dataset.withColumn("weight", lit(1.0))
149+
// classifier inherits hasWeightCol
150+
val ova = new OneVsRest().setWeightCol("weight").setClassifier(new LogisticRegression())
151+
assert(ova.fit(dataset2) !== null)
152+
// classifier doesn't inherit hasWeightCol
153+
val ova2 = new OneVsRest().setWeightCol("weight").setClassifier(new DecisionTreeClassifier())
154+
assert(ova2.fit(dataset2) !== null)
155+
}
156+
146157
test("OneVsRest.copy and OneVsRestModel.copy") {
147158
val lr = new LogisticRegression()
148159
.setMaxIter(1)

python/pyspark/ml/classification.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,7 +1252,7 @@ def weights(self):
12521252
return self._call_java("weights")
12531253

12541254

1255-
class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasPredictionCol):
1255+
class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasWeightCol, HasPredictionCol):
12561256
"""
12571257
Parameters for OneVsRest and OneVsRestModel.
12581258
"""
@@ -1315,20 +1315,22 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable):
13151315

13161316
@keyword_only
13171317
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
1318-
classifier=None):
1318+
classifier=None, weightCol=None):
13191319
"""
13201320
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1321-
classifier=None)
1321+
classifier=None, weightCol=None)
13221322
"""
13231323
super(OneVsRest, self).__init__()
13241324
kwargs = self._input_kwargs
13251325
self._set(**kwargs)
13261326

13271327
@keyword_only
13281328
@since("2.0.0")
1329-
def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None):
1329+
def setParams(self, featuresCol=None, labelCol=None, predictionCol=None,
1330+
classifier=None, weightCol=None):
13301331
"""
1331-
setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None):
1332+
setParams(self, featuresCol=None, labelCol=None, predictionCol=None, \
1333+
classifier=None, weightCol=None):
13321334
Sets params for OneVsRest.
13331335
"""
13341336
kwargs = self._input_kwargs
@@ -1344,7 +1346,18 @@ def _fit(self, dataset):
13441346

13451347
numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1
13461348

1347-
multiclassLabeled = dataset.select(labelCol, featuresCol)
1349+
weightCol = None
1350+
if (self.isDefined(self.weightCol) and self.getWeightCol()):
1351+
if isinstance(classifier, HasWeightCol):
1352+
weightCol = self.getWeightCol()
1353+
else:
1354+
warnings.warn("weightCol is ignored, "
1355+
"as it is not supported by {0} now.".format(classifier))
1356+
1357+
if weightCol:
1358+
multiclassLabeled = dataset.select(labelCol, featuresCol, weightCol)
1359+
else:
1360+
multiclassLabeled = dataset.select(labelCol, featuresCol)
13481361

13491362
# persist if underlying dataset is not persistent.
13501363
handlePersistence = \
@@ -1360,6 +1373,8 @@ def trainSingleClass(index):
13601373
paramMap = dict([(classifier.labelCol, binaryLabelCol),
13611374
(classifier.featuresCol, featuresCol),
13621375
(classifier.predictionCol, predictionCol)])
1376+
if weightCol:
1377+
paramMap[classifier.weightCol] = weightCol
13631378
return classifier.fit(trainingDataset, paramMap)
13641379

13651380
# TODO: Parallel training for all classes.

python/pyspark/ml/tests.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,6 +1128,20 @@ def test_output_columns(self):
11281128
output = model.transform(df)
11291129
self.assertEqual(output.columns, ["label", "features", "prediction"])
11301130

1131+
def test_support_for_weightCol(self):
1132+
df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8), 1.0),
1133+
(1.0, Vectors.sparse(2, [], []), 1.0),
1134+
(2.0, Vectors.dense(0.5, 0.5), 1.0)],
1135+
["label", "features", "weight"])
1136+
# classifier inherits hasWeightCol
1137+
lr = LogisticRegression(maxIter=5, regParam=0.01)
1138+
ovr = OneVsRest(classifier=lr, weightCol="weight")
1139+
self.assertIsNotNone(ovr.fit(df))
1140+
# classifier doesn't inherit hasWeightCol
1141+
dt = DecisionTreeClassifier()
1142+
ovr2 = OneVsRest(classifier=dt, weightCol="weight")
1143+
self.assertIsNotNone(ovr2.fit(df))
1144+
11311145

11321146
class HashingTFTest(SparkSessionTestCase):
11331147

0 commit comments

Comments
 (0)