@@ -122,7 +122,7 @@ class StringIndexer @Since("1.4.0") (
122122 /** @group setParam */
123123 @ Since (" 2.2.0" )
124124 def setStringOrderType (value : String ): this .type = set(stringOrderType, value)
125- setDefault(stringOrderType, " freq_desc " )
125+ setDefault(stringOrderType, StringIndexer . FREQ_DESC )
126126
127127 /** @group setParam */
128128 @ Since (" 1.4.0" )
@@ -138,11 +138,11 @@ class StringIndexer @Since("1.4.0") (
138138 val values = dataset.na.drop(Array ($(inputCol)))
139139 .select(col($(inputCol)).cast(StringType ))
140140 .rdd.map(_.getString(0 ))
141- val labels = $(stringOrderType) match {
142- case " freq_desc " => values.countByValue().toSeq.sortBy(- _._2).map(_._1).toArray
143- case " freq_asc " => values.countByValue().toSeq.sortBy(_._2).map(_._1).toArray
144- case " alphabet_desc " => values.distinct.collect.sortWith(_ > _)
145- case " alphabet_asc " => values.distinct.collect.sortWith(_ < _)
141+ val labels = $(stringOrderType).toLowerCase match {
142+ case StringIndexer . FREQ_DESC => values.countByValue().toSeq.sortBy(- _._2).map(_._1).toArray
143+ case StringIndexer . FREQ_ASC => values.countByValue().toSeq.sortBy(_._2).map(_._1).toArray
144+ case StringIndexer . ALPHABET_DESC => values.distinct.collect.sortWith(_ > _)
145+ case StringIndexer . ALPHABET_ASC => values.distinct.collect.sortWith(_ < _)
146146 }
147147 copyValues(new StringIndexerModel (uid, labels).setParent(this ))
148148 }
@@ -163,8 +163,12 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] {
163163 private [feature] val KEEP_INVALID : String = " keep"
164164 private [feature] val supportedHandleInvalids : Array [String ] =
165165 Array (SKIP_INVALID , ERROR_INVALID , KEEP_INVALID )
166+ private [feature] val FREQ_DESC : String = " freq_desc"
167+ private [feature] val FREQ_ASC : String = " freq_asc"
168+ private [feature] val ALPHABET_DESC : String = " alphabet_desc"
169+ private [feature] val ALPHABET_ASC : String = " alphabet_asc"
166170 private [feature] val supportedStringOrderType : Array [String ] =
167- Array (" freq_desc " , " freq_asc " , " alphabet_desc " , " alphabet_asc " )
171+ Array (FREQ_DESC , FREQ_ASC , ALPHABET_DESC , ALPHABET_ASC )
168172
169173 @ Since (" 1.6.0" )
170174 override def load (path : String ): StringIndexer = super .load(path)
0 commit comments