1717
1818package org .apache .spark .ml .feature
1919
20+ import scala .collection .mutable .ArrayBuffer
2021import scala .util .parsing .combinator .RegexParsers
2122
2223import org .apache .spark .annotation .Experimental
23- import org .apache .spark .ml .Transformer
24+ import org .apache .spark .ml .{ Estimator , Model , Transformer , Pipeline , PipelineModel , PipelineStage }
2425import org .apache .spark .ml .param .{Param , ParamMap }
2526import org .apache .spark .ml .param .shared .{HasFeaturesCol , HasLabelCol }
2627import org .apache .spark .ml .util .Identifiable
28+ import org .apache .spark .mllib .linalg .VectorUDT
2729import org .apache .spark .sql .DataFrame
2830import org .apache .spark .sql .functions ._
2931import org .apache .spark .sql .types ._
3032
33+ /**
34+ * Base trait for [[RFormula ]] and [[RFormulaModel ]].
35+ */
36+ private [feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol {
37+ /** @group getParam */
38+ def setFeaturesCol (value : String ): this .type = set(featuresCol, value)
39+
40+ /** @group getParam */
41+ def setLabelCol (value : String ): this .type = set(labelCol, value)
42+
43+ protected def hasLabelCol (schema : StructType ): Boolean = {
44+ schema.map(_.name).contains($(labelCol))
45+ }
46+ }
47+
3148/**
3249 * :: Experimental ::
3350 * Implements the transforms required for fitting a dataset against an R model formula. Currently
3451 * we support a limited subset of the R operators, including '~' and '+'. Also see the R formula
3552 * docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
3653 */
3754@ Experimental
38- class RFormula (override val uid : String )
39- extends Transformer with HasFeaturesCol with HasLabelCol {
55+ class RFormula (override val uid : String ) extends Estimator [RFormulaModel ] with RFormulaBase {
4056
4157 def this () = this (Identifiable .randomUID(" rFormula" ))
4258
@@ -62,19 +78,74 @@ class RFormula(override val uid: String)
6278 /** @group getParam */
6379 def getFormula : String = $(formula)
6480
65- /** @group getParam */
66- def setFeaturesCol (value : String ): this .type = set(featuresCol, value)
81+ override def fit (dataset : DataFrame ): RFormulaModel = {
82+ require(parsedFormula.isDefined, " Must call setFormula() first." )
83+ // StringType terms and terms representing interactions need to be encoded before assembly.
84+ // TODO(ekl) add support for feature interactions
85+ var encoderStages = ArrayBuffer [PipelineStage ]()
86+ var tempColumns = ArrayBuffer [String ]()
87+ val encodedTerms = parsedFormula.get.terms.map { term =>
88+ dataset.schema(term) match {
89+ case column if column.dataType == StringType =>
90+ val indexCol = term + " _idx_" + uid
91+ val encodedCol = term + " _onehot_" + uid
92+ encoderStages += new StringIndexer ().setInputCol(term).setOutputCol(indexCol)
93+ encoderStages += new OneHotEncoder ().setInputCol(indexCol).setOutputCol(encodedCol)
94+ tempColumns += indexCol
95+ tempColumns += encodedCol
96+ encodedCol
97+ case _ =>
98+ term
99+ }
100+ }
101+ encoderStages += new VectorAssembler (uid)
102+ .setInputCols(encodedTerms.toArray)
103+ .setOutputCol($(featuresCol))
104+ encoderStages += new ColumnPruner (tempColumns.toSet)
105+ val pipelineModel = new Pipeline (uid).setStages(encoderStages.toArray).fit(dataset)
106+ copyValues(new RFormulaModel (uid, parsedFormula.get, pipelineModel).setParent(this ))
107+ }
67108
68- /** @group getParam */
69- def setLabelCol (value : String ): this .type = set(labelCol, value)
109+ // optimistic schema; does not contain any ML attributes
110+ override def transformSchema (schema : StructType ): StructType = {
111+ if (hasLabelCol(schema)) {
112+ StructType (schema.fields :+ StructField ($(featuresCol), new VectorUDT , true ))
113+ } else {
114+ StructType (schema.fields :+ StructField ($(featuresCol), new VectorUDT , true ) :+
115+ StructField ($(labelCol), DoubleType , true ))
116+ }
117+ }
118+
119+ override def copy (extra : ParamMap ): RFormula = defaultCopy(extra)
120+
121+ override def toString : String = s " RFormula( ${get(formula)}) "
122+ }
123+
124+ /**
125+ * :: Experimental ::
126+ * A fitted RFormula. Fitting is required to determine the factor levels of formula terms.
127+ * @param parsedFormula a pre-parsed R formula.
128+ * @param pipelineModel the fitted feature model, including factor to index mappings.
129+ */
130+ @ Experimental
131+ class RFormulaModel private [feature](
132+ override val uid : String ,
133+ parsedFormula : ParsedRFormula ,
134+ pipelineModel : PipelineModel )
135+ extends Model [RFormulaModel ] with RFormulaBase {
136+
137+ override def transform (dataset : DataFrame ): DataFrame = {
138+ checkCanTransform(dataset.schema)
139+ transformLabel(pipelineModel.transform(dataset))
140+ }
70141
71142 override def transformSchema (schema : StructType ): StructType = {
72143 checkCanTransform(schema)
73- val withFeatures = transformFeatures .transformSchema(schema)
144+ val withFeatures = pipelineModel .transformSchema(schema)
74145 if (hasLabelCol(schema)) {
75146 withFeatures
76- } else if (schema.exists(_.name == parsedFormula.get. label)) {
77- val nullable = schema(parsedFormula.get. label).dataType match {
147+ } else if (schema.exists(_.name == parsedFormula.label)) {
148+ val nullable = schema(parsedFormula.label).dataType match {
78149 case _ : NumericType | BooleanType => false
79150 case _ => true
80151 }
@@ -86,24 +157,19 @@ class RFormula(override val uid: String)
86157 }
87158 }
88159
89- override def transform (dataset : DataFrame ): DataFrame = {
90- checkCanTransform(dataset.schema)
91- transformLabel(transformFeatures.transform(dataset))
92- }
93-
94- override def copy (extra : ParamMap ): RFormula = defaultCopy(extra)
160+ override def copy (extra : ParamMap ): RFormulaModel = copyValues(
161+ new RFormulaModel (uid, parsedFormula, pipelineModel))
95162
96- override def toString : String = s " RFormula ( ${get(formula) }) "
163+ override def toString : String = s " RFormulaModel ( ${parsedFormula }) "
97164
98165 private def transformLabel (dataset : DataFrame ): DataFrame = {
99- val labelName = parsedFormula.get. label
166+ val labelName = parsedFormula.label
100167 if (hasLabelCol(dataset.schema)) {
101168 dataset
102169 } else if (dataset.schema.exists(_.name == labelName)) {
103170 dataset.schema(labelName).dataType match {
104171 case _ : NumericType | BooleanType =>
105172 dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType ))
106- // TODO(ekl) add support for string-type labels
107173 case other =>
108174 throw new IllegalArgumentException (" Unsupported type for label: " + other)
109175 }
@@ -114,25 +180,32 @@ class RFormula(override val uid: String)
114180 }
115181 }
116182
117- private def transformFeatures : Transformer = {
118- // TODO(ekl) add support for non-numeric features and feature interactions
119- new VectorAssembler (uid)
120- .setInputCols(parsedFormula.get.terms.toArray)
121- .setOutputCol($(featuresCol))
122- }
123-
124183 private def checkCanTransform (schema : StructType ) {
125- require(parsedFormula.isDefined, " Must call setFormula() first." )
126184 val columnNames = schema.map(_.name)
127185 require(! columnNames.contains($(featuresCol)), " Features column already exists." )
128186 require(
129187 ! columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType ,
130188 " Label column already exists and is not of type DoubleType." )
131189 }
190+ }
132191
133- private def hasLabelCol (schema : StructType ): Boolean = {
134- schema.map(_.name).contains($(labelCol))
192+ /**
193+ * Utility transformer for removing temporary columns from a DataFrame.
194+ * TODO(ekl) make this a public transformer
195+ */
196+ private class ColumnPruner (columnsToPrune : Set [String ]) extends Transformer {
197+ override val uid = Identifiable .randomUID(" columnPruner" )
198+
199+ override def transform (dataset : DataFrame ): DataFrame = {
200+ val columnsToKeep = dataset.columns.filter(! columnsToPrune.contains(_))
201+ dataset.select(columnsToKeep.map(dataset.col) : _* )
135202 }
203+
204+ override def transformSchema (schema : StructType ): StructType = {
205+ StructType (schema.fields.filter(col => ! columnsToPrune.contains(col.name)))
206+ }
207+
208+ override def copy (extra : ParamMap ): ColumnPruner = defaultCopy(extra)
136209}
137210
138211/**
@@ -149,7 +222,7 @@ private[ml] object RFormulaParser extends RegexParsers {
149222 def expr : Parser [List [String ]] = term ~ rep(" +" ~> term) ^^ { case a ~ list => a :: list }
150223
151224 def formula : Parser [ParsedRFormula ] =
152- (term ~ " ~" ~ expr) ^^ { case r ~ " ~" ~ t => ParsedRFormula (r, t) }
225+ (term ~ " ~" ~ expr) ^^ { case r ~ " ~" ~ t => ParsedRFormula (r, t.distinct ) }
153226
154227 def parse (value : String ): ParsedRFormula = parseAll(formula, value) match {
155228 case Success (result, _) => result
0 commit comments