|
17 | 17 |
|
18 | 18 | package org.apache.spark.ml.feature |
19 | 19 |
|
20 | | -import org.apache.spark.{SparkException, SparkFunSuite} |
| 20 | +import org.apache.spark.SparkException |
21 | 21 | import org.apache.spark.ml.attribute._ |
22 | | -import org.apache.spark.ml.linalg.Vectors |
| 22 | +import org.apache.spark.ml.linalg.{Vector, Vectors} |
23 | 23 | import org.apache.spark.ml.param.ParamsSuite |
24 | | -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} |
25 | | -import org.apache.spark.mllib.util.MLlibTestSparkContext |
| 24 | +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} |
| 25 | +import org.apache.spark.sql.{DataFrame, Encoder, Row} |
26 | 26 | import org.apache.spark.sql.types.DoubleType |
27 | 27 |
|
28 | | -class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { |
| 28 | +class RFormulaSuite extends MLTest with DefaultReadWriteTest { |
29 | 29 |
|
30 | 30 | import testImplicits._ |
31 | 31 |
|
@@ -548,4 +548,31 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul |
548 | 548 | assert(result3.collect() === expected3.collect()) |
549 | 549 | assert(result4.collect() === expected4.collect()) |
550 | 550 | } |
| 551 | + |
| 552 | + test("Use Vectors as inputs to formula.") { |
| 553 | + val original = Seq( |
| 554 | + (1, 4, Vectors.dense(0.0, 0.0, 4.0)), |
| 555 | + (2, 4, Vectors.dense(1.0, 0.0, 4.0)), |
| 556 | + (3, 5, Vectors.dense(1.0, 0.0, 5.0)), |
| 557 | + (4, 5, Vectors.dense(0.0, 1.0, 5.0)) |
| 558 | + ).toDF("id", "a", "b") |
| 559 | + val formula = new RFormula().setFormula("id ~ a + b") |
| 560 | + val (first +: rest) = Seq("id", "a", "b", "features", "label") |
| 561 | + testTransformer[(Int, Int, Vector)](original, formula.fit(original), first, rest: _*) { |
| 562 | + case Row(id: Int, a: Int, b: Vector, features: Vector, label: Double) => |
| 563 | + assert(label === id) |
| 564 | + assert(features.toArray === a +: b.toArray) |
| 565 | + } |
| 566 | + |
| 567 | + val group = new AttributeGroup("b", 3) |
| 568 | + val vectorColWithMetadata = original("b").as("b", group.toMetadata()) |
| 569 | + val dfWithMetadata = original.withColumn("b", vectorColWithMetadata) |
| 570 | + val model = formula.fit(dfWithMetadata) |
| 571 | + // model should work even when applied to dataframe without metadata. |
| 572 | + testTransformer[(Int, Int, Vector)](original, model, first, rest: _*) { |
| 573 | + case Row(id: Int, a: Int, b: Vector, features: Vector, label: Double) => |
| 574 | + assert(label === id) |
| 575 | + assert(features.toArray === a +: b.toArray) |
| 576 | + } |
| 577 | + } |
551 | 578 | } |
0 commit comments