1717
1818package org .apache .spark .ml .classification
1919
20- import org .apache .spark .SparkFunSuite
20+ import org .scalatest .FunSuite
21+
2122import org .apache .spark .ml .attribute .NominalAttribute
2223import org .apache .spark .ml .util .MetadataUtils
2324import org .apache .spark .mllib .classification .LogisticRegressionWithLBFGS
@@ -29,7 +30,7 @@ import org.apache.spark.mllib.util.TestingUtils._
2930import org .apache .spark .rdd .RDD
3031import org .apache .spark .sql .DataFrame
3132
32- class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
33+ class OneVsRestSuite extends FunSuite with MLlibTestSparkContext {
3334
3435 @ transient var dataset : DataFrame = _
3536 @ transient var rdd : RDD [LabeledPoint ] = _
@@ -93,6 +94,15 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
9394 val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features)
9495 ova.fit(datasetWithLabelMetadata)
9596 }
97+
98+ test(" SPARK-8049: OneVsRest shouldn't output temp columns" ) {
99+ val logReg = new LogisticRegression ()
100+ .setMaxIter(1 )
101+ val ovr = new OneVsRest ()
102+ .setClassifier(logReg)
103+ val output = ovr.fit(dataset).transform(dataset)
104+ assert(output.schema.fieldNames.toSet === Set (" label" , " features" , " prediction" ))
105+ }
96106}
97107
98108private class MockLogisticRegression (uid : String ) extends LogisticRegression (uid) {
0 commit comments