Skip to content

Commit 26b6925

Browse files
committed
Revert user-facing R changes
1 parent 3816477 commit 26b6925

File tree

9 files changed

+63
-271
lines changed

9 files changed

+63
-271
lines changed

R/pkg/R/mllib.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ setClass("PipelineModel", representation(model = "jobj"))
2727
#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
2828
#'
2929
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
30-
#' operators are supported, including '~', '.', ':', '+', and '-'.
30+
#' operators are supported, including '~', '+', '-', and '.'.
3131
#' @param data DataFrame for training
3232
#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg.
3333
#' @param lambda Regularization parameter

R/pkg/inst/tests/test_mllib.R

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,6 @@ test_that("dot minus and intercept vs native glm", {
4949
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
5050
})
5151

52-
test_that("feature interaction vs native glm", {
53-
training <- createDataFrame(sqlContext, iris)
54-
model <- glm(Sepal_Width ~ Species:Sepal_Length, data = training)
55-
vals <- collect(select(predict(model, training), "prediction"))
56-
rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris)
57-
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
58-
})
59-
6052
test_that("summary coefficients match with native glm", {
6153
training <- createDataFrame(sqlContext, iris)
6254
stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training))
@@ -65,5 +57,5 @@ test_that("summary coefficients match with native glm", {
6557
expect_true(all(abs(rCoefs - coefs) < 1e-6))
6658
expect_true(all(
6759
as.character(stats$features) ==
68-
c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica")))
60+
c("(Intercept)", "Sepal_Length", "Species__versicolor", "Species__virginica")))
6961
})

mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol {
4747
/**
4848
* :: Experimental ::
4949
* Implements the transforms required for fitting a dataset against an R model formula. Currently
50-
* we support a limited subset of the R operators, including '~', '.', ':', '+', and '-'. Also see
51-
* the R formula docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
50+
* we support a limited subset of the R operators, including '~' and '+'. Also see the R formula
51+
* docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
5252
*/
5353
@Experimental
5454
class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase {
@@ -81,26 +81,32 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
8181
require(isDefined(formula), "Formula must be defined first.")
8282
val parsedFormula = RFormulaParser.parse($(formula))
8383
val resolvedFormula = parsedFormula.resolve(dataset.schema)
84+
// StringType terms and terms representing interactions need to be encoded before assembly.
85+
// TODO(ekl) add support for feature interactions
8486
val encoderStages = ArrayBuffer[PipelineStage]()
8587
val tempColumns = ArrayBuffer[String]()
86-
def encodeInteraction(terms: Seq[String]): String = {
87-
val outputCol = "interaction_" + uid + "_" + terms.mkString(":")
88-
encoderStages += new RInteraction()
89-
.setInputCols(terms.toArray)
90-
.setOutputCol(outputCol)
91-
tempColumns += outputCol
92-
outputCol
93-
}
94-
val encodedTerms = resolvedFormula.terms.map {
95-
case terms @ Seq(value) =>
96-
dataset.schema(value) match {
97-
case column if column.dataType == StringType =>
98-
encodeInteraction(terms)
99-
case _ =>
100-
value
101-
}
102-
case terms =>
103-
encodeInteraction(terms)
88+
val takenNames = mutable.Set(dataset.columns: _*)
89+
val encodedTerms = resolvedFormula.terms.map { term =>
90+
dataset.schema(term) match {
91+
case column if column.dataType == StringType =>
92+
val indexCol = term + "_idx_" + uid
93+
val encodedCol = {
94+
var tmp = term
95+
while (takenNames.contains(tmp)) {
96+
tmp += "_"
97+
}
98+
tmp
99+
}
100+
takenNames.add(indexCol)
101+
takenNames.add(encodedCol)
102+
encoderStages += new StringIndexer().setInputCol(term).setOutputCol(indexCol)
103+
encoderStages += new OneHotEncoder().setInputCol(indexCol).setOutputCol(encodedCol)
104+
tempColumns += indexCol
105+
tempColumns += encodedCol
106+
encodedCol
107+
case _ =>
108+
term
109+
}
104110
}
105111
encoderStages += new VectorAssembler(uid)
106112
.setInputCols(encodedTerms.toArray)
@@ -197,7 +203,7 @@ class RFormulaModel private[feature](
197203
* Utility transformer for removing temporary columns from a DataFrame.
198204
* TODO(ekl) make this a public transformer
199205
*/
200-
private[feature] class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {
206+
private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {
201207
override val uid = Identifiable.randomUID("columnPruner")
202208

203209
override def transform(dataset: DataFrame): DataFrame = {

mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala

Lines changed: 14 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark.ml.feature
1919

20-
import scala.collection.mutable
2120
import scala.util.parsing.combinator.RegexParsers
2221

2322
import org.apache.spark.mllib.linalg.VectorUDT
@@ -32,28 +31,20 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) {
3231
* of the special '.' term. Duplicate terms will be removed during resolution.
3332
*/
3433
def resolve(schema: StructType): ResolvedRFormula = {
35-
val dotTerms = expandDot(schema)
36-
var includedTerms = Seq[Seq[String]]()
34+
var includedTerms = Seq[String]()
3735
terms.foreach {
38-
case term: ColumnRef =>
39-
includedTerms :+= Seq(term.value)
40-
case ColumnInteraction(terms) =>
41-
includedTerms ++= expandInteraction(schema, terms)
4236
case Dot =>
43-
includedTerms ++= dotTerms.map(Seq(_))
37+
includedTerms ++= simpleTypes(schema).filter(_ != label.value)
38+
case ColumnRef(value) =>
39+
includedTerms :+= value
4440
case Deletion(term: Term) =>
4541
term match {
46-
case inner: ColumnRef =>
47-
includedTerms = includedTerms.filter(_ != Seq(inner.value))
48-
case ColumnInteraction(terms) =>
49-
val fromInteraction = expandInteraction(schema, terms).map(_.toSet)
50-
includedTerms = includedTerms.filter(t => !fromInteraction.contains(t.toSet))
42+
case ColumnRef(value) =>
43+
includedTerms = includedTerms.filter(_ != value)
5144
case Dot =>
5245
// e.g. "- .", which removes all first-order terms
53-
includedTerms = includedTerms.filter {
54-
case Seq(t) => !dotTerms.contains(t)
55-
case _ => true
56-
}
46+
val fromSchema = simpleTypes(schema)
47+
includedTerms = includedTerms.filter(fromSchema.contains(_))
5748
case _: Deletion =>
5849
assert(false, "Deletion terms cannot be nested")
5950
case _: Intercept =>
@@ -76,70 +67,31 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) {
7667
intercept
7768
}
7869

79-
// expands the Dot operators in interaction terms
80-
private def expandInteraction(
81-
schema: StructType, terms: Seq[InteractionComponent]): Seq[Seq[String]] = {
82-
if (terms.isEmpty) {
83-
return Seq(Nil)
84-
}
85-
86-
val rest = expandInteraction(schema, terms.tail)
87-
val validInteractions = (terms.head match {
88-
case Dot =>
89-
expandDot(schema).filter(_ != label.value).flatMap { t =>
90-
rest.map { r =>
91-
Seq(t) ++ r
92-
}
93-
}
94-
case ColumnRef(value) =>
95-
rest.map(Seq(value) ++ _)
96-
}).map(_.distinct)
97-
98-
// Deduplicates feature interactions, for example, a:b is the same as b:a.
99-
var seen = mutable.Set[Set[String]]()
100-
validInteractions.flatMap {
101-
case t if seen.contains(t.toSet) =>
102-
None
103-
case t =>
104-
seen += t.toSet
105-
Some(t)
106-
}.sortBy(_.length)
107-
}
108-
10970
// the dot operator excludes complex column types
110-
private def expandDot(schema: StructType): Seq[String] = {
71+
private def simpleTypes(schema: StructType): Seq[String] = {
11172
schema.fields.filter(_.dataType match {
11273
case _: NumericType | StringType | BooleanType | _: VectorUDT => true
11374
case _ => false
114-
}).map(_.name).filter(_ != label.value)
75+
}).map(_.name)
11576
}
11677
}
11778

11879
/**
11980
* Represents a fully evaluated and simplified R formula.
120-
* @param label the column name of the R formula label (response variable).
121-
* @param terms the simplified terms of the R formula. Interactions terms are represented as Seqs
122-
* of column names; non-interaction terms as length 1 Seqs.
12381
*/
124-
private[ml] case class ResolvedRFormula(label: String, terms: Seq[Seq[String]])
82+
private[ml] case class ResolvedRFormula(label: String, terms: Seq[String])
12583

12684
/**
12785
* R formula terms. See the R formula docs here for more information:
12886
* http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
12987
*/
13088
private[ml] sealed trait Term
13189

132-
/** A term that may be part of an interaction, e.g. 'x' in 'x:y' */
133-
private[ml] sealed trait InteractionComponent extends Term
134-
13590
/* R formula reference to all available columns, e.g. "." in a formula */
136-
private[ml] case object Dot extends InteractionComponent
91+
private[ml] case object Dot extends Term
13792

13893
/* R formula reference to a column, e.g. "+ Species" in a formula */
139-
private[ml] case class ColumnRef(value: String) extends InteractionComponent
140-
141-
/* R formula interaction of several columns, e.g. "Sepal_Length:Species" in a formula */
142-
private[ml] case class ColumnInteraction(terms: Seq[InteractionComponent]) extends Term
94+
private[ml] case class ColumnRef(value: String) extends Term
14395

14496
/* R formula intercept toggle, e.g. "+ 0" in a formula */
14597
private[ml] case class Intercept(enabled: Boolean) extends Term
@@ -157,15 +109,7 @@ private[ml] object RFormulaParser extends RegexParsers {
157109
def columnRef: Parser[ColumnRef] =
158110
"([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r ^^ { case a => ColumnRef(a) }
159111

160-
def dot: Parser[InteractionComponent] = "\\.".r ^^ { case _ => Dot }
161-
162-
def interaction: Parser[List[InteractionComponent]] = repsep(columnRef | dot, ":")
163-
164-
def term: Parser[Term] = intercept |
165-
interaction ^^ {
166-
case Seq(term) => term
167-
case terms => ColumnInteraction(terms)
168-
}
112+
def term: Parser[Term] = intercept | columnRef | "\\.".r ^^ { case _ => Dot }
169113

170114
def terms: Parser[List[Term]] = (term ~ rep("+" ~ term | "-" ~ term)) ^^ {
171115
case op ~ list => list.foldLeft(List(op)) {

mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ class RInteraction(override val uid: String) extends Estimator[PipelineModel]
127127
}
128128

129129
/**
130-
* This helper class computes the joint index of multiple string-indexed columns such that the
131-
* combined index covers the cartesian product of column values.
130+
* Computes the joint index of multiple string-indexed columns such that the combined index
131+
* covers the cartesian product of column values.
132132
*/
133133
private class IndexCombiner(
134134
inputCols: Array[String], attrNames: Array[String], outputCol: String)
@@ -181,8 +181,8 @@ private class IndexCombiner(
181181
}
182182

183183
/**
184-
* This helper class scales the input vector column by the product of the input numeric columns.
185-
* If no vector column is specified, the output is just the product of the numeric columns.
184+
* Scales the input vector column by the product of the input numeric columns. If no vector column
185+
* is specified, the output is just the product of the numeric columns.
186186
*/
187187
private class NumericInteraction(
188188
inputCols: Array[String], vectorCol: Option[String], outputCol: String)

mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,10 @@ class VectorAssembler(override val uid: String)
5656
val index = schema.fieldIndex(c)
5757
field.dataType match {
5858
case DoubleType =>
59-
val attr = Attribute.decodeStructField(field, preserveName = true)
59+
val attr = Attribute.fromStructField(field)
6060
// If the input column doesn't have ML attribute, assume numeric.
6161
if (attr == UnresolvedAttribute) {
6262
Some(NumericAttribute.defaultAttr.withName(c))
63-
} else if (attr.name.isDefined) {
64-
Some(attr)
6563
} else {
6664
Some(attr.withName(c))
6765
}
@@ -71,8 +69,15 @@ class VectorAssembler(override val uid: String)
7169
case _: VectorUDT =>
7270
val group = AttributeGroup.fromStructField(field)
7371
if (group.attributes.isDefined) {
74-
// If attributes are defined, copy them.
75-
group.attributes.get
72+
// If attributes are defined, copy them with updated names.
73+
group.attributes.get.map { attr =>
74+
if (attr.name.isDefined) {
75+
// TODO: Define a rigorous naming scheme.
76+
attr.withName(c + "_" + attr.name.get)
77+
} else {
78+
attr
79+
}
80+
}
7681
} else {
7782
// Otherwise, treat all attributes as numeric. If we cannot get the number of attributes
7883
// from metadata, check the first row.

0 commit comments

Comments
 (0)