Skip to content

Commit 6960a79

Browse files
ericlmengxr
authored andcommitted
[SPARK-8774] [ML] Add R model formula with basic support as a transformer
This implements minimal R formula support as a feature transformer. Both numeric and string labels are supported, but features must be numeric for now. cc mengxr Author: Eric Liang <[email protected]> Closes #7381 from ericl/spark-8774-1 and squashes the following commits: d1959d2 [Eric Liang] clarify comment 2db68aa [Eric Liang] second round of comments dc3c943 [Eric Liang] address comments 5765ec6 [Eric Liang] fix style checks 1f361b0 [Eric Liang] doc fb0826b [Eric Liang] [SPARK-8774] Add R model formula with basic support as a transformer
1 parent b064519 commit 6960a79

File tree

4 files changed

+279
-1
lines changed

4 files changed

+279
-1
lines changed
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
import scala.util.parsing.combinator.RegexParsers
21+
22+
import org.apache.spark.annotation.Experimental
23+
import org.apache.spark.ml.Transformer
24+
import org.apache.spark.ml.param.{Param, ParamMap}
25+
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
26+
import org.apache.spark.ml.util.Identifiable
27+
import org.apache.spark.sql.DataFrame
28+
import org.apache.spark.sql.functions._
29+
import org.apache.spark.sql.types._
30+
31+
/**
32+
* :: Experimental ::
33+
* Implements the transforms required for fitting a dataset against an R model formula. Currently
34+
* we support a limited subset of the R operators, including '~' and '+'. Also see the R formula
35+
* docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
36+
*/
37+
@Experimental
38+
class RFormula(override val uid: String)
39+
extends Transformer with HasFeaturesCol with HasLabelCol {
40+
41+
def this() = this(Identifiable.randomUID("rFormula"))
42+
43+
/**
44+
* R formula parameter. The formula is provided in string form.
45+
* @group setParam
46+
*/
47+
val formula: Param[String] = new Param(this, "formula", "R model formula")
48+
49+
private var parsedFormula: Option[ParsedRFormula] = None
50+
51+
/**
52+
* Sets the formula to use for this transformer. Must be called before use.
53+
* @group setParam
54+
* @param value an R formula in string form (e.g. "y ~ x + z")
55+
*/
56+
def setFormula(value: String): this.type = {
57+
parsedFormula = Some(RFormulaParser.parse(value))
58+
set(formula, value)
59+
this
60+
}
61+
62+
/** @group getParam */
63+
def getFormula: String = $(formula)
64+
65+
/** @group getParam */
66+
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
67+
68+
/** @group getParam */
69+
def setLabelCol(value: String): this.type = set(labelCol, value)
70+
71+
override def transformSchema(schema: StructType): StructType = {
72+
checkCanTransform(schema)
73+
val withFeatures = transformFeatures.transformSchema(schema)
74+
if (hasLabelCol(schema)) {
75+
withFeatures
76+
} else {
77+
val nullable = schema(parsedFormula.get.label).dataType match {
78+
case _: NumericType | BooleanType => false
79+
case _ => true
80+
}
81+
StructType(withFeatures.fields :+ StructField($(labelCol), DoubleType, nullable))
82+
}
83+
}
84+
85+
override def transform(dataset: DataFrame): DataFrame = {
86+
checkCanTransform(dataset.schema)
87+
transformLabel(transformFeatures.transform(dataset))
88+
}
89+
90+
override def copy(extra: ParamMap): RFormula = defaultCopy(extra)
91+
92+
override def toString: String = s"RFormula(${get(formula)})"
93+
94+
private def transformLabel(dataset: DataFrame): DataFrame = {
95+
if (hasLabelCol(dataset.schema)) {
96+
dataset
97+
} else {
98+
val labelName = parsedFormula.get.label
99+
dataset.schema(labelName).dataType match {
100+
case _: NumericType | BooleanType =>
101+
dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType))
102+
// TODO(ekl) add support for string-type labels
103+
case other =>
104+
throw new IllegalArgumentException("Unsupported type for label: " + other)
105+
}
106+
}
107+
}
108+
109+
private def transformFeatures: Transformer = {
110+
// TODO(ekl) add support for non-numeric features and feature interactions
111+
new VectorAssembler(uid)
112+
.setInputCols(parsedFormula.get.terms.toArray)
113+
.setOutputCol($(featuresCol))
114+
}
115+
116+
private def checkCanTransform(schema: StructType) {
117+
require(parsedFormula.isDefined, "Must call setFormula() first.")
118+
val columnNames = schema.map(_.name)
119+
require(!columnNames.contains($(featuresCol)), "Features column already exists.")
120+
require(
121+
!columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType,
122+
"Label column already exists and is not of type DoubleType.")
123+
}
124+
125+
private def hasLabelCol(schema: StructType): Boolean = {
126+
schema.map(_.name).contains($(labelCol))
127+
}
128+
}
129+
130+
/**
131+
* Represents a parsed R formula.
132+
*/
133+
private[ml] case class ParsedRFormula(label: String, terms: Seq[String])
134+
135+
/**
136+
* Limited implementation of R formula parsing. Currently supports: '~', '+'.
137+
*/
138+
private[ml] object RFormulaParser extends RegexParsers {
139+
def term: Parser[String] = "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r
140+
141+
def expr: Parser[List[String]] = term ~ rep("+" ~> term) ^^ { case a ~ list => a :: list }
142+
143+
def formula: Parser[ParsedRFormula] =
144+
(term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) }
145+
146+
def parse(value: String): ParsedRFormula = parseAll(formula, value) match {
147+
case Success(result, _) => result
148+
case failure: NoSuccess => throw new IllegalArgumentException(
149+
"Could not parse formula: " + value)
150+
}
151+
}

mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class VectorAssembler(override val uid: String)
116116
if (schema.fieldNames.contains(outputColName)) {
117117
throw new IllegalArgumentException(s"Output column $outputColName already exists.")
118118
}
119-
StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, false))
119+
StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, true))
120120
}
121121

122122
override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
import org.apache.spark.SparkFunSuite
21+
22+
class RFormulaParserSuite extends SparkFunSuite {
23+
private def checkParse(formula: String, label: String, terms: Seq[String]) {
24+
val parsed = RFormulaParser.parse(formula)
25+
assert(parsed.label == label)
26+
assert(parsed.terms == terms)
27+
}
28+
29+
test("parse simple formulas") {
30+
checkParse("y ~ x", "y", Seq("x"))
31+
checkParse("y ~ ._foo ", "y", Seq("._foo"))
32+
checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123"))
33+
}
34+
}
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.ml.param.ParamsSuite
22+
import org.apache.spark.mllib.linalg.Vectors
23+
import org.apache.spark.mllib.util.MLlibTestSparkContext
24+
25+
class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
26+
test("params") {
27+
ParamsSuite.checkParams(new RFormula())
28+
}
29+
30+
test("transform numeric data") {
31+
val formula = new RFormula().setFormula("id ~ v1 + v2")
32+
val original = sqlContext.createDataFrame(
33+
Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2")
34+
val result = formula.transform(original)
35+
val resultSchema = formula.transformSchema(original.schema)
36+
val expected = sqlContext.createDataFrame(
37+
Seq(
38+
(0, 1.0, 3.0, Vectors.dense(Array(1.0, 3.0)), 0.0),
39+
(2, 2.0, 5.0, Vectors.dense(Array(2.0, 5.0)), 2.0))
40+
).toDF("id", "v1", "v2", "features", "label")
41+
// TODO(ekl) make schema comparisons ignore metadata, to avoid .toString
42+
assert(result.schema.toString == resultSchema.toString)
43+
assert(resultSchema == expected.schema)
44+
assert(result.collect().toSeq == expected.collect().toSeq)
45+
}
46+
47+
test("features column already exists") {
48+
val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x")
49+
val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
50+
intercept[IllegalArgumentException] {
51+
formula.transformSchema(original.schema)
52+
}
53+
intercept[IllegalArgumentException] {
54+
formula.transform(original)
55+
}
56+
}
57+
58+
test("label column already exists") {
59+
val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
60+
val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
61+
val resultSchema = formula.transformSchema(original.schema)
62+
assert(resultSchema.length == 3)
63+
assert(resultSchema.toString == formula.transform(original).schema.toString)
64+
}
65+
66+
test("label column already exists but is not double type") {
67+
val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
68+
val original = sqlContext.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y")
69+
intercept[IllegalArgumentException] {
70+
formula.transformSchema(original.schema)
71+
}
72+
intercept[IllegalArgumentException] {
73+
formula.transform(original)
74+
}
75+
}
76+
77+
// TODO(ekl) enable after we implement string label support
78+
// test("transform string label") {
79+
// val formula = new RFormula().setFormula("name ~ id")
80+
// val original = sqlContext.createDataFrame(
81+
// Seq((1, "foo"), (2, "bar"), (3, "bar"))).toDF("id", "name")
82+
// val result = formula.transform(original)
83+
// val resultSchema = formula.transformSchema(original.schema)
84+
// val expected = sqlContext.createDataFrame(
85+
// Seq(
86+
// (1, "foo", Vectors.dense(Array(1.0)), 1.0),
87+
// (2, "bar", Vectors.dense(Array(2.0)), 0.0),
88+
// (3, "bar", Vectors.dense(Array(3.0)), 0.0))
89+
// ).toDF("id", "name", "features", "label")
90+
// assert(result.schema.toString == resultSchema.toString)
91+
// assert(result.collect().toSeq == expected.collect().toSeq)
92+
// }
93+
}

0 commit comments

Comments
 (0)