@@ -29,7 +29,7 @@ import org.apache.spark.sql.functions.{col, udf}
2929/**
3030 * stop words list
3131 */
32- private object StopWords {
32+ private object StopWords {
3333
3434 /**
3535 * Use the same default stopwords list as scikit-learn.
@@ -80,7 +80,8 @@ private object StopWords{
8080
8181/**
8282 * :: Experimental ::
83- * A feature transformer that filters out stop words from input
83+ * A feature transformer that filters out stop words from input.
84+ * Note: null values from input array are preserved unless adding null to stopWords explicitly.
8485 * @see [[http://en.wikipedia.org/wiki/Stop_words ]]
8586 */
8687@ Experimental
@@ -124,15 +125,19 @@ class StopWordsRemover(override val uid: String)
124125
125126 override def transform (dataset : DataFrame ): DataFrame = {
126127 val outputSchema = transformSchema(dataset.schema)
127- val stopwordsSet = $(stopWords).toSet
128- val lowerStopWords = stopwordsSet.map(_.toLowerCase)
129- val t = udf { terms : Seq [String ] =>
130- if ($(caseSensitive)) {
131- terms.filter(s => s == null || ! stopwordsSet.contains(s))
128+ val t = if ($(caseSensitive)) {
129+ val stopWordsSet = $(stopWords).toSet
130+ udf { terms : Seq [String ] =>
131+ terms.filter(s => ! stopWordsSet.contains(s))
132+ }
132133 } else {
133- terms.filter(s => s == null || ! lowerStopWords.contains(s.toLowerCase))
134- }
134+ val toLower = (s : String ) => if (s != null ) s.toLowerCase else s
135+ val lowerStopWords = $(stopWords).map(toLower(_)).toSet
136+ udf { terms : Seq [String ] =>
137+ terms.filter(s => ! lowerStopWords.contains(toLower(s)))
138+ }
135139 }
140+
136141 val metadata = outputSchema($(outputCol)).metadata
137142 dataset.select(col(" *" ), t(col($(inputCol))).as($(outputCol), metadata))
138143 }
@@ -146,5 +151,5 @@ class StopWordsRemover(override val uid: String)
146151 StructType (outputFields)
147152 }
148153
149- override def copy (extra : ParamMap ): RegexTokenizer = defaultCopy(extra)
154+ override def copy (extra : ParamMap ): StopWordsRemover = defaultCopy(extra)
150155}
0 commit comments