@@ -32,21 +32,21 @@ import org.apache.spark.sql.types._
3232 * :: Experimental ::
3333 * Implements the transforms required for fitting a dataset against an R model formula. Currently
3434 * we support a limited subset of the R operators, including '~' and '+'. Also see the R formula
35- * docs here: http://www.inside-r.org/r-doc/ stats/formula
35+ * docs here: http://stat.ethz.ch/R-manual/R-patched/library/ stats/html/ formula.html
3636 */
3737@ Experimental
38- class RModelFormula (override val uid : String )
38+ class RFormula (override val uid : String )
3939 extends Transformer with HasFeaturesCol with HasLabelCol {
4040
41- def this () = this (Identifiable .randomUID(" rModelFormula " ))
41+ def this () = this (Identifiable .randomUID(" rFormula " ))
4242
4343 /**
4444 * R formula parameter. The formula is provided in string form.
4545 * @group setParam
4646 */
4747 val formula : Param [String ] = new Param (this , " formula" , " R model formula" )
4848
49- private var parsedFormula : Option [RFormula ] = None
49+ private var parsedFormula : Option [ParsedRFormula ] = None
5050
5151 /**
5252 * Sets the formula to use for this transformer. Must be called before use.
@@ -63,60 +63,74 @@ class RModelFormula(override val uid: String)
6363 def getFormula : String = $(formula)
6464
6565 /** @group getParam */
66- def setFeaturesCol (col : String ): this .type = set(featuresCol, col )
66+ def setFeaturesCol (value : String ): this .type = set(featuresCol, value )
6767
6868 /** @group getParam */
69- def setLabelCol (col : String ): this .type = set(labelCol, col )
69+ def setLabelCol (value : String ): this .type = set(labelCol, value )
7070
7171 override def transformSchema (schema : StructType ): StructType = {
72- require(parsedFormula.isDefined, " Must call setFormula() first." )
73- val withFeatures = featureTransformer.transformSchema(schema)
74- val nullable = schema(parsedFormula.get.response).dataType match {
75- case _ : NumericType | BooleanType => false
76- case _ => true
72+ checkCanTransform(schema)
73+ val withFeatures = transformFeatures.transformSchema(schema)
74+ if (hasLabelCol(schema)) {
75+ withFeatures
76+ } else {
77+ val nullable = schema(parsedFormula.get.label).dataType match {
78+ case _ : NumericType | BooleanType => false
79+ case _ => true
80+ }
81+ StructType (withFeatures.fields :+ StructField ($(labelCol), DoubleType , nullable))
7782 }
78- StructType (withFeatures.fields :+ StructField ($(labelCol), DoubleType , nullable))
7983 }
8084
8185 override def transform (dataset : DataFrame ): DataFrame = {
82- require(parsedFormula.isDefined, " Must call setFormula() first. " )
83- transformLabel(featureTransformer .transform(dataset))
86+ checkCanTransform(dataset.schema )
87+ transformLabel(transformFeatures .transform(dataset))
8488 }
8589
86- override def copy (extra : ParamMap ): RModelFormula = defaultCopy(extra)
90+ override def copy (extra : ParamMap ): RFormula = defaultCopy(extra)
8791
88- override def toString : String = s " RModelFormula ( ${get(formula)}) "
92+ override def toString : String = s " RFormula ( ${get(formula)}) "
8993
9094 private def transformLabel (dataset : DataFrame ): DataFrame = {
91- val responseName = parsedFormula.get.response
92- dataset.schema(responseName).dataType match {
93- case _ : NumericType | BooleanType =>
94- dataset.select(
95- col(" *" ),
96- dataset(responseName).cast(DoubleType ).as($(labelCol)))
97- case StringType =>
98- new StringIndexer ()
99- .setInputCol(responseName)
100- .setOutputCol($(labelCol))
101- .fit(dataset)
102- .transform(dataset)
103- case other =>
104- throw new IllegalArgumentException (" Unsupported type for response: " + other)
95+ if (hasLabelCol(dataset.schema)) {
96+ dataset
97+ } else {
98+ val labelName = parsedFormula.get.label
99+ dataset.schema(labelName).dataType match {
100+ case _ : NumericType | BooleanType =>
101+ dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType ))
102+ // TODO(ekl) add support for string-type labels
103+ case other =>
104+ throw new IllegalArgumentException (" Unsupported type for label: " + other)
105+ }
105106 }
106107 }
107108
108- private def featureTransformer : Transformer = {
109+ private def transformFeatures : Transformer = {
109110 // TODO(ekl) add support for non-numeric features and feature interactions
110111 new VectorAssembler (uid)
111112 .setInputCols(parsedFormula.get.terms.toArray)
112113 .setOutputCol($(featuresCol))
113114 }
115+
116+ private def checkCanTransform (schema : StructType ) {
117+ require(parsedFormula.isDefined, " Must call setFormula() first." )
118+ val columnNames = schema.map(_.name)
119+ require(! columnNames.contains($(featuresCol)), " Features column already exists." )
120+ require(
121+ ! columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType ,
122+ " Label column already exists and is not of type DoubleType." )
123+ }
124+
125+ private def hasLabelCol (schema : StructType ): Boolean = {
126+ schema.map(_.name).contains($(labelCol))
127+ }
114128}
115129
116130/**
117131 * Represents a parsed R formula.
118132 */
119- private [ml] case class RFormula ( response : String , terms : Seq [String ])
133+ private [ml] case class ParsedRFormula ( label : String , terms : Seq [String ])
120134
121135/**
122136 * Limited implementation of R formula parsing. Currently supports: '~', '+'.
@@ -126,9 +140,10 @@ private[ml] object RFormulaParser extends RegexParsers {
126140
127141 def expr : Parser [List [String ]] = term ~ rep(" +" ~> term) ^^ { case a ~ list => a :: list }
128142
129- def formula : Parser [RFormula ] = (term ~ " ~" ~ expr) ^^ { case r ~ " ~" ~ t => RFormula (r, t) }
143+ def formula : Parser [ParsedRFormula ] =
144+ (term ~ " ~" ~ expr) ^^ { case r ~ " ~" ~ t => ParsedRFormula (r, t) }
130145
131- def parse (value : String ): RFormula = parseAll(formula, value) match {
146+ def parse (value : String ): ParsedRFormula = parseAll(formula, value) match {
132147 case Success (result, _) => result
133148 case failure : NoSuccess => throw new IllegalArgumentException (
134149 " Could not parse formula: " + value)
0 commit comments