-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-14659][ML] RFormula consistent with R when handling strings #17967
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4d27123
6841c33
77fe864
a1be94c
698588e
147311b
5f31d31
341949c
24818a7
1a1e06c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -26,7 +26,7 @@ import org.apache.spark.annotation.{Experimental, Since} | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import org.apache.spark.ml.attribute.AttributeGroup | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import org.apache.spark.ml.linalg.VectorUDT | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import org.apache.spark.ml.util._ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import org.apache.spark.sql.{DataFrame, Dataset} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -37,6 +37,42 @@ import org.apache.spark.sql.types._ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| /** | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * Param for how to order categories of a string FEATURE column used by `StringIndexer`. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * The last category after ordering is dropped when encoding strings. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * Supported options: 'frequencyDesc', 'frequencyAsc', 'alphabetDesc', 'alphabetAsc'. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * The default value is 'frequencyDesc'. When the ordering is set to 'alphabetDesc', `RFormula` | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * drops the same category as R when encoding strings. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The order should be |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * The options are explained using an example `'b', 'a', 'b', 'a', 'c', 'b'`: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * {{{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * +-----------------+---------------------------------------+----------------------------------+ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would like to suggest just to write out as prose with a simple list if we are all fine for now, which I guess we would generally agree with.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @HyukjinKwon Would you please clarify what you mean by a
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, sure, I initially meant a HTML list that we are already using - spark/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala Lines 304 to 340 in 04901dd
<ul>
<li> abc </li>
<li> abc </li>
</ul>I just tested it to double-check a wiki-style list ( Scaladoc Javadoc My worry is, it draws attention with a different format. I believe we have similar instances but wonder if it is worth changing only this one. I would not strongly against but
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess I am not supposed to make a decision call though. Please let me know @felixcheung and @yanboliang if you have any preference.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @HyukjinKwon Thanks for the clarification. I don't think
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. according to this, table is https://wiki.scala-lang.org/display/SW/Syntax
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @felixcheung @HyukjinKwon The scaladoc complied, but the javadoc failed... Not sure if there is additional config for java?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think javadoc 8 complains about HTML. It looks this works: * <table summary="abc">
* <tr>
* <th>Firstname</th>
* <th>Lastname</th>
* <th>Age</th>
* </tr>
* <tr>
* <td>Jill</td>
* <td>Smith</td>
* <td>50</td>
* </tr>
* <tr>
* <td>Eve</td>
* <td>Jackson</td>
* <td>94</td>
* </tr>
* </table>Scaladoc Javadoc Other errors probably are spurious (please refer https://issues.apache.org/jira/browse/SPARK-20840 which I am fighting with right now).
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @HyukjinKwon Nice. Thanks much. One issue is in the scaladoc, the columns are very close to each other. How to add spacing between columns in the scaladoc?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure. it did not work for me too ... |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * | Option | Category mapped to 0 by StringIndexer | Category dropped by RFormula | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * +-----------------+---------------------------------------+----------------------------------+ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * | 'frequencyDesc' | most frequent category ('b') | least frequent category ('c') | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * | 'frequencyAsc' | least frequent category ('c') | most frequent category ('b') | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * | 'alphabetDesc' | last alphabetical category ('c') | first alphabetical category ('a')| | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * | 'alphabetAsc' | first alphabetical category ('a') | last alphabetical category ('c') | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * +-----------------+---------------------------------------+----------------------------------+ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * }}} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * Note that this ordering option is NOT used for the label column. When the label column is | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * indexed, it uses the default descending frequency ordering in `StringIndexer`. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * @group param | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @Since("2.3.0") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| final val stringIndexerOrderType: Param[String] = new Param(this, "stringIndexerOrderType", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "How to order categories of a string FEATURE column used by StringIndexer. " + | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "The last category after ordering is dropped when encoding strings. " + | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| s"Supported options: ${StringIndexer.supportedStringOrderType.mkString(", ")}. " + | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "The default value is 'frequencyDesc'. When the ordering is set to 'alphabetDesc', " + | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "RFormula drops the same category as R when encoding strings.", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ParamValidators.inArray(StringIndexer.supportedStringOrderType)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| /** @group getParam */ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @Since("2.3.0") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def getStringIndexerOrderType: String = $(stringIndexerOrderType) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| protected def hasLabelCol(schema: StructType): Boolean = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| schema.map(_.name).contains($(labelCol)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -125,6 +161,11 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @Since("2.1.0") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def setForceIndexLabel(value: Boolean): this.type = set(forceIndexLabel, value) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| /** @group setParam */ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @Since("2.3.0") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def setStringIndexerOrderType(value: String): this.type = set(stringIndexerOrderType, value) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| setDefault(stringIndexerOrderType, StringIndexer.frequencyDesc) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| /** Whether the formula specifies fitting an intercept. */ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| private[ml] def hasIntercept: Boolean = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| require(isDefined(formula), "Formula must be defined first.") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -155,6 +196,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| encoderStages += new StringIndexer() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| .setInputCol(term) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| .setOutputCol(indexCol) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| .setStringOrderType($(stringIndexerOrderType)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prefixesToRewrite(indexCol + "_") = term + "_" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| (term, indexCol) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| case _ => | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -129,6 +129,90 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul | |
| assert(result.collect() === expected.collect()) | ||
| } | ||
|
|
||
| test("encodes string terms with string indexer order type") { | ||
| val formula = new RFormula().setFormula("id ~ a + b") | ||
| val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "aaz", 5)) | ||
| .toDF("id", "a", "b") | ||
|
|
||
| val expected = Seq( | ||
| Seq( | ||
| (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0), | ||
| (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0), | ||
| (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0), | ||
| (4, "aaz", 5, Vectors.dense(0.0, 1.0, 5.0), 4.0) | ||
| ).toDF("id", "a", "b", "features", "label"), | ||
| Seq( | ||
| (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), | ||
| (2, "bar", 4, Vectors.dense(0.0, 0.0, 4.0), 2.0), | ||
| (3, "bar", 5, Vectors.dense(0.0, 0.0, 5.0), 3.0), | ||
| (4, "aaz", 5, Vectors.dense(1.0, 0.0, 5.0), 4.0) | ||
| ).toDF("id", "a", "b", "features", "label"), | ||
| Seq( | ||
| (1, "foo", 4, Vectors.dense(1.0, 0.0, 4.0), 1.0), | ||
| (2, "bar", 4, Vectors.dense(0.0, 1.0, 4.0), 2.0), | ||
| (3, "bar", 5, Vectors.dense(0.0, 1.0, 5.0), 3.0), | ||
| (4, "aaz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0) | ||
| ).toDF("id", "a", "b", "features", "label"), | ||
| Seq( | ||
| (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0), | ||
| (2, "bar", 4, Vectors.dense(0.0, 1.0, 4.0), 2.0), | ||
| (3, "bar", 5, Vectors.dense(0.0, 1.0, 5.0), 3.0), | ||
| (4, "aaz", 5, Vectors.dense(1.0, 0.0, 5.0), 4.0) | ||
| ).toDF("id", "a", "b", "features", "label") | ||
| ) | ||
|
|
||
| var idx = 0 | ||
| for (orderType <- StringIndexer.supportedStringOrderType) { | ||
| val model = formula.setStringIndexerOrderType(orderType).fit(original) | ||
| val result = model.transform(original) | ||
| val resultSchema = model.transformSchema(original.schema) | ||
| assert(result.schema.toString == resultSchema.toString) | ||
| assert(result.collect() === expected(idx).collect()) | ||
| idx += 1 | ||
| } | ||
| } | ||
|
|
||
| test("test consistency with R when encoding string terms") { | ||
| /* | ||
| R code: | ||
|
|
||
| df <- data.frame(id = c(1, 2, 3, 4), | ||
| a = c("foo", "bar", "bar", "aaz"), | ||
| b = c(4, 4, 5, 5)) | ||
| model.matrix(id ~ a + b, df)[, -1] | ||
|
|
||
| abar afoo b | ||
| 0 1 4 | ||
| 1 0 4 | ||
| 1 0 5 | ||
| 0 0 5 | ||
| */ | ||
| val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "aaz", 5)) | ||
| .toDF("id", "a", "b") | ||
| val formula = new RFormula().setFormula("id ~ a + b") | ||
| .setStringIndexerOrderType(StringIndexer.alphabetDesc) | ||
|
|
||
| /* | ||
| Note that the category dropped after encoding is the same between R and Spark | ||
| (i.e., "aaz" is treated as the reference level). | ||
| However, the column order is still different: | ||
| R renders the columns in ascending alphabetical order ("bar", "foo"), while | ||
| RFormula renders the columns in descending alphabetical order ("foo", "bar"). | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. R and RFormula should behavior consistent if you fix the issue I mentioned above. |
||
| */ | ||
| val expected = Seq( | ||
| (1, "foo", 4, Vectors.dense(1.0, 0.0, 4.0), 1.0), | ||
| (2, "bar", 4, Vectors.dense(0.0, 1.0, 4.0), 2.0), | ||
| (3, "bar", 5, Vectors.dense(0.0, 1.0, 5.0), 3.0), | ||
| (4, "aaz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0) | ||
| ).toDF("id", "a", "b", "features", "label") | ||
|
|
||
| val model = formula.fit(original) | ||
| val result = model.transform(original) | ||
| val resultSchema = model.transformSchema(original.schema) | ||
| assert(result.schema.toString == resultSchema.toString) | ||
| assert(result.collect() === expected.collect()) | ||
| } | ||
|
|
||
| test("index string label") { | ||
| val formula = new RFormula().setFormula("id ~ a + b") | ||
| val original = | ||
|
|
||





There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this correct? Do you have some references? AFAIK, R formula drop the first category alphabetically ascending order when encode string/category feature (which is consistent with your PR description). I think
test("StringIndexer order types")in #17879 is correct. Could you double check this?