1717
1818package org .apache .spark .ml .feature
1919
20- import scala .collection .mutable
2120import scala .util .parsing .combinator .RegexParsers
2221
2322import org .apache .spark .mllib .linalg .VectorUDT
@@ -32,28 +31,20 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) {
3231 * of the special '.' term. Duplicate terms will be removed during resolution.
3332 */
3433 def resolve (schema : StructType ): ResolvedRFormula = {
35- val dotTerms = expandDot(schema)
36- var includedTerms = Seq [Seq [String ]]()
34+ var includedTerms = Seq [String ]()
3735 terms.foreach {
38- case term : ColumnRef =>
39- includedTerms :+= Seq (term.value)
40- case ColumnInteraction (terms) =>
41- includedTerms ++= expandInteraction(schema, terms)
4236 case Dot =>
43- includedTerms ++= dotTerms.map(Seq (_))
37+ includedTerms ++= simpleTypes(schema).filter(_ != label.value)
38+ case ColumnRef (value) =>
39+ includedTerms :+= value
4440 case Deletion (term : Term ) =>
4541 term match {
46- case inner : ColumnRef =>
47- includedTerms = includedTerms.filter(_ != Seq (inner.value))
48- case ColumnInteraction (terms) =>
49- val fromInteraction = expandInteraction(schema, terms).map(_.toSet)
50- includedTerms = includedTerms.filter(t => ! fromInteraction.contains(t.toSet))
42+ case ColumnRef (value) =>
43+ includedTerms = includedTerms.filter(_ != value)
5144 case Dot =>
5245 // e.g. "- .", which removes all first-order terms
53- includedTerms = includedTerms.filter {
54- case Seq (t) => ! dotTerms.contains(t)
55- case _ => true
56- }
46+ val fromSchema = simpleTypes(schema)
47+ includedTerms = includedTerms.filter(fromSchema.contains(_))
5748 case _ : Deletion =>
5849 assert(false , " Deletion terms cannot be nested" )
5950 case _ : Intercept =>
@@ -76,70 +67,31 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) {
7667 intercept
7768 }
7869
79- // expands the Dot operators in interaction terms
80- private def expandInteraction (
81- schema : StructType , terms : Seq [InteractionComponent ]): Seq [Seq [String ]] = {
82- if (terms.isEmpty) {
83- return Seq (Nil )
84- }
85-
86- val rest = expandInteraction(schema, terms.tail)
87- val validInteractions = (terms.head match {
88- case Dot =>
89- expandDot(schema).filter(_ != label.value).flatMap { t =>
90- rest.map { r =>
91- Seq (t) ++ r
92- }
93- }
94- case ColumnRef (value) =>
95- rest.map(Seq (value) ++ _)
96- }).map(_.distinct)
97-
98- // Deduplicates feature interactions, for example, a:b is the same as b:a.
99- var seen = mutable.Set [Set [String ]]()
100- validInteractions.flatMap {
101- case t if seen.contains(t.toSet) =>
102- None
103- case t =>
104- seen += t.toSet
105- Some (t)
106- }.sortBy(_.length)
107- }
108-
10970 // the dot operator excludes complex column types
110- private def expandDot (schema : StructType ): Seq [String ] = {
71+ private def simpleTypes (schema : StructType ): Seq [String ] = {
11172 schema.fields.filter(_.dataType match {
11273 case _ : NumericType | StringType | BooleanType | _ : VectorUDT => true
11374 case _ => false
114- }).map(_.name).filter(_ != label.value)
75+ }).map(_.name)
11576 }
11677}
11778
11879/**
11980 * Represents a fully evaluated and simplified R formula.
120- * @param label the column name of the R formula label (response variable).
121- * @param terms the simplified terms of the R formula. Interactions terms are represented as Seqs
122- * of column names; non-interaction terms as length 1 Seqs.
12381 */
124- private [ml] case class ResolvedRFormula (label : String , terms : Seq [Seq [ String ] ])
82+ private [ml] case class ResolvedRFormula (label : String , terms : Seq [String ])
12583
12684/**
12785 * R formula terms. See the R formula docs here for more information:
12886 * http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
12987 */
13088private [ml] sealed trait Term
13189
132- /** A term that may be part of an interaction, e.g. 'x' in 'x:y' */
133- private [ml] sealed trait InteractionComponent extends Term
134-
13590/* R formula reference to all available columns, e.g. "." in a formula */
136- private [ml] case object Dot extends InteractionComponent
91+ private [ml] case object Dot extends Term
13792
13893/* R formula reference to a column, e.g. "+ Species" in a formula */
139- private [ml] case class ColumnRef (value : String ) extends InteractionComponent
140-
141- /* R formula interaction of several columns, e.g. "Sepal_Length:Species" in a formula */
142- private [ml] case class ColumnInteraction (terms : Seq [InteractionComponent ]) extends Term
94+ private [ml] case class ColumnRef (value : String ) extends Term
14395
14496/* R formula intercept toggle, e.g. "+ 0" in a formula */
14597private [ml] case class Intercept (enabled : Boolean ) extends Term
@@ -157,15 +109,7 @@ private[ml] object RFormulaParser extends RegexParsers {
157109 def columnRef : Parser [ColumnRef ] =
158110 " ([a-zA-Z]|\\ .[a-zA-Z_])[a-zA-Z0-9._]*" .r ^^ { case a => ColumnRef (a) }
159111
160- def dot : Parser [InteractionComponent ] = " \\ ." .r ^^ { case _ => Dot }
161-
162- def interaction : Parser [List [InteractionComponent ]] = repsep(columnRef | dot, " :" )
163-
164- def term : Parser [Term ] = intercept |
165- interaction ^^ {
166- case Seq (term) => term
167- case terms => ColumnInteraction (terms)
168- }
112+ def term : Parser [Term ] = intercept | columnRef | " \\ ." .r ^^ { case _ => Dot }
169113
170114 def terms : Parser [List [Term ]] = (term ~ rep(" +" ~ term | " -" ~ term)) ^^ {
171115 case op ~ list => list.foldLeft(List (op)) {
0 commit comments