Skip to content

Commit e5db190

Browse files
committed
address failed RFormula tests
1 parent 77bea32 commit e5db190

File tree

2 files changed

+27
-10
lines changed

2 files changed

+27
-10
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] {
235235
/**
236236
* Model fitted by [[StringIndexer]].
237237
*
238-
* @param labelsArray Array of Ordered list of labels, corresponding to indices to be assigned
238+
* @param labelsArray Array of ordered list of labels, corresponding to indices to be assigned
239239
* for each input column.
240240
*
241241
* @note During transformation, if the input column does not exist,
@@ -365,8 +365,14 @@ class StringIndexerModel (
365365
.as(outputColName, metadata)
366366
}
367367
}
368-
filteredDataset.withColumns(outputColNames.filter(_ != null),
369-
outputColumns.filter(_ != null))
368+
val filteredOutputColNames = outputColNames.filter(_ != null)
369+
val filteredOutputColumns = outputColumns.filter(_ != null)
370+
371+
if (filteredOutputColNames.length > 0) {
372+
filteredDataset.withColumns(filteredOutputColNames, filteredOutputColumns)
373+
} else {
374+
filteredDataset.toDF()
375+
}
370376
}
371377

372378
@Since("1.4.0")

mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
114114

115115
test("encodes string terms") {
116116
val formula = new RFormula().setFormula("id ~ a + b")
117-
val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
117+
val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5),
118+
(5, "bar", 6), (6, "foo", 6))
118119
.toDF("id", "a", "b")
119120
val model = formula.fit(original)
120121
val result = model.transform(original)
@@ -123,7 +124,9 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
123124
(1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
124125
(2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0),
125126
(3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0),
126-
(4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0)
127+
(4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0),
128+
(5, "bar", 6, Vectors.dense(1.0, 0.0, 6.0), 5.0),
129+
(6, "foo", 6, Vectors.dense(0.0, 1.0, 6.0), 6.0)
127130
).toDF("id", "a", "b", "features", "label")
128131
assert(result.schema.toString == resultSchema.toString)
129132
assert(result.collect() === expected.collect())
@@ -299,15 +302,18 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
299302
test("index string label") {
300303
val formula = new RFormula().setFormula("id ~ a + b")
301304
val original =
302-
Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5))
305+
Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5),
306+
("female", "bar", 6), ("female", "foo", 6))
303307
.toDF("id", "a", "b")
304308
val model = formula.fit(original)
305309
val result = model.transform(original)
306310
val expected = Seq(
307311
("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
308312
("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0),
309313
("female", "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 0.0),
310-
("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0)
314+
("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0),
315+
("female", "bar", 6, Vectors.dense(1.0, 0.0, 6.0), 0.0),
316+
("female", "foo", 6, Vectors.dense(0.0, 1.0, 6.0), 0.0)
311317
).toDF("id", "a", "b", "features", "label")
312318
// assert(result.schema.toString == resultSchema.toString)
313319
assert(result.collect() === expected.collect())
@@ -316,7 +322,8 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
316322
test("force to index label even it is numeric type") {
317323
val formula = new RFormula().setFormula("id ~ a + b").setForceIndexLabel(true)
318324
val original = spark.createDataFrame(
319-
Seq((1.0, "foo", 4), (1.0, "bar", 4), (0.0, "bar", 5), (1.0, "baz", 5))
325+
Seq((1.0, "foo", 4), (1.0, "bar", 4), (0.0, "bar", 5), (1.0, "baz", 5),
326+
(1.0, "bar", 6), (0.0, "foo", 6))
320327
).toDF("id", "a", "b")
321328
val model = formula.fit(original)
322329
val result = model.transform(original)
@@ -325,14 +332,18 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
325332
(1.0, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 0.0),
326333
(1.0, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0),
327334
(0.0, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 1.0),
328-
(1.0, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 0.0))
335+
(1.0, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 0.0),
336+
(1.0, "bar", 6, Vectors.dense(1.0, 0.0, 6.0), 0.0),
337+
(0.0, "foo", 6, Vectors.dense(0.0, 1.0, 6.0), 1.0)
338+
)
329339
).toDF("id", "a", "b", "features", "label")
330340
assert(result.collect() === expected.collect())
331341
}
332342

333343
test("attribute generation") {
334344
val formula = new RFormula().setFormula("id ~ a + b")
335-
val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
345+
val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5),
346+
(1, "bar", 6), (0, "foo", 6))
336347
.toDF("id", "a", "b")
337348
val model = formula.fit(original)
338349
val result = model.transform(original)

0 commit comments

Comments
 (0)