Skip to content

Commit d69ef5e

Browse files
committed
Add a test
1 parent b5734be commit d69ef5e

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.ml.feature
1919

20+
import org.apache.spark.SparkException
2021
import org.apache.spark.SparkFunSuite
2122
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
2223
import org.apache.spark.ml.param.ParamsSuite
@@ -49,6 +50,38 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
4950
assert(output === expected)
5051
}
5152

53+
test("StringIndexerUnessen") {
54+
val data = sc.parallelize(Seq((0, "a"), (1, "b")), 2)
55+
val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2)
56+
val df = sqlContext.createDataFrame(data).toDF("id", "label")
57+
val df2 = sqlContext.createDataFrame(data2).toDF("id", "label")
58+
val indexer = new StringIndexer()
59+
.setInputCol("label")
60+
.setOutputCol("labelIndex")
61+
.fit(df)
62+
// Verify we throw by default with unseen values
63+
intercept[SparkException] {
64+
indexer.transform(df2).collect()
65+
}
66+
val indexerSkipInvalid = new StringIndexer()
67+
.setInputCol("label")
68+
.setOutputCol("labelIndex")
69+
.setSkipInvalid(true)
70+
.fit(df)
71+
// Verify that we skip the c record
72+
val transformed = indexerSkipInvalid.transform(df2)
73+
val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
74+
.asInstanceOf[NominalAttribute]
75+
assert(attr.values.get === Array("b", "a"))
76+
val output = transformed.select("id", "labelIndex").map { r =>
77+
(r.getInt(0), r.getDouble(1))
78+
}.collect().toSet
79+
// a -> 1, b -> 0
80+
val expected = Set((0, 1.0), (1, 0.0))
81+
assert(output === expected)
82+
}
83+
84+
5285
test("StringIndexer with a numeric input column") {
5386
val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2)
5487
val df = sqlContext.createDataFrame(data).toDF("id", "label")

0 commit comments

Comments
 (0)