Skip to content

Commit 8ddfa52

Browse files
ericlmengxr
authored andcommitted
[SPARK-9230] [ML] Support StringType features in RFormula
This adds StringType feature support via OneHotEncoder. As part of this task it was necessary to change RFormula to an Estimator, so that factor levels could be determined from the training dataset. Not sure if I am using uids correctly here, would be good to get reviewer help on that. cc mengxr Umbrella design doc: https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit# Author: Eric Liang <[email protected]> Closes apache#7574 from ericl/string-features and squashes the following commits: f99131a [Eric Liang] comments 0bf3c26 [Eric Liang] update docs c302a2c [Eric Liang] fix tests 9d1ac82 [Eric Liang] Merge remote-tracking branch 'upstream/master' into string-features e713da3 [Eric Liang] comments 4d79193 [Eric Liang] revert to seq + distinct 169a085 [Eric Liang] tweak functional test a230a47 [Eric Liang] Merge branch 'master' into string-features 72bd6f3 [Eric Liang] fix merge d841cec [Eric Liang] Merge branch 'master' into string-features 5b2c4a2 [Eric Liang] Mon Jul 20 18:45:33 PDT 2015 b01c7c5 [Eric Liang] add test 8a637db [Eric Liang] encoder wip a1d03f4 [Eric Liang] refactor into estimator
1 parent dafe8d8 commit 8ddfa52

File tree

4 files changed

+142
-62
lines changed

4 files changed

+142
-62
lines changed

R/pkg/inst/tests/test_mllib.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ test_that("glm and predict", {
3535

3636
test_that("predictions match with native glm", {
3737
training <- createDataFrame(sqlContext, iris)
38-
model <- glm(Sepal_Width ~ Sepal_Length, data = training)
38+
model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training)
3939
vals <- collect(select(predict(model, training), "prediction"))
40-
rVals <- predict(glm(Sepal.Width ~ Sepal.Length, data = iris), iris)
41-
expect_true(all(abs(rVals - vals) < 1e-9), rVals - vals)
40+
rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris)
41+
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
4242
})

mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala

Lines changed: 103 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,42 @@
1717

1818
package org.apache.spark.ml.feature
1919

20+
import scala.collection.mutable.ArrayBuffer
2021
import scala.util.parsing.combinator.RegexParsers
2122

2223
import org.apache.spark.annotation.Experimental
23-
import org.apache.spark.ml.Transformer
24+
import org.apache.spark.ml.{Estimator, Model, Transformer, Pipeline, PipelineModel, PipelineStage}
2425
import org.apache.spark.ml.param.{Param, ParamMap}
2526
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
2627
import org.apache.spark.ml.util.Identifiable
28+
import org.apache.spark.mllib.linalg.VectorUDT
2729
import org.apache.spark.sql.DataFrame
2830
import org.apache.spark.sql.functions._
2931
import org.apache.spark.sql.types._
3032

33+
/**
34+
* Base trait for [[RFormula]] and [[RFormulaModel]].
35+
*/
36+
private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol {
37+
/** @group getParam */
38+
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
39+
40+
/** @group getParam */
41+
def setLabelCol(value: String): this.type = set(labelCol, value)
42+
43+
protected def hasLabelCol(schema: StructType): Boolean = {
44+
schema.map(_.name).contains($(labelCol))
45+
}
46+
}
47+
3148
/**
3249
* :: Experimental ::
3350
* Implements the transforms required for fitting a dataset against an R model formula. Currently
3451
* we support a limited subset of the R operators, including '~' and '+'. Also see the R formula
3552
* docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
3653
*/
3754
@Experimental
38-
class RFormula(override val uid: String)
39-
extends Transformer with HasFeaturesCol with HasLabelCol {
55+
class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase {
4056

4157
def this() = this(Identifiable.randomUID("rFormula"))
4258

@@ -62,19 +78,74 @@ class RFormula(override val uid: String)
6278
/** @group getParam */
6379
def getFormula: String = $(formula)
6480

65-
/** @group getParam */
66-
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
81+
override def fit(dataset: DataFrame): RFormulaModel = {
82+
require(parsedFormula.isDefined, "Must call setFormula() first.")
83+
// StringType terms and terms representing interactions need to be encoded before assembly.
84+
// TODO(ekl) add support for feature interactions
85+
var encoderStages = ArrayBuffer[PipelineStage]()
86+
var tempColumns = ArrayBuffer[String]()
87+
val encodedTerms = parsedFormula.get.terms.map { term =>
88+
dataset.schema(term) match {
89+
case column if column.dataType == StringType =>
90+
val indexCol = term + "_idx_" + uid
91+
val encodedCol = term + "_onehot_" + uid
92+
encoderStages += new StringIndexer().setInputCol(term).setOutputCol(indexCol)
93+
encoderStages += new OneHotEncoder().setInputCol(indexCol).setOutputCol(encodedCol)
94+
tempColumns += indexCol
95+
tempColumns += encodedCol
96+
encodedCol
97+
case _ =>
98+
term
99+
}
100+
}
101+
encoderStages += new VectorAssembler(uid)
102+
.setInputCols(encodedTerms.toArray)
103+
.setOutputCol($(featuresCol))
104+
encoderStages += new ColumnPruner(tempColumns.toSet)
105+
val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset)
106+
copyValues(new RFormulaModel(uid, parsedFormula.get, pipelineModel).setParent(this))
107+
}
67108

68-
/** @group getParam */
69-
def setLabelCol(value: String): this.type = set(labelCol, value)
109+
// optimistic schema; does not contain any ML attributes
110+
override def transformSchema(schema: StructType): StructType = {
111+
if (hasLabelCol(schema)) {
112+
StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true))
113+
} else {
114+
StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true) :+
115+
StructField($(labelCol), DoubleType, true))
116+
}
117+
}
118+
119+
override def copy(extra: ParamMap): RFormula = defaultCopy(extra)
120+
121+
override def toString: String = s"RFormula(${get(formula)})"
122+
}
123+
124+
/**
125+
* :: Experimental ::
126+
* A fitted RFormula. Fitting is required to determine the factor levels of formula terms.
127+
* @param parsedFormula a pre-parsed R formula.
128+
* @param pipelineModel the fitted feature model, including factor to index mappings.
129+
*/
130+
@Experimental
131+
class RFormulaModel private[feature](
132+
override val uid: String,
133+
parsedFormula: ParsedRFormula,
134+
pipelineModel: PipelineModel)
135+
extends Model[RFormulaModel] with RFormulaBase {
136+
137+
override def transform(dataset: DataFrame): DataFrame = {
138+
checkCanTransform(dataset.schema)
139+
transformLabel(pipelineModel.transform(dataset))
140+
}
70141

71142
override def transformSchema(schema: StructType): StructType = {
72143
checkCanTransform(schema)
73-
val withFeatures = transformFeatures.transformSchema(schema)
144+
val withFeatures = pipelineModel.transformSchema(schema)
74145
if (hasLabelCol(schema)) {
75146
withFeatures
76-
} else if (schema.exists(_.name == parsedFormula.get.label)) {
77-
val nullable = schema(parsedFormula.get.label).dataType match {
147+
} else if (schema.exists(_.name == parsedFormula.label)) {
148+
val nullable = schema(parsedFormula.label).dataType match {
78149
case _: NumericType | BooleanType => false
79150
case _ => true
80151
}
@@ -86,24 +157,19 @@ class RFormula(override val uid: String)
86157
}
87158
}
88159

89-
override def transform(dataset: DataFrame): DataFrame = {
90-
checkCanTransform(dataset.schema)
91-
transformLabel(transformFeatures.transform(dataset))
92-
}
93-
94-
override def copy(extra: ParamMap): RFormula = defaultCopy(extra)
160+
override def copy(extra: ParamMap): RFormulaModel = copyValues(
161+
new RFormulaModel(uid, parsedFormula, pipelineModel))
95162

96-
override def toString: String = s"RFormula(${get(formula)})"
163+
override def toString: String = s"RFormulaModel(${parsedFormula})"
97164

98165
private def transformLabel(dataset: DataFrame): DataFrame = {
99-
val labelName = parsedFormula.get.label
166+
val labelName = parsedFormula.label
100167
if (hasLabelCol(dataset.schema)) {
101168
dataset
102169
} else if (dataset.schema.exists(_.name == labelName)) {
103170
dataset.schema(labelName).dataType match {
104171
case _: NumericType | BooleanType =>
105172
dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType))
106-
// TODO(ekl) add support for string-type labels
107173
case other =>
108174
throw new IllegalArgumentException("Unsupported type for label: " + other)
109175
}
@@ -114,25 +180,32 @@ class RFormula(override val uid: String)
114180
}
115181
}
116182

117-
private def transformFeatures: Transformer = {
118-
// TODO(ekl) add support for non-numeric features and feature interactions
119-
new VectorAssembler(uid)
120-
.setInputCols(parsedFormula.get.terms.toArray)
121-
.setOutputCol($(featuresCol))
122-
}
123-
124183
private def checkCanTransform(schema: StructType) {
125-
require(parsedFormula.isDefined, "Must call setFormula() first.")
126184
val columnNames = schema.map(_.name)
127185
require(!columnNames.contains($(featuresCol)), "Features column already exists.")
128186
require(
129187
!columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType,
130188
"Label column already exists and is not of type DoubleType.")
131189
}
190+
}
132191

133-
private def hasLabelCol(schema: StructType): Boolean = {
134-
schema.map(_.name).contains($(labelCol))
192+
/**
193+
* Utility transformer for removing temporary columns from a DataFrame.
194+
* TODO(ekl) make this a public transformer
195+
*/
196+
private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {
197+
override val uid = Identifiable.randomUID("columnPruner")
198+
199+
override def transform(dataset: DataFrame): DataFrame = {
200+
val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_))
201+
dataset.select(columnsToKeep.map(dataset.col) : _*)
135202
}
203+
204+
override def transformSchema(schema: StructType): StructType = {
205+
StructType(schema.fields.filter(col => !columnsToPrune.contains(col.name)))
206+
}
207+
208+
override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra)
136209
}
137210

138211
/**
@@ -149,7 +222,7 @@ private[ml] object RFormulaParser extends RegexParsers {
149222
def expr: Parser[List[String]] = term ~ rep("+" ~> term) ^^ { case a ~ list => a :: list }
150223

151224
def formula: Parser[ParsedRFormula] =
152-
(term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) }
225+
(term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t.distinct) }
153226

154227
def parse(value: String): ParsedRFormula = parseAll(formula, value) match {
155228
case Success(result, _) => result

mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class RFormulaParserSuite extends SparkFunSuite {
2828

2929
test("parse simple formulas") {
3030
checkParse("y ~ x", "y", Seq("x"))
31+
checkParse("y ~ x + x", "y", Seq("x"))
3132
checkParse("y ~ ._foo ", "y", Seq("._foo"))
3233
checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123"))
3334
}

mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -31,72 +31,78 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
3131
val formula = new RFormula().setFormula("id ~ v1 + v2")
3232
val original = sqlContext.createDataFrame(
3333
Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2")
34-
val result = formula.transform(original)
35-
val resultSchema = formula.transformSchema(original.schema)
34+
val model = formula.fit(original)
35+
val result = model.transform(original)
36+
val resultSchema = model.transformSchema(original.schema)
3637
val expected = sqlContext.createDataFrame(
3738
Seq(
38-
(0, 1.0, 3.0, Vectors.dense(Array(1.0, 3.0)), 0.0),
39-
(2, 2.0, 5.0, Vectors.dense(Array(2.0, 5.0)), 2.0))
39+
(0, 1.0, 3.0, Vectors.dense(1.0, 3.0), 0.0),
40+
(2, 2.0, 5.0, Vectors.dense(2.0, 5.0), 2.0))
4041
).toDF("id", "v1", "v2", "features", "label")
4142
// TODO(ekl) make schema comparisons ignore metadata, to avoid .toString
4243
assert(result.schema.toString == resultSchema.toString)
4344
assert(resultSchema == expected.schema)
44-
assert(result.collect().toSeq == expected.collect().toSeq)
45+
assert(result.collect() === expected.collect())
4546
}
4647

4748
test("features column already exists") {
4849
val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x")
4950
val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
5051
intercept[IllegalArgumentException] {
51-
formula.transformSchema(original.schema)
52+
formula.fit(original)
5253
}
5354
intercept[IllegalArgumentException] {
54-
formula.transform(original)
55+
formula.fit(original)
5556
}
5657
}
5758

5859
test("label column already exists") {
5960
val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
6061
val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
61-
val resultSchema = formula.transformSchema(original.schema)
62+
val model = formula.fit(original)
63+
val resultSchema = model.transformSchema(original.schema)
6264
assert(resultSchema.length == 3)
63-
assert(resultSchema.toString == formula.transform(original).schema.toString)
65+
assert(resultSchema.toString == model.transform(original).schema.toString)
6466
}
6567

6668
test("label column already exists but is not double type") {
6769
val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
6870
val original = sqlContext.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y")
71+
val model = formula.fit(original)
6972
intercept[IllegalArgumentException] {
70-
formula.transformSchema(original.schema)
73+
model.transformSchema(original.schema)
7174
}
7275
intercept[IllegalArgumentException] {
73-
formula.transform(original)
76+
model.transform(original)
7477
}
7578
}
7679

7780
test("allow missing label column for test datasets") {
7881
val formula = new RFormula().setFormula("y ~ x").setLabelCol("label")
7982
val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y")
80-
val resultSchema = formula.transformSchema(original.schema)
83+
val model = formula.fit(original)
84+
val resultSchema = model.transformSchema(original.schema)
8185
assert(resultSchema.length == 3)
8286
assert(!resultSchema.exists(_.name == "label"))
83-
assert(resultSchema.toString == formula.transform(original).schema.toString)
87+
assert(resultSchema.toString == model.transform(original).schema.toString)
8488
}
8589

86-
// TODO(ekl) enable after we implement string label support
87-
// test("transform string label") {
88-
// val formula = new RFormula().setFormula("name ~ id")
89-
// val original = sqlContext.createDataFrame(
90-
// Seq((1, "foo"), (2, "bar"), (3, "bar"))).toDF("id", "name")
91-
// val result = formula.transform(original)
92-
// val resultSchema = formula.transformSchema(original.schema)
93-
// val expected = sqlContext.createDataFrame(
94-
// Seq(
95-
// (1, "foo", Vectors.dense(Array(1.0)), 1.0),
96-
// (2, "bar", Vectors.dense(Array(2.0)), 0.0),
97-
// (3, "bar", Vectors.dense(Array(3.0)), 0.0))
98-
// ).toDF("id", "name", "features", "label")
99-
// assert(result.schema.toString == resultSchema.toString)
100-
// assert(result.collect().toSeq == expected.collect().toSeq)
101-
// }
90+
test("encodes string terms") {
91+
val formula = new RFormula().setFormula("id ~ a + b")
92+
val original = sqlContext.createDataFrame(
93+
Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
94+
).toDF("id", "a", "b")
95+
val model = formula.fit(original)
96+
val result = model.transform(original)
97+
val resultSchema = model.transformSchema(original.schema)
98+
val expected = sqlContext.createDataFrame(
99+
Seq(
100+
(1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
101+
(2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0),
102+
(3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0),
103+
(4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0))
104+
).toDF("id", "a", "b", "features", "label")
105+
assert(result.schema.toString == resultSchema.toString)
106+
assert(result.collect() === expected.collect())
107+
}
102108
}

0 commit comments

Comments
 (0)