Skip to content

Commit 53fc7db

Browse files
mengxrjeanlyn
authored andcommitted
[SPARK-8049] [MLLIB] drop tmp col from OneVsRest output
The temporary column should be dropped after we get the prediction column. harsha2010 Author: Xiangrui Meng <[email protected]> Closes apache#6592 from mengxr/SPARK-8049 and squashes the following commits: 1d89107 [Xiangrui Meng] use SparkFunSuite 6ee70de [Xiangrui Meng] drop tmp col from OneVsRest output
1 parent cc56082 commit 53fc7db

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ final class OneVsRestModel private[ml] (
131131
// output label and label metadata as prediction
132132
val labelUdf = callUDF(label, DoubleType, col(accColName))
133133
aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata))
134+
.drop(accColName)
134135
}
135136
}
136137

mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,15 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
9393
val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features)
9494
ova.fit(datasetWithLabelMetadata)
9595
}
96+
97+
test("SPARK-8049: OneVsRest shouldn't output temp columns") {
98+
val logReg = new LogisticRegression()
99+
.setMaxIter(1)
100+
val ovr = new OneVsRest()
101+
.setClassifier(logReg)
102+
val output = ovr.fit(dataset).transform(dataset)
103+
assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
104+
}
96105
}
97106

98107
private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) {

0 commit comments

Comments
 (0)