|
17 | 17 |
|
18 | 18 | package org.apache.spark.ml.feature |
19 | 19 |
|
| 20 | +import org.apache.spark.SparkException |
20 | 21 | import org.apache.spark.SparkFunSuite |
21 | 22 | import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} |
22 | 23 | import org.apache.spark.ml.param.ParamsSuite |
@@ -49,6 +50,38 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { |
49 | 50 | assert(output === expected) |
50 | 51 | } |
51 | 52 |
|
| 53 | + test("StringIndexerUnessen") { |
| 54 | + val data = sc.parallelize(Seq((0, "a"), (1, "b")), 2) |
| 55 | + val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2) |
| 56 | + val df = sqlContext.createDataFrame(data).toDF("id", "label") |
| 57 | + val df2 = sqlContext.createDataFrame(data2).toDF("id", "label") |
| 58 | + val indexer = new StringIndexer() |
| 59 | + .setInputCol("label") |
| 60 | + .setOutputCol("labelIndex") |
| 61 | + .fit(df) |
| 62 | + // Verify we throw by default with unseen values |
| 63 | + intercept[SparkException] { |
| 64 | + indexer.transform(df2).collect() |
| 65 | + } |
| 66 | + val indexerSkipInvalid = new StringIndexer() |
| 67 | + .setInputCol("label") |
| 68 | + .setOutputCol("labelIndex") |
| 69 | + .setSkipInvalid(true) |
| 70 | + .fit(df) |
| 71 | + // Verify that we skip the c record |
| 72 | + val transformed = indexerSkipInvalid.transform(df2) |
| 73 | + val attr = Attribute.fromStructField(transformed.schema("labelIndex")) |
| 74 | + .asInstanceOf[NominalAttribute] |
| 75 | + assert(attr.values.get === Array("b", "a")) |
| 76 | + val output = transformed.select("id", "labelIndex").map { r => |
| 77 | + (r.getInt(0), r.getDouble(1)) |
| 78 | + }.collect().toSet |
| 79 | + // a -> 1, b -> 0 |
| 80 | + val expected = Set((0, 1.0), (1, 0.0)) |
| 81 | + assert(output === expected) |
| 82 | + } |
| 83 | + |
| 84 | + |
52 | 85 | test("StringIndexer with a numeric input column") { |
53 | 86 | val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2) |
54 | 87 | val df = sqlContext.createDataFrame(data).toDF("id", "label") |
|
0 commit comments