Skip to content

Commit fa959d8

Browse files
committed
separating udf
1 parent f190217 commit fa959d8

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.spark.sql.functions.{col, udf}
2929
/**
3030
* stop words list
3131
*/
32-
private object StopWords{
32+
private object StopWords {
3333

3434
/**
3535
* Use the same default stopwords list as scikit-learn.
@@ -80,7 +80,8 @@ private object StopWords{
8080

8181
/**
8282
* :: Experimental ::
83-
* A feature transformer that filters out stop words from input
83+
* A feature transformer that filters out stop words from input.
84+
* Note: null values from input array are preserved unless adding null to stopWords explicitly.
8485
* @see [[http://en.wikipedia.org/wiki/Stop_words]]
8586
*/
8687
@Experimental
@@ -124,15 +125,19 @@ class StopWordsRemover(override val uid: String)
124125

125126
override def transform(dataset: DataFrame): DataFrame = {
126127
val outputSchema = transformSchema(dataset.schema)
127-
val stopwordsSet = $(stopWords).toSet
128-
val lowerStopWords = stopwordsSet.map(_.toLowerCase)
129-
val t = udf { terms: Seq[String] =>
130-
if ($(caseSensitive)) {
131-
terms.filter(s => s == null || !stopwordsSet.contains(s))
128+
val t = if ($(caseSensitive)) {
129+
val stopWordsSet = $(stopWords).toSet
130+
udf { terms: Seq[String] =>
131+
terms.filter(s => !stopWordsSet.contains(s))
132+
}
132133
} else {
133-
terms.filter(s => s == null || !lowerStopWords.contains(s.toLowerCase))
134-
}
134+
val toLower = (s: String) => if (s != null) s.toLowerCase else s
135+
val lowerStopWords = $(stopWords).map(toLower(_)).toSet
136+
udf { terms: Seq[String] =>
137+
terms.filter(s => !lowerStopWords.contains(toLower(s)))
138+
}
135139
}
140+
136141
val metadata = outputSchema($(outputCol)).metadata
137142
dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))
138143
}
@@ -146,5 +151,5 @@ class StopWordsRemover(override val uid: String)
146151
StructType(outputFields)
147152
}
148153

149-
override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra)
154+
override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra)
150155
}

mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ object StopWordsRemoverSuite extends SparkFunSuite {
3333
}
3434

3535
class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext {
36-
import org.apache.spark.ml.feature.StopWordsRemoverSuite._
36+
import StopWordsRemoverSuite._
3737

3838
test("StopWordsRemover default") {
3939
val remover = new StopWordsRemover()

0 commit comments

Comments
 (0)