Skip to content

Commit dec0634

Browse files
committed
fix locale
1 parent 01471ec commit dec0634

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,18 @@ class StopWordsRemover(override val uid: String)
9494

9595
@Since("2.0.0")
9696
override def transform(dataset: Dataset[_]): DataFrame = {
97-
val stopWordsSet = if ($(caseSensitive)) {
98-
$(stopWords).toSet
97+
val outputSchema = transformSchema(dataset.schema)
98+
val t = if ($(caseSensitive)) {
99+
val stopWordsSet = $(stopWords).toSet
100+
udf { terms: Seq[String] =>
101+
terms.filterNot(stopWordsSet.contains)
102+
}
99103
} else {
100104
val loadedLocale = StopWordsRemover.loadLocale($(locale))
101-
$(stopWords).filterNot(_ == null).map(_.toLowerCase(loadedLocale)).toSet
102-
}
103-
val outputSchema = transformSchema(dataset.schema)
104-
val t = udf { terms: Seq[String] =>
105-
terms.filterNot(stopWordsSet.contains)
105+
val stopWordsSet = $(stopWords).filterNot(_ == null).map(_.toLowerCase(loadedLocale)).toSet
106+
udf { terms: Seq[String] =>
107+
terms.filterNot(term => stopWordsSet.contains(term.toLowerCase(loadedLocale)))
108+
}
106109
}
107110
val metadata = outputSchema($(outputCol)).metadata
108111
dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))

0 commit comments

Comments
 (0)