Skip to content

Commit 41cd258

Browse files
committed
update stopwordsremover
1 parent 6deceec commit 41cd258

File tree

2 files changed

+48
-61
lines changed

2 files changed

+48
-61
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala

Lines changed: 26 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,14 @@ private[spark] object StopWords {
3333

3434
/** Read stop words list from resources */
3535
def readStopWords(language: String): Array[String] = {
36+
require(supportedLanguages.contains(language), s"$language is not in language list")
3637
val is = getClass.getResourceAsStream(s"/org/apache/spark/ml/feature/stopwords/$language.txt")
3738
scala.io.Source.fromInputStream(is).getLines().toArray
3839
}
3940

4041
/** Supported languages list must be lowercase */
41-
val supportedLanguages = Array("danish", "dutch", "english", "finnish", "french", "german",
42+
private val supportedLanguages = Set("danish", "dutch", "english", "finnish", "french", "german",
4243
"hungarian", "italian", "norwegian", "portuguese", "russian", "spanish", "swedish", "turkish")
43-
44-
/** Languages and stopwords map */
45-
val languageMap = supportedLanguages.map{
46-
language => language -> readStopWords(language)
47-
}.toMap
4844
}
4945

5046
/**
@@ -67,16 +63,13 @@ class StopWordsRemover(override val uid: String)
6763

6864
/**
6965
* the stop words set to be filtered out
70-
* Default: [[StopWords.languageMap("english")]]
66+
* Default: [[Array.empty]]
7167
* @group param
7268
*/
7369
val stopWords: StringArrayParam = new StringArrayParam(this, "stopWords", "stop words")
7470

7571
/** @group setParam */
76-
def setStopWords(value: Array[String]): this.type = {
77-
set(stopWords, value)
78-
set(language, "unknown")
79-
}
72+
def setStopWords(value: Array[String]): this.type = set(stopWords, value)
8073

8174
/** @group getParam */
8275
def getStopWords: Array[String] = $(stopWords)
@@ -96,70 +89,39 @@ class StopWordsRemover(override val uid: String)
9689
def getCaseSensitive: Boolean = $(caseSensitive)
9790

9891
/**
99-
* the language of stop words
100-
* Default: "english"
101-
* @group param
102-
*/
92+
* the language of stop words
93+
* Supported languages: Danish, Dutch, English, Finnish, French, German, Hungarian,
94+
* Italian, Norwegian, Portuguese, Russian, Spanish, Swedish, Turkish
95+
* Default: "English"
96+
* @group param
97+
*/
10398
val language: Param[String] = new Param[String](this, "language", "stopwords language")
10499

105100
/** @group setParam */
106-
def setLanguage(value: String): this.type = {
107-
val lang = value.toLowerCase
108-
require(StopWords.languageMap.contains(lang), s"$lang is not in language list")
109-
set(language, lang)
110-
set(stopWords, StopWords.languageMap(lang))
111-
}
101+
def setLanguage(value: String): this.type = set(language, value.toLowerCase)
112102

113103
/** @group getParam */
114104
def getLanguage: String = $(language)
115105

116-
/**
117-
* the ignored stop words set to be ignored out
118-
* Default: [[Array.empty]]
119-
* @group param
120-
*/
121-
val ignoredWords: StringArrayParam = new StringArrayParam(this, "ignoredWords",
122-
"the ignored stop words set to be ignored out")
123-
124-
/** @group setParam */
125-
def setIgnoredWords(value: Array[String]): this.type = set(ignoredWords, value)
126-
127-
/** @group getParam */
128-
def getIgnoredWords: Array[String] = $(ignoredWords)
129-
130-
/**
131-
* the additional stop words set to be filtered out
132-
* Default: [[Array.empty]]
133-
* @group param
134-
*/
135-
val additionalWords: StringArrayParam = new StringArrayParam(this, "additionalWords",
136-
"the additional stop words set to be filtered out")
137-
138-
/** @group setParam */
139-
def setAdditionalWords(value: Array[String]): this.type = set(additionalWords, value)
140-
141-
/** @group getParam */
142-
def getAdditionalWords: Array[String] = $(additionalWords)
143-
144-
setDefault(stopWords -> StopWords.languageMap("english"),
106+
setDefault(stopWords -> Array.empty[String],
145107
language -> "english",
146-
ignoredWords -> Array.empty[String],
147-
additionalWords -> Array.empty[String],
148108
caseSensitive -> false)
149109

150110
override def transform(dataset: DataFrame): DataFrame = {
111+
val stopWordsSet = if ($(stopWords).isEmpty) {
112+
StopWords.readStopWords($(language)).toSet
113+
} else {
114+
$(stopWords).toSet
115+
}
116+
151117
val outputSchema = transformSchema(dataset.schema)
152118
val t = if ($(caseSensitive)) {
153-
val stopWordsSet = ($(stopWords) ++ $(additionalWords)).toSet -- $(ignoredWords).toSet
154119
udf { terms: Seq[String] =>
155120
terms.filter(s => !stopWordsSet.contains(s))
156121
}
157122
} else {
158123
val toLower = (s: String) => if (s != null) s.toLowerCase else s
159-
val lowerStopWords = {
160-
($(stopWords) ++ $(additionalWords))
161-
.map(toLower(_)).toSet -- $(ignoredWords).map(toLower(_)).toSet
162-
}
124+
val lowerStopWords = stopWordsSet.map(toLower(_)).toSet
163125
udf { terms: Seq[String] =>
164126
terms.filter(s => !lowerStopWords.contains(toLower(s)))
165127
}
@@ -185,4 +147,11 @@ object StopWordsRemover extends DefaultParamsReadable[StopWordsRemover] {
185147

186148
@Since("1.6.0")
187149
override def load(path: String): StopWordsRemover = super.load(path)
150+
151+
/**
152+
* Stop words for the language
153+
* Supported languages: Danish, Dutch, English, Finnish, French, German, Hungarian,
154+
* Italian, Norwegian, Portuguese, Russian, Spanish, Swedish, Turkish
155+
*/
156+
def loadStopWords(language: String): Array[String] = StopWords.readStopWords(language.toLowerCase)
188157
}

mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,24 @@ class StopWordsRemoverSuite
5454
testStopWordsRemover(remover, dataSet)
5555
}
5656

57+
test("StopWordsRemover with particular stop words list") {
58+
val stopWords = Array("test", "a", "an", "the")
59+
val remover = new StopWordsRemover()
60+
.setInputCol("raw")
61+
.setOutputCol("filtered")
62+
.setStopWords(stopWords)
63+
val dataSet = sqlContext.createDataFrame(Seq(
64+
(Seq("test", "test"), Seq()),
65+
(Seq("a", "b", "c", "d"), Seq("b", "c")),
66+
(Seq("a", "the", "an"), Seq()),
67+
(Seq("A", "The", "AN"), Seq()),
68+
(Seq(null), Seq(null)),
69+
(Seq(), Seq())
70+
)).toDF("raw", "expected")
71+
72+
testStopWordsRemover(remover, dataSet)
73+
}
74+
5775
test("StopWordsRemover case sensitive") {
5876
val remover = new StopWordsRemover()
5977
.setInputCol("raw")
@@ -68,11 +86,11 @@ class StopWordsRemoverSuite
6886
}
6987

7088
test("StopWordsRemover with ignored words") {
71-
val ignoredWords = Array("a")
89+
val stopWords = StopWordsRemover.loadStopWords("english").toSet -- Set("a")
7290
val remover = new StopWordsRemover()
7391
.setInputCol("raw")
7492
.setOutputCol("filtered")
75-
.setIgnoredWords(ignoredWords)
93+
.setStopWords(stopWords.toArray)
7694
val dataSet = sqlContext.createDataFrame(Seq(
7795
(Seq("python", "scala", "a"), Seq("python", "scala", "a")),
7896
(Seq("Python", "Scala", "swift"), Seq("Python", "Scala", "swift"))
@@ -82,11 +100,11 @@ class StopWordsRemoverSuite
82100
}
83101

84102
test("StopWordsRemover with additional words") {
85-
val additionalWords = Array("python", "scala")
103+
val stopWords = StopWordsRemover.loadStopWords("english").toSet ++ Set("python", "scala")
86104
val remover = new StopWordsRemover()
87105
.setInputCol("raw")
88106
.setOutputCol("filtered")
89-
.setAdditionalWords(additionalWords)
107+
.setStopWords(stopWords.toArray)
90108
val dataSet = sqlContext.createDataFrame(Seq(
91109
(Seq("python", "scala", "a"), Seq()),
92110
(Seq("Python", "Scala", "swift"), Seq("swift"))

0 commit comments

Comments
 (0)