Skip to content

Commit b01c7c5

Browse files
committed
add test
1 parent 8a637db commit b01c7c5

File tree

3 files changed

+38
-33
lines changed

3 files changed

+38
-33
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,12 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
8383
dataset.schema(term) match {
8484
case column if column.dataType == StringType =>
8585
val idxTerm = term + "_idx_" + uid
86-
val indexer = new StringIndexer(uid).setInputCol(term).setOutputCol(idxTerm))
87-
Some(Map(term -> indexer.fit(dataset)))
86+
val indexer = new StringIndexer().setInputCol(term).setOutputCol(idxTerm)
87+
Some(term -> indexer.fit(dataset))
8888
case _ =>
8989
None
9090
}
91-
}
91+
}.toMap
9292
copyValues(new RFormulaModel(uid, parsedFormula.get, factorLevels).setParent(this))
9393
}
9494

@@ -109,6 +109,8 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
109109

110110
/**
111111
* A fitted RFormula. Fitting is required to determine the factor levels of formula terms.
112+
* @param parsedFormula a pre-parsed R formula
113+
* @param factorLevels the fitted factor to index mappings from the training dataset.
112114
*/
113115
private[feature] class RFormulaModel(
114116
override val uid: String,
@@ -136,7 +138,7 @@ private[feature] class RFormulaModel(
136138
}
137139

138140
override def copy(extra: ParamMap): RFormulaModel = copyValues(
139-
new RFormulaModel(uid, parsedFormula))
141+
new RFormulaModel(uid, parsedFormula, factorLevels))
140142

141143
override def toString: String = s"RFormulaModel(${parsedFormula})"
142144

@@ -172,12 +174,13 @@ private[feature] class RFormulaModel(
172174
case column if column.dataType == StringType =>
173175
val encodedTerm = term + "_onehot_" + uid
174176
val indexer = factorLevels(term)
177+
val indexCol = indexer.getOrDefault(indexer.outputCol)
175178
encoderStages :+= indexer
176-
encoderStages :+= new OneHotEncoder(uid)
177-
.setInputCol($(indexer.outputCol))
179+
encoderStages :+= new OneHotEncoder()
180+
.setInputCol(indexCol)
178181
.setOutputCol(encodedTerm)
179182
tempColumns :+= encodedTerm
180-
tempColumns :+= $(indexer.outputCol)
183+
tempColumns :+= indexCol
181184
encodedTerm
182185
case _ =>
183186
term
@@ -186,16 +189,16 @@ private[feature] class RFormulaModel(
186189
encoderStages :+= new VectorAssembler(uid)
187190
.setInputCols(encodedTerms.toArray)
188191
.setOutputCol($(featuresCol))
189-
encoderStages :+= new ColumnPruner(uid, tempColumns.toSet)
192+
encoderStages :+= new ColumnPruner(tempColumns.toSet)
190193
new PipelineModel(uid, encoderStages.toArray)
191194
}
192195
}
193196

194197
/**
195-
* Utility class for removing temporary columns from a DataFrame.
198+
* Utility transformer for removing temporary columns from a DataFrame.
196199
*/
197-
private[ml] class ColumnPruner(
198-
override val uid: String, columnsToPrune: Set[String]) extends Transformer {
200+
private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {
201+
override val uid = Identifiable.randomUID("columnPruner")
199202
override def transform(dataset: DataFrame): DataFrame = {
200203
var res: DataFrame = dataset
201204
for (column <- columnsToPrune) {
@@ -212,7 +215,7 @@ private[ml] class ColumnPruner(
212215
/**
213216
* Represents a parsed R formula.
214217
*/
215-
private[ml] case class ParsedRFormula(label: String, terms: Seq[String])
218+
private[ml] case class ParsedRFormula(label: String, terms: Set[String])
216219

217220
/**
218221
* Limited implementation of R formula parsing. Currently supports: '~', '+'.
@@ -223,7 +226,7 @@ private[ml] object RFormulaParser extends RegexParsers {
223226
def expr: Parser[List[String]] = term ~ rep("+" ~> term) ^^ { case a ~ list => a :: list }
224227

225228
def formula: Parser[ParsedRFormula] =
226-
(term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) }
229+
(term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t.toSet) }
227230

228231
def parse(value: String): ParsedRFormula = parseAll(formula, value) match {
229232
case Success(result, _) => result

mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,16 @@ package org.apache.spark.ml.feature
2020
import org.apache.spark.SparkFunSuite
2121

2222
class RFormulaParserSuite extends SparkFunSuite {
23-
private def checkParse(formula: String, label: String, terms: Seq[String]) {
23+
private def checkParse(formula: String, label: String, terms: Set[String]) {
2424
val parsed = RFormulaParser.parse(formula)
2525
assert(parsed.label == label)
2626
assert(parsed.terms == terms)
2727
}
2828

2929
test("parse simple formulas") {
30-
checkParse("y ~ x", "y", Seq("x"))
31-
checkParse("y ~ ._foo ", "y", Seq("._foo"))
32-
checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123"))
30+
checkParse("y ~ x", "y", Set("x"))
31+
checkParse("y ~ x + x", "y", Set("x"))
32+
checkParse("y ~ ._foo ", "y", Set("._foo"))
33+
checkParse("resp ~ A_VAR + B + c123", "resp", Set("A_VAR", "B", "c123"))
3334
}
3435
}

mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -78,20 +78,21 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
7878
}
7979
}
8080

81-
// TODO(ekl) enable after we implement string label support
82-
// test("transform string label") {
83-
// val formula = new RFormula().setFormula("name ~ id")
84-
// val original = sqlContext.createDataFrame(
85-
// Seq((1, "foo"), (2, "bar"), (3, "bar"))).toDF("id", "name")
86-
// val result = formula.transform(original)
87-
// val resultSchema = formula.transformSchema(original.schema)
88-
// val expected = sqlContext.createDataFrame(
89-
// Seq(
90-
// (1, "foo", Vectors.dense(Array(1.0)), 1.0),
91-
// (2, "bar", Vectors.dense(Array(2.0)), 0.0),
92-
// (3, "bar", Vectors.dense(Array(3.0)), 0.0))
93-
// ).toDF("id", "name", "features", "label")
94-
// assert(result.schema.toString == resultSchema.toString)
95-
// assert(result.collect().toSeq == expected.collect().toSeq)
96-
// }
81+
test("encodes string terms") {
82+
val formula = new RFormula().setFormula("id ~ category")
83+
val original = sqlContext.createDataFrame(
84+
Seq((1, "foo"), (2, "bar"), (3, "bar"), (4, "baz"))).toDF("id", "category")
85+
val model = formula.fit(original)
86+
val result = model.transform(original)
87+
val resultSchema = model.transformSchema(original.schema)
88+
val expected = sqlContext.createDataFrame(
89+
Seq(
90+
(1, "foo", Vectors.dense(Array(0.0, 1.0)), 1.0),
91+
(2, "bar", Vectors.dense(Array(1.0, 0.0)), 2.0),
92+
(3, "bar", Vectors.dense(Array(1.0, 0.0)), 3.0),
93+
(4, "baz", Vectors.dense(Array(0.0, 0.0)), 4.0))
94+
).toDF("id", "name", "features", "label")
95+
assert(result.schema.toString == resultSchema.toString)
96+
assert(result.collect().toSeq == expected.collect().toSeq)
97+
}
9798
}

0 commit comments

Comments
 (0)