Skip to content

Commit 4c11a77

Browse files
committed
small nits
1 parent 5f7cb9b commit 4c11a77

File tree

8 files changed

+57
-42
lines changed

8 files changed

+57
-42
lines changed

R/pkg/R/mllib.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ setClass("PipelineModel", representation(model = "jobj"))
2727
#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
2828
#'
2929
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
30-
#' operators are supported, including '~', '+', '-', and '.'.
30+
#' operators are supported, including '~', '.', ':', '+', and '-'.
3131
#' @param data DataFrame for training
3232
#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg.
3333
#' @param lambda Regularization parameter

R/pkg/inst/tests/test_mllib.R

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ test_that("dot minus and intercept vs native glm", {
4949
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
5050
})
5151

52+
test_that("feature interaction vs native glm", {
53+
training <- createDataFrame(sqlContext, iris)
54+
model <- glm(Sepal_Width ~ Species:Sepal_Length, data = training)
55+
vals <- collect(select(predict(model, training), "prediction"))
56+
rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris)
57+
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
58+
})
59+
5260
test_that("summary coefficients match with native glm", {
5361
training <- createDataFrame(sqlContext, iris)
5462
stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training))
@@ -57,5 +65,5 @@ test_that("summary coefficients match with native glm", {
5765
expect_true(all(abs(rCoefs - coefs) < 1e-6))
5866
expect_true(all(
5967
as.character(stats$features) ==
60-
c("(Intercept)", "Sepal_Length", "Species__versicolor", "Species__virginica")))
68+
c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica")))
6169
})

mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol {
4747
/**
4848
* :: Experimental ::
4949
* Implements the transforms required for fitting a dataset against an R model formula. Currently
50-
* we support a limited subset of the R operators, including '~' and '+'. Also see the R formula
51-
* docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
50+
* we support a limited subset of the R operators, including '~', '.', ':', '+', and '-'. Also see
51+
* the R formula docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
5252
*/
5353
@Experimental
5454
class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase {
@@ -86,6 +86,8 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
8686
val takenNames = mutable.Set(dataset.columns: _*)
8787
def encodeInteraction(terms: Seq[String]): String = {
8888
val outputCol = {
89+
// TODO(ekl) this column naming should be unnecessary since we generate the right attr
90+
// names in RInteraction, but the name is lost somewhere before VectorAssembler.
8991
var tmp = terms.mkString(":")
9092
while (takenNames.contains(tmp)) {
9193
tmp += "_"
@@ -99,7 +101,7 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
99101
tempColumns += outputCol
100102
outputCol
101103
}
102-
val encodedCols = resolvedFormula.terms.map {
104+
val encodedTerms = resolvedFormula.terms.map {
103105
case terms @ Seq(value) =>
104106
dataset.schema(value) match {
105107
case column if column.dataType == StringType =>
@@ -111,7 +113,7 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
111113
encodeInteraction(terms)
112114
}
113115
encoderStages += new VectorAssembler(uid)
114-
.setInputCols(encodedCols.toArray)
116+
.setInputCols(encodedTerms.toArray)
115117
.setOutputCol($(featuresCol))
116118
encoderStages += new ColumnPruner(tempColumns.toSet)
117119
val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset)

mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) {
3232
* of the special '.' term. Duplicate terms will be removed during resolution.
3333
*/
3434
def resolve(schema: StructType): ResolvedRFormula = {
35-
lazy val dotTerms = expandDot(schema)
35+
val dotTerms = expandDot(schema)
3636
var includedTerms = Seq[Seq[String]]()
3737
terms.foreach {
3838
case term: ColumnRef =>
@@ -80,29 +80,30 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) {
8080
private def expandInteraction(
8181
schema: StructType, terms: Seq[InteractionComponent]): Seq[Seq[String]] = {
8282
if (terms.isEmpty) {
83-
Seq(Nil)
84-
} else {
85-
val rest = expandInteraction(schema, terms.tail)
86-
val validInteractions = (terms.head match {
87-
case Dot =>
88-
expandDot(schema).filter(_ != label.value).flatMap { t =>
89-
rest.map { r =>
90-
Seq(t) ++ r
91-
}
92-
}
93-
case ColumnRef(value) =>
94-
rest.map(Seq(value) ++ _)
95-
}).map(_.distinct)
96-
// Deduplicates feature interactions, for example, a:b is the same as b:a.
97-
var seen = mutable.Set[Set[String]]()
98-
validInteractions.flatMap {
99-
case t if seen.contains(t.toSet) =>
100-
None
101-
case t =>
102-
seen += t.toSet
103-
Some(t)
104-
}.sortBy(_.length)
83+
return Seq(Nil)
10584
}
85+
86+
val rest = expandInteraction(schema, terms.tail)
87+
val validInteractions = (terms.head match {
88+
case Dot =>
89+
expandDot(schema).filter(_ != label.value).flatMap { t =>
90+
rest.map { r =>
91+
Seq(t) ++ r
92+
}
93+
}
94+
case ColumnRef(value) =>
95+
rest.map(Seq(value) ++ _)
96+
}).map(_.distinct)
97+
98+
// Deduplicates feature interactions, for example, a:b is the same as b:a.
99+
var seen = mutable.Set[Set[String]]()
100+
validInteractions.flatMap {
101+
case t if seen.contains(t.toSet) =>
102+
None
103+
case t =>
104+
seen += t.toSet
105+
Some(t)
106+
}.sortBy(_.length)
106107
}
107108

108109
// the dot operator excludes complex column types
@@ -116,6 +117,9 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) {
116117

117118
/**
118119
* Represents a fully evaluated and simplified R formula.
120+
* @param label the column name of the R formula label (response variable).
121+
* @param terms the simplified terms of the R formula. Interactions terms are represented as Seqs
122+
* of column names; non-interaction terms as length 1 Seqs.
119123
*/
120124
private[ml] case class ResolvedRFormula(label: String, terms: Seq[Seq[String]])
121125

mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ import org.apache.spark.sql.types._
4141
* See https://stat.ethz.ch/R-manual/R-devel/library/base/html/formula.html for more
4242
* information about factor interactions in R formulae.
4343
*/
44+
// TODO(ekl) it might be nice to have standalone tests for RInteraction.
4445
@Experimental
4546
class RInteraction(override val uid: String) extends Estimator[PipelineModel]
4647
with HasInputCols with HasOutputCol {
@@ -127,8 +128,8 @@ class RInteraction(override val uid: String) extends Estimator[PipelineModel]
127128
}
128129

129130
/**
130-
* This helper class combines the output of multiple string-indexed columns to simulate
131-
* the joint indexing of tuples containing all the column values.
131+
* This helper class computes the joint index of multiple string-indexed columns such that the
132+
* combined index covers the cartesian product of column values.
132133
*/
133134
private class IndexCombiner(
134135
inputCols: Array[String], attrNames: Array[String], outputCol: String)

mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class RFormulaParserSuite extends SparkFunSuite {
2525
formula: String,
2626
label: String,
2727
terms: Seq[String],
28-
schema: StructType = null) {
28+
schema: StructType = new StructType) {
2929
val resolved = RFormulaParser.parse(formula).resolve(schema)
3030
assert(resolved.label == label)
3131
val simpleTerms = terms.map { t =>

mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -125,26 +125,26 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
125125
}
126126

127127
test("numeric interaction") {
128-
val formula = new RFormula().setFormula("a ~ b:c")
128+
val formula = new RFormula().setFormula("a ~ b:c:d")
129129
val original = sqlContext.createDataFrame(
130-
Seq((1, 2, 4), (2, 3, 4))
131-
).toDF("a", "b", "c")
130+
Seq((1, 2, 4, 2), (2, 3, 4, 1))
131+
).toDF("a", "b", "c", "d")
132132
val model = formula.fit(original)
133133
val result = model.transform(original)
134134
val expected = sqlContext.createDataFrame(
135135
Seq(
136-
(1, 2, 4, Vectors.dense(8.0), 1.0),
137-
(2, 3, 4, Vectors.dense(12.0), 2.0))
138-
).toDF("id", "a", "b", "features", "label")
136+
(1, 2, 4, 2, Vectors.dense(16.0), 1.0),
137+
(2, 3, 4, 1, Vectors.dense(12.0), 2.0))
138+
).toDF("a", "b", "c", "d", "features", "label")
139139
assert(result.collect() === expected.collect())
140140
val attrs = AttributeGroup.fromStructField(result.schema("features"))
141141
val expectedAttrs = new AttributeGroup(
142142
"features",
143-
Array[Attribute](new NumericAttribute(Some("b:c"), Some(1))))
143+
Array[Attribute](new NumericAttribute(Some("b:c:d"), Some(1))))
144144
assert(attrs === expectedAttrs)
145145
}
146146

147-
test("numeric:factor interaction") {
147+
test("factor numeric interaction") {
148148
val formula = new RFormula().setFormula("id ~ a:b")
149149
val original = sqlContext.createDataFrame(
150150
Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, "baz", 5), (4, "baz", 5))
@@ -171,7 +171,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
171171
assert(attrs === expectedAttrs)
172172
}
173173

174-
test("factor:factor interaction") {
174+
test("factor factor interaction") {
175175
val formula = new RFormula().setFormula("id ~ a:b")
176176
val original = sqlContext.createDataFrame(
177177
Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz"))

python/pyspark/ml/feature.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1117,7 +1117,7 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol):
11171117
11181118
Implements the transforms required for fitting a dataset against an
11191119
R model formula. Currently we support a limited subset of the R
1120-
operators, including '~', '+', '-', and '.'. Also see the R formula
1120+
operators, including '~', '.', ':', '+', and '-'. Also see the R formula
11211121
docs:
11221122
http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
11231123

0 commit comments

Comments
 (0)