Skip to content

Commit ba34043

Browse files
author
Wayne Zhang
committed
address comments- spell out freq and update annotation and toLowerCase
1 parent 97e020f commit ba34043

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717

1818
package org.apache.spark.ml.feature
1919

20-
import scala.language.existentials
20+
import java.util.Locale
2121

22+
import scala.language.existentials
2223
import org.apache.hadoop.fs.Path
2324

2425
import org.apache.spark.SparkException
@@ -63,23 +64,25 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
6364
* Param for how to order labels of string column. The first label after ordering is assigned
6465
* an index of 0.
6566
* Options are:
66-
* - 'freq_desc': descending order by label frequency (most frequent label assigned 0)
67-
* - 'freq_asc': ascending order by label frequency (least frequent label assigned 0)
67+
* - 'frequency_desc': descending order by label frequency (most frequent label assigned 0)
68+
* - 'frequency_asc': ascending order by label frequency (least frequent label assigned 0)
6869
* - 'alphabet_desc': descending alphabetical order
6970
* - 'alphabet_asc': ascending alphabetical order
70-
* Default is 'freq_desc'.
71+
* Default is 'frequency_desc'.
7172
*
7273
* @group param
7374
*/
74-
@Since("2.2.0")
75+
@Since("2.3.0")
7576
final val stringOrderType: Param[String] = new Param(this, "stringOrderType",
76-
"The method used to order values of input column. " +
77-
s"Supported options: ${StringIndexer.supportedStringOrderType.mkString(", ")}.",
78-
(value: String) => StringIndexer.supportedStringOrderType.contains(value.toLowerCase))
77+
"how to order labels of string column. " +
78+
"The first label after ordering is assigned an index of 0. " +
79+
s"Supported options: ${StringIndexer.supportedStringOrderType.mkString(", ")}.",
80+
(value: String) => StringIndexer.supportedStringOrderType
81+
.contains(value.toLowerCase(Locale.ROOT)))
7982

8083
/** @group getParam */
81-
@Since("2.2.0")
82-
def getStringOrderType: String = $(stringOrderType)
84+
@Since("2.3.0")
85+
def getStringOrderType: String = $(stringOrderType).toLowerCase(Locale.ROOT)
8386

8487
/** Validates and transforms the input schema. */
8588
protected def validateAndTransformSchema(schema: StructType): StructType = {
@@ -138,7 +141,7 @@ class StringIndexer @Since("1.4.0") (
138141
val values = dataset.na.drop(Array($(inputCol)))
139142
.select(col($(inputCol)).cast(StringType))
140143
.rdd.map(_.getString(0))
141-
val labels = $(stringOrderType).toLowerCase match {
144+
val labels = this.getStringOrderType match {
142145
case StringIndexer.FREQ_DESC => values.countByValue().toSeq.sortBy(-_._2).map(_._1).toArray
143146
case StringIndexer.FREQ_ASC => values.countByValue().toSeq.sortBy(_._2).map(_._1).toArray
144147
case StringIndexer.ALPHABET_DESC => values.distinct.collect.sortWith(_ > _)
@@ -163,8 +166,8 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] {
163166
private[feature] val KEEP_INVALID: String = "keep"
164167
private[feature] val supportedHandleInvalids: Array[String] =
165168
Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)
166-
private[feature] val FREQ_DESC: String = "freq_desc"
167-
private[feature] val FREQ_ASC: String = "freq_asc"
169+
private[feature] val FREQ_DESC: String = "frequency_desc"
170+
private[feature] val FREQ_ASC: String = "frequency_asc"
168171
private[feature] val ALPHABET_DESC: String = "alphabet_desc"
169172
private[feature] val ALPHABET_ASC: String = "alphabet_asc"
170173
private[feature] val supportedStringOrderType: Array[String] =

0 commit comments

Comments
 (0)