@@ -83,12 +83,12 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
8383 dataset.schema(term) match {
8484 case column if column.dataType == StringType =>
8585 val idxTerm = term + " _idx_" + uid
86- val indexer = new StringIndexer (uid ).setInputCol(term).setOutputCol(idxTerm) )
87- Some (Map ( term -> indexer.fit(dataset) ))
86+ val indexer = new StringIndexer ().setInputCol(term).setOutputCol(idxTerm)
87+ Some (term -> indexer.fit(dataset))
8888 case _ =>
8989 None
9090 }
91- }
91+ }.toMap
9292 copyValues(new RFormulaModel (uid, parsedFormula.get, factorLevels).setParent(this ))
9393 }
9494
@@ -109,6 +109,8 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
109109
110110/**
111111 * A fitted RFormula. Fitting is required to determine the factor levels of formula terms.
112+ * @param parsedFormula a pre-parsed R formula
113+ * @param factorLevels the fitted factor to index mappings from the training dataset.
112114 */
113115private [feature] class RFormulaModel (
114116 override val uid : String ,
@@ -136,7 +138,7 @@ private[feature] class RFormulaModel(
136138 }
137139
138140 override def copy (extra : ParamMap ): RFormulaModel = copyValues(
139- new RFormulaModel (uid, parsedFormula))
141+ new RFormulaModel (uid, parsedFormula, factorLevels ))
140142
141143 override def toString : String = s " RFormulaModel( ${parsedFormula}) "
142144
@@ -172,12 +174,13 @@ private[feature] class RFormulaModel(
172174 case column if column.dataType == StringType =>
173175 val encodedTerm = term + " _onehot_" + uid
174176 val indexer = factorLevels(term)
177+ val indexCol = indexer.getOrDefault(indexer.outputCol)
175178 encoderStages :+= indexer
176- encoderStages :+= new OneHotEncoder (uid )
177- .setInputCol($(indexer.outputCol) )
179+ encoderStages :+= new OneHotEncoder ()
180+ .setInputCol(indexCol )
178181 .setOutputCol(encodedTerm)
179182 tempColumns :+= encodedTerm
180- tempColumns :+= $(indexer.outputCol)
183+ tempColumns :+= indexCol
181184 encodedTerm
182185 case _ =>
183186 term
@@ -186,16 +189,16 @@ private[feature] class RFormulaModel(
186189 encoderStages :+= new VectorAssembler (uid)
187190 .setInputCols(encodedTerms.toArray)
188191 .setOutputCol($(featuresCol))
189- encoderStages :+= new ColumnPruner (uid, tempColumns.toSet)
192+ encoderStages :+= new ColumnPruner (tempColumns.toSet)
190193 new PipelineModel (uid, encoderStages.toArray)
191194 }
192195}
193196
194197/**
195- * Utility class for removing temporary columns from a DataFrame.
198+ * Utility transformer for removing temporary columns from a DataFrame.
196199 */
197- private [ml] class ColumnPruner (
198- override val uid : String , columnsToPrune : Set [ String ]) extends Transformer {
200+ private class ColumnPruner (columnsToPrune : Set [ String ]) extends Transformer {
201+ override val uid = Identifiable .randomUID( " columnPruner " )
199202 override def transform (dataset : DataFrame ): DataFrame = {
200203 var res : DataFrame = dataset
201204 for (column <- columnsToPrune) {
@@ -212,7 +215,7 @@ private[ml] class ColumnPruner(
212215/**
213216 * Represents a parsed R formula.
214217 */
215- private [ml] case class ParsedRFormula (label : String , terms : Seq [String ])
218+ private [ml] case class ParsedRFormula (label : String , terms : Set [String ])
216219
217220/**
218221 * Limited implementation of R formula parsing. Currently supports: '~', '+'.
@@ -223,7 +226,7 @@ private[ml] object RFormulaParser extends RegexParsers {
223226 def expr : Parser [List [String ]] = term ~ rep(" +" ~> term) ^^ { case a ~ list => a :: list }
224227
225228 def formula : Parser [ParsedRFormula ] =
226- (term ~ " ~" ~ expr) ^^ { case r ~ " ~" ~ t => ParsedRFormula (r, t) }
229+ (term ~ " ~" ~ expr) ^^ { case r ~ " ~" ~ t => ParsedRFormula (r, t.toSet ) }
227230
228231 def parse (value : String ): ParsedRFormula = parseAll(formula, value) match {
229232 case Success (result, _) => result
0 commit comments