Skip to content

Commit 5b7067c

Browse files
Nick Pritchardmengxr
authored andcommitted
[SPARK-10573] [ML] IndexToString output schema should be StringType
Fixes bug where IndexToString output schema was DoubleType. Correct me if I'm wrong, but it doesn't seem like the output needs to have any "ML Attribute" metadata. Author: Nick Pritchard <[email protected]> Closes #8751 from pnpritchard/SPARK-10573. (cherry picked from commit 8a634e9) Signed-off-by: Xiangrui Meng <[email protected]>
1 parent 5f58704 commit 5b7067c

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.ml.Transformer
2727
import org.apache.spark.ml.util.Identifiable
2828
import org.apache.spark.sql.DataFrame
2929
import org.apache.spark.sql.functions._
30-
import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType}
30+
import org.apache.spark.sql.types._
3131
import org.apache.spark.util.collection.OpenHashMap
3232

3333
/**
@@ -220,8 +220,7 @@ class IndexToString private[ml] (
220220
val outputColName = $(outputCol)
221221
require(inputFields.forall(_.name != outputColName),
222222
s"Output column $outputColName already exists.")
223-
val attr = NominalAttribute.defaultAttr.withName($(outputCol))
224-
val outputFields = inputFields :+ attr.toStructField()
223+
val outputFields = inputFields :+ StructField($(outputCol), StringType)
225224
StructType(outputFields)
226225
}
227226

mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
package org.apache.spark.ml.feature
1919

20-
import org.apache.spark.SparkFunSuite
20+
import org.apache.spark.sql.types.{StringType, StructType, StructField, DoubleType}
21+
import org.apache.spark.{SparkException, SparkFunSuite}
2122
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
2223
import org.apache.spark.ml.param.ParamsSuite
2324
import org.apache.spark.ml.util.MLTestingUtils
@@ -134,4 +135,11 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
134135
assert(a === b)
135136
}
136137
}
138+
139+
test("IndexToString.transformSchema (SPARK-10573)") {
140+
val idxToStr = new IndexToString().setInputCol("input").setOutputCol("output")
141+
val inSchema = StructType(Seq(StructField("input", DoubleType)))
142+
val outSchema = idxToStr.transformSchema(inSchema)
143+
assert(outSchema("output").dataType === StringType)
144+
}
137145
}

0 commit comments

Comments
 (0)