Skip to content

Commit d4d762f

Browse files
Ram Sriharshajkbradley
authored andcommitted
[SPARK-8092] [ML] Allow OneVsRest Classifier feature and label column names to be configurable.
The base classifier input and output columns are ignored in favor of the ones specified in OneVsRest. Author: Ram Sriharsha <[email protected]> Closes #6631 from harsha2010/SPARK-8092 and squashes the following commits: 6591dc6 [Ram Sriharsha] add documentation for params b7024b1 [Ram Sriharsha] cleanup f0e2bfb [Ram Sriharsha] merge with master 108d3d7 [Ram Sriharsha] merge with master 4f74126 [Ram Sriharsha] Allow label/ features columns to be configurable
1 parent d249636 commit d4d762f

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ private[ml] trait OneVsRestParams extends PredictorParams {
4747

4848
/**
4949
* param for the base binary classifier that we reduce multiclass classification into.
50+
* The base classifier input and output columns are ignored in favor of
51+
* the ones specified in [[OneVsRest]].
5052
* @group param
5153
*/
5254
val classifier: Param[ClassifierType] = new Param(this, "classifier", "base binary classifier")
@@ -160,6 +162,15 @@ final class OneVsRest(override val uid: String)
160162
set(classifier, value.asInstanceOf[ClassifierType])
161163
}
162164

165+
/** @group setParam */
166+
def setLabelCol(value: String): this.type = set(labelCol, value)
167+
168+
/** @group setParam */
169+
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
170+
171+
/** @group setParam */
172+
def setPredictionCol(value: String): this.type = set(predictionCol, value)
173+
163174
override def transformSchema(schema: StructType): StructType = {
164175
validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
165176
}
@@ -195,7 +206,11 @@ final class OneVsRest(override val uid: String)
195206
val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, newLabelMeta)
196207
val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta)
197208
val classifier = getClassifier
198-
classifier.fit(trainingDataset, classifier.labelCol -> labelColName)
209+
val paramMap = new ParamMap()
210+
paramMap.put(classifier.labelCol -> labelColName)
211+
paramMap.put(classifier.featuresCol -> getFeaturesCol)
212+
paramMap.put(classifier.predictionCol -> getPredictionCol)
213+
classifier.fit(trainingDataset, paramMap)
199214
}.toArray[ClassificationModel[_, _]]
200215

201216
if (handlePersistence) {

mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.ml.classification
1919

2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.ml.attribute.NominalAttribute
22+
import org.apache.spark.ml.feature.StringIndexer
2223
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
2324
import org.apache.spark.ml.util.MetadataUtils
2425
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
@@ -104,6 +105,29 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
104105
ova.fit(datasetWithLabelMetadata)
105106
}
106107

108+
test("SPARK-8092: ensure label features and prediction cols are configurable") {
109+
val labelIndexer = new StringIndexer()
110+
.setInputCol("label")
111+
.setOutputCol("indexed")
112+
113+
val indexedDataset = labelIndexer
114+
.fit(dataset)
115+
.transform(dataset)
116+
.drop("label")
117+
.withColumnRenamed("features", "f")
118+
119+
val ova = new OneVsRest()
120+
ova.setClassifier(new LogisticRegression())
121+
.setLabelCol(labelIndexer.getOutputCol)
122+
.setFeaturesCol("f")
123+
.setPredictionCol("p")
124+
125+
val ovaModel = ova.fit(indexedDataset)
126+
val transformedDataset = ovaModel.transform(indexedDataset)
127+
val outputFields = transformedDataset.schema.fieldNames.toSet
128+
assert(outputFields.contains("p"))
129+
}
130+
107131
test("SPARK-8049: OneVsRest shouldn't output temp columns") {
108132
val logReg = new LogisticRegression()
109133
.setMaxIter(1)

0 commit comments

Comments
 (0)