Skip to content

Commit 2db68aa

Browse files
committed
second round of comments
1 parent dc3c943 commit 2db68aa

File tree

4 files changed

+176
-112
lines changed

4 files changed

+176
-112
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/RModelFormula.scala renamed to mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala

Lines changed: 49 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -32,21 +32,21 @@ import org.apache.spark.sql.types._
3232
* :: Experimental ::
3333
* Implements the transforms required for fitting a dataset against an R model formula. Currently
3434
* we support a limited subset of the R operators, including '~' and '+'. Also see the R formula
35-
* docs here: http://www.inside-r.org/r-doc/stats/formula
35+
* docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
3636
*/
3737
@Experimental
38-
class RModelFormula(override val uid: String)
38+
class RFormula(override val uid: String)
3939
extends Transformer with HasFeaturesCol with HasLabelCol {
4040

41-
def this() = this(Identifiable.randomUID("rModelFormula"))
41+
def this() = this(Identifiable.randomUID("rFormula"))
4242

4343
/**
4444
* R formula parameter. The formula is provided in string form.
4545
* @group setParam
4646
*/
4747
val formula: Param[String] = new Param(this, "formula", "R model formula")
4848

49-
private var parsedFormula: Option[RFormula] = None
49+
private var parsedFormula: Option[ParsedRFormula] = None
5050

5151
/**
5252
* Sets the formula to use for this transformer. Must be called before use.
@@ -63,60 +63,74 @@ class RModelFormula(override val uid: String)
6363
def getFormula: String = $(formula)
6464

6565
/** @group getParam */
66-
def setFeaturesCol(col: String): this.type = set(featuresCol, col)
66+
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
6767

6868
/** @group getParam */
69-
def setLabelCol(col: String): this.type = set(labelCol, col)
69+
def setLabelCol(value: String): this.type = set(labelCol, value)
7070

7171
override def transformSchema(schema: StructType): StructType = {
72-
require(parsedFormula.isDefined, "Must call setFormula() first.")
73-
val withFeatures = featureTransformer.transformSchema(schema)
74-
val nullable = schema(parsedFormula.get.response).dataType match {
75-
case _: NumericType | BooleanType => false
76-
case _ => true
72+
checkCanTransform(schema)
73+
val withFeatures = transformFeatures.transformSchema(schema)
74+
if (hasLabelCol(schema)) {
75+
withFeatures
76+
} else {
77+
val nullable = schema(parsedFormula.get.label).dataType match {
78+
case _: NumericType | BooleanType => false
79+
case _ => true
80+
}
81+
StructType(withFeatures.fields :+ StructField($(labelCol), DoubleType, nullable))
7782
}
78-
StructType(withFeatures.fields :+ StructField($(labelCol), DoubleType, nullable))
7983
}
8084

8185
override def transform(dataset: DataFrame): DataFrame = {
82-
require(parsedFormula.isDefined, "Must call setFormula() first.")
83-
transformLabel(featureTransformer.transform(dataset))
86+
checkCanTransform(dataset.schema)
87+
transformLabel(transformFeatures.transform(dataset))
8488
}
8589

86-
override def copy(extra: ParamMap): RModelFormula = defaultCopy(extra)
90+
override def copy(extra: ParamMap): RFormula = defaultCopy(extra)
8791

88-
override def toString: String = s"RModelFormula(${get(formula)})"
92+
override def toString: String = s"RFormula(${get(formula)})"
8993

9094
private def transformLabel(dataset: DataFrame): DataFrame = {
91-
val responseName = parsedFormula.get.response
92-
dataset.schema(responseName).dataType match {
93-
case _: NumericType | BooleanType =>
94-
dataset.select(
95-
col("*"),
96-
dataset(responseName).cast(DoubleType).as($(labelCol)))
97-
case StringType =>
98-
new StringIndexer()
99-
.setInputCol(responseName)
100-
.setOutputCol($(labelCol))
101-
.fit(dataset)
102-
.transform(dataset)
103-
case other =>
104-
throw new IllegalArgumentException("Unsupported type for response: " + other)
95+
if (hasLabelCol(dataset.schema)) {
96+
dataset
97+
} else {
98+
val labelName = parsedFormula.get.label
99+
dataset.schema(labelName).dataType match {
100+
case _: NumericType | BooleanType =>
101+
dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType))
102+
// TODO(ekl) add support for string-type labels
103+
case other =>
104+
throw new IllegalArgumentException("Unsupported type for label: " + other)
105+
}
105106
}
106107
}
107108

108-
private def featureTransformer: Transformer = {
109+
private def transformFeatures: Transformer = {
109110
// TODO(ekl) add support for non-numeric features and feature interactions
110111
new VectorAssembler(uid)
111112
.setInputCols(parsedFormula.get.terms.toArray)
112113
.setOutputCol($(featuresCol))
113114
}
115+
116+
private def checkCanTransform(schema: StructType) {
117+
require(parsedFormula.isDefined, "Must call setFormula() first.")
118+
val columnNames = schema.map(_.name)
119+
require(!columnNames.contains($(featuresCol)), "Features column already exists.")
120+
require(
121+
!columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType,
122+
"Label column already exists and is not of type DoubleType.")
123+
}
124+
125+
private def hasLabelCol(schema: StructType): Boolean = {
126+
schema.map(_.name).contains($(labelCol))
127+
}
114128
}
115129

116130
/**
117131
* Represents a parsed R formula.
118132
*/
119-
private[ml] case class RFormula(response: String, terms: Seq[String])
133+
private[ml] case class ParsedRFormula(label: String, terms: Seq[String])
120134

121135
/**
122136
* Limited implementation of R formula parsing. Currently supports: '~', '+'.
@@ -126,9 +140,10 @@ private[ml] object RFormulaParser extends RegexParsers {
126140

127141
def expr: Parser[List[String]] = term ~ rep("+" ~> term) ^^ { case a ~ list => a :: list }
128142

129-
def formula: Parser[RFormula] = (term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => RFormula(r, t) }
143+
def formula: Parser[ParsedRFormula] =
144+
(term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) }
130145

131-
def parse(value: String): RFormula = parseAll(formula, value) match {
146+
def parse(value: String): ParsedRFormula = parseAll(formula, value) match {
132147
case Success(result, _) => result
133148
case failure: NoSuccess => throw new IllegalArgumentException(
134149
"Could not parse formula: " + value)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
import org.apache.spark.SparkFunSuite
21+
22+
class RFormulaParserSuite extends SparkFunSuite {
23+
private def checkParse(formula: String, label: String, terms: Seq[String]) {
24+
val parsed = RFormulaParser.parse(formula)
25+
assert(parsed.label == label)
26+
assert(parsed.terms == terms)
27+
}
28+
29+
test("parse simple formulas") {
30+
checkParse("y ~ x", "y", Seq("x"))
31+
checkParse("y ~ ._foo ", "y", Seq("._foo"))
32+
checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123"))
33+
}
34+
}
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.ml.param.ParamsSuite
22+
import org.apache.spark.mllib.linalg.Vectors
23+
import org.apache.spark.mllib.util.MLlibTestSparkContext
24+
25+
class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
26+
test("params") {
27+
ParamsSuite.checkParams(new RFormula())
28+
}
29+
30+
test("transform numeric data") {
31+
val formula = new RFormula().setFormula("id ~ v1 + v2")
32+
val original = sqlContext.createDataFrame(
33+
Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2")
34+
val result = formula.transform(original)
35+
val resultSchema = formula.transformSchema(original.schema)
36+
val expected = sqlContext.createDataFrame(
37+
Seq(
38+
(0, 1.0, 3.0, Vectors.dense(Array(1.0, 3.0)), 0.0),
39+
(2, 2.0, 5.0, Vectors.dense(Array(2.0, 5.0)), 2.0))
40+
).toDF("id", "v1", "v2", "features", "label")
41+
// TODO(ekl) make schema comparisons check metadata, to avoid .toString
42+
assert(result.schema.toString == resultSchema.toString)
43+
assert(resultSchema == expected.schema)
44+
assert(result.collect().toSeq == expected.collect().toSeq)
45+
}
46+
47+
test("features column already exists") {
48+
val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x")
49+
val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
50+
intercept[IllegalArgumentException] {
51+
formula.transformSchema(original.schema)
52+
}
53+
intercept[IllegalArgumentException] {
54+
formula.transform(original)
55+
}
56+
}
57+
58+
test("label column already exists") {
59+
val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
60+
val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
61+
val resultSchema = formula.transformSchema(original.schema)
62+
assert(resultSchema.length == 3)
63+
assert(resultSchema.toString == formula.transform(original).schema.toString)
64+
}
65+
66+
test("label column already exists but is not double type") {
67+
val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
68+
val original = sqlContext.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y")
69+
intercept[IllegalArgumentException] {
70+
formula.transformSchema(original.schema)
71+
}
72+
intercept[IllegalArgumentException] {
73+
formula.transform(original)
74+
}
75+
}
76+
77+
// TODO(ekl) enable after we implement string label support
78+
// test("transform string label") {
79+
// val formula = new RFormula().setFormula("name ~ id")
80+
// val original = sqlContext.createDataFrame(
81+
// Seq((1, "foo"), (2, "bar"), (3, "bar"))).toDF("id", "name")
82+
// val result = formula.transform(original)
83+
// val resultSchema = formula.transformSchema(original.schema)
84+
// val expected = sqlContext.createDataFrame(
85+
// Seq(
86+
// (1, "foo", Vectors.dense(Array(1.0)), 1.0),
87+
// (2, "bar", Vectors.dense(Array(2.0)), 0.0),
88+
// (3, "bar", Vectors.dense(Array(3.0)), 0.0))
89+
// ).toDF("id", "name", "features", "label")
90+
// assert(result.schema.toString == resultSchema.toString)
91+
// assert(result.collect().toSeq == expected.collect().toSeq)
92+
// }
93+
}

mllib/src/test/scala/org/apache/spark/ml/feature/RModelFormulaSuite.scala

Lines changed: 0 additions & 78 deletions
This file was deleted.

0 commit comments

Comments
 (0)