Skip to content

Commit 09f0d60

Browse files
committed
Have RFormula include VectorSizeHint in pipeline.
1 parent 2250cb7 commit 09f0d60

File tree

2 files changed

+47
-8
lines changed

2 files changed

+47
-8
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.hadoop.fs.Path
2525
import org.apache.spark.annotation.{Experimental, Since}
2626
import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer}
2727
import org.apache.spark.ml.attribute.AttributeGroup
28-
import org.apache.spark.ml.linalg.VectorUDT
28+
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
2929
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators}
3030
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasHandleInvalid, HasLabelCol}
3131
import org.apache.spark.ml.util._
@@ -210,8 +210,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
210210

211211
// First we index each string column referenced by the input terms.
212212
val indexed: Map[String, String] = resolvedFormula.terms.flatten.distinct.map { term =>
213-
dataset.schema(term) match {
214-
case column if column.dataType == StringType =>
213+
dataset.schema(term).dataType match {
214+
case _: StringType =>
215215
val indexCol = tmpColumn("stridx")
216216
encoderStages += new StringIndexer()
217217
.setInputCol(term)
@@ -220,6 +220,18 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
220220
.setHandleInvalid($(handleInvalid))
221221
prefixesToRewrite(indexCol + "_") = term + "_"
222222
(term, indexCol)
223+
case _: VectorUDT =>
224+
val group = AttributeGroup.fromStructField(dataset.schema(term))
225+
val size = if (group.size < 0) {
226+
dataset.select(term).first().getAs[Vector](0).size
227+
} else {
228+
group.size
229+
}
230+
encoderStages += new VectorSizeHint(uid)
231+
.setHandleInvalid("optimistic")
232+
.setInputCol(term)
233+
.setSize(size)
234+
(term, term)
223235
case _ =>
224236
(term, term)
225237
}

mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@
1717

1818
package org.apache.spark.ml.feature
1919

20-
import org.apache.spark.{SparkException, SparkFunSuite}
20+
import org.apache.spark.SparkException
2121
import org.apache.spark.ml.attribute._
22-
import org.apache.spark.ml.linalg.Vectors
22+
import org.apache.spark.ml.linalg.{Vector, Vectors}
2323
import org.apache.spark.ml.param.ParamsSuite
24-
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
25-
import org.apache.spark.mllib.util.MLlibTestSparkContext
24+
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
25+
import org.apache.spark.sql.{DataFrame, Encoder, Row}
2626
import org.apache.spark.sql.types.DoubleType
2727

28-
class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
28+
class RFormulaSuite extends MLTest with DefaultReadWriteTest {
2929

3030
import testImplicits._
3131

@@ -548,4 +548,31 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
548548
assert(result3.collect() === expected3.collect())
549549
assert(result4.collect() === expected4.collect())
550550
}
551+
552+
test("Use Vectors as inputs to formula.") {
553+
val original = Seq(
554+
(1, 4, Vectors.dense(0.0, 0.0, 4.0)),
555+
(2, 4, Vectors.dense(1.0, 0.0, 4.0)),
556+
(3, 5, Vectors.dense(1.0, 0.0, 5.0)),
557+
(4, 5, Vectors.dense(0.0, 1.0, 5.0))
558+
).toDF("id", "a", "b")
559+
val formula = new RFormula().setFormula("id ~ a + b")
560+
val (first +: rest) = Seq("id", "a", "b", "features", "label")
561+
testTransformer[(Int, Int, Vector)](original, formula.fit(original), first, rest: _*) {
562+
case Row(id: Int, a: Int, b: Vector, features: Vector, label: Double) =>
563+
assert(label === id)
564+
assert(features.toArray === a +: b.toArray)
565+
}
566+
567+
val group = new AttributeGroup("b", 3)
568+
val vectorColWithMetadata = original("b").as("b", group.toMetadata())
569+
val dfWithMetadata = original.withColumn("b", vectorColWithMetadata)
570+
val model = formula.fit(dfWithMetadata)
571+
// model should work even when applied to dataframe without metadata.
572+
testTransformer[(Int, Int, Vector)](original, model, first, rest: _*) {
573+
case Row(id: Int, a: Int, b: Vector, features: Vector, label: Double) =>
574+
assert(label === id)
575+
assert(features.toArray === a +: b.toArray)
576+
}
577+
}
551578
}

0 commit comments

Comments
 (0)