Skip to content

Commit fb0826b

Browse files
committed
[SPARK-8774] Add R model formula with basic support as a transformer
1 parent c185f3a commit fb0826b

File tree

3 files changed

+196
-1
lines changed

3 files changed

+196
-1
lines changed
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
import scala.util.parsing.combinator.RegexParsers
21+
22+
import org.apache.spark.annotation.Experimental
23+
import org.apache.spark.ml.Transformer
24+
import org.apache.spark.ml.param.{Param, ParamMap}
25+
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
26+
import org.apache.spark.ml.util.Identifiable
27+
import org.apache.spark.sql.DataFrame
28+
import org.apache.spark.sql.functions._
29+
import org.apache.spark.sql.types._
30+
31+
/**
32+
* :: Experimental ::
33+
* Implements the transforms required for fitting a dataset against a R model formula.
34+
*/
35+
@Experimental
36+
private[spark] class RModelFormula(override val uid: String)
37+
extends Transformer with HasFeaturesCol with HasLabelCol {
38+
39+
def this() = this(Identifiable.randomUID("rModelFormula"))
40+
41+
val formula: Param[String] = new Param(this, "formula", "R model formula")
42+
protected var parsedFormula: Option[RFormula] = None
43+
44+
def setFormula(value: String): this.type = {
45+
parsedFormula = Some(RFormulaParser.parse(value))
46+
set(formula, value)
47+
this
48+
}
49+
50+
override def transformSchema(schema: StructType): StructType = {
51+
require(parsedFormula.isDefined, "Must call setFormula() first.")
52+
val withFeatures = featureTransformer.transformSchema(schema)
53+
val nullable = schema(parsedFormula.get.response).dataType match {
54+
case _: NumericType | BooleanType => false
55+
case _ => true
56+
}
57+
StructType(withFeatures.fields :+ StructField($(labelCol), DoubleType, nullable))
58+
}
59+
60+
override def transform(dataset: DataFrame): DataFrame = {
61+
require(parsedFormula.isDefined, "Must call setFormula() first.")
62+
transformLabel(featureTransformer.transform(dataset))
63+
}
64+
65+
override def copy(extra: ParamMap): RModelFormula = defaultCopy(extra)
66+
67+
override def toString = s"RModelFormula(${get(formula)})"
68+
69+
protected def transformLabel(dataset: DataFrame): DataFrame = {
70+
val responseName = parsedFormula.get.response
71+
dataset.schema(responseName).dataType match {
72+
case _: NumericType | BooleanType =>
73+
dataset.select(
74+
col("*"),
75+
dataset(responseName).cast(DoubleType).as($(labelCol)))
76+
case StringType =>
77+
new StringIndexer(uid)
78+
.setInputCol(responseName)
79+
.setOutputCol($(labelCol))
80+
.fit(dataset)
81+
.transform(dataset)
82+
case other =>
83+
throw new IllegalArgumentException("Unsupported type for response: " + other)
84+
}
85+
}
86+
87+
protected def featureTransformer: Transformer = {
88+
// TODO(ekl) add support for non-numeric features and feature interactions
89+
new VectorAssembler(uid)
90+
.setInputCols(parsedFormula.get.terms.toArray)
91+
.setOutputCol($(featuresCol))
92+
}
93+
}
94+
95+
/**
96+
* :: Experimental ::
97+
* Represents a parsed R formula.
98+
*/
99+
private[ml] case class RFormula(response: String, terms: Seq[String])
100+
101+
/**
102+
* :: Experimental ::
103+
* Limited implementation of R formula parsing. Currently supports: '~', '+'.
104+
*/
105+
private[ml] object RFormulaParser extends RegexParsers {
106+
def term: Parser[String] = "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r
107+
108+
def expr: Parser[List[String]] = term ~ rep("+" ~> term) ^^ { case a ~ list => a :: list }
109+
110+
def formula: Parser[RFormula] = (term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => RFormula(r, t) }
111+
112+
def parse(value: String): RFormula = parseAll(formula, value) match {
113+
case Success(result, _) => result
114+
case failure: NoSuccess => throw new IllegalArgumentException(
115+
"Could not parse formula: " + value)
116+
}
117+
}

mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class VectorAssembler(override val uid: String)
116116
if (schema.fieldNames.contains(outputColName)) {
117117
throw new IllegalArgumentException(s"Output column $outputColName already exists.")
118118
}
119-
StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, false))
119+
StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, true))
120120
}
121121

122122
override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra)
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.ml.param.ParamsSuite
22+
import org.apache.spark.mllib.linalg.Vectors
23+
import org.apache.spark.mllib.util.MLlibTestSparkContext
24+
import org.apache.spark.mllib.util.TestingUtils._
25+
26+
class RFormulaModelSuite extends SparkFunSuite with MLlibTestSparkContext {
27+
test("params") {
28+
ParamsSuite.checkParams(new RModelFormula())
29+
}
30+
31+
test("parse simple formulas") {
32+
def check(formula: String, response: String, terms: Seq[String]) {
33+
new RModelFormula().setFormula(formula)
34+
val parsed = RFormulaParser.parse(formula)
35+
assert(parsed.response == response)
36+
assert(parsed.terms == terms)
37+
}
38+
check("y ~ x", "y", Seq("x"))
39+
check("y ~ ._foo ", "y", Seq("._foo"))
40+
check("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123"))
41+
}
42+
43+
test("transform numeric data") {
44+
val formula = new RModelFormula().setFormula("id ~ v1 + v2")
45+
val original = sqlContext.createDataFrame(
46+
Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2")
47+
val result = formula.transform(original)
48+
val resultSchema = formula.transformSchema(original.schema)
49+
val expected = sqlContext.createDataFrame(
50+
Seq(
51+
(0, 1.0, 3.0, Vectors.dense(Array(1.0, 3.0)), 0.0),
52+
(2, 2.0, 5.0, Vectors.dense(Array(2.0, 5.0)), 2.0))
53+
).toDF("id", "v1", "v2", "features", "label")
54+
assert(result.schema.toString == resultSchema.toString)
55+
assert(resultSchema.toString == expected.schema.toString)
56+
assert(
57+
result.collect.map(_.toString).mkString(",") ==
58+
expected.collect.map(_.toString).mkString(","))
59+
}
60+
61+
test("transform string label") {
62+
val formula = new RModelFormula().setFormula("name ~ id")
63+
val original = sqlContext.createDataFrame(
64+
Seq((1, "foo"), (2, "bar"), (3, "bar"))).toDF("id", "name")
65+
val result = formula.transform(original)
66+
val resultSchema = formula.transformSchema(original.schema)
67+
val expected = sqlContext.createDataFrame(
68+
Seq(
69+
(1, "foo", Vectors.dense(Array(1.0)), 1.0),
70+
(2, "bar", Vectors.dense(Array(2.0)), 0.0),
71+
(3, "bar", Vectors.dense(Array(3.0)), 0.0))
72+
).toDF("id", "name", "features", "label")
73+
assert(result.schema.toString == resultSchema.toString)
74+
assert(
75+
result.collect.map(_.toString).mkString(",") ==
76+
expected.collect.map(_.toString).mkString(","))
77+
}
78+
}

0 commit comments

Comments
 (0)