Skip to content

Commit e2d0aba

Browse files
committed
address feedback
1 parent 9f488fb commit e2d0aba

File tree

4 files changed

+27
-19
lines changed

4 files changed

+27
-19
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class StopWordsRemover(override val uid: String)
4747
/**
4848
* The words to be filtered out.
4949
* Default: English stop words
50-
* @see [[StopWordsRemover.loadStopWords()]]
50+
* @see [[StopWordsRemover.loadDefaultStopWords()]]
5151
* @group param
5252
*/
5353
val stopWords: StringArrayParam =
@@ -65,15 +65,15 @@ class StopWordsRemover(override val uid: String)
6565
* @group param
6666
*/
6767
val caseSensitive: BooleanParam = new BooleanParam(this, "caseSensitive",
68-
"whether to do a case-sensitive comparison over the stop stop words")
68+
"whether to do a case-sensitive comparison over the stop words")
6969

7070
/** @group setParam */
7171
def setCaseSensitive(value: Boolean): this.type = set(caseSensitive, value)
7272

7373
/** @group getParam */
7474
def getCaseSensitive: Boolean = $(caseSensitive)
7575

76-
setDefault(stopWords -> StopWordsRemover.loadStopWords("english"), caseSensitive -> false)
76+
setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"), caseSensitive -> false)
7777

7878
@Since("2.0.0")
7979
override def transform(dataset: Dataset[_]): DataFrame = {
@@ -108,20 +108,21 @@ class StopWordsRemover(override val uid: String)
108108
@Since("1.6.0")
109109
object StopWordsRemover extends DefaultParamsReadable[StopWordsRemover] {
110110

111-
private val supportedLanguages = Set("danish", "dutch", "english", "finnish", "french", "german",
111+
private[feature]
112+
val supportedLanguages = Set("danish", "dutch", "english", "finnish", "french", "german",
112113
"hungarian", "italian", "norwegian", "portuguese", "russian", "spanish", "swedish", "turkish")
113114

114115
@Since("1.6.0")
115116
override def load(path: String): StopWordsRemover = super.load(path)
116117

117118
/**
118-
* Load stop words for the language
119+
* Loads the default stop words for the given language.
119120
* Supported languages: danish, dutch, english, finnish, french, german, hungarian,
120121
* italian, norwegian, portuguese, russian, spanish, swedish, turkish
121122
* @see [[http://anoncvs.postgresql.org/cvsweb.cgi/pgsql/src/backend/snowball/stopwords/]]
122123
*/
123124
@Since("2.0.0")
124-
def loadStopWords(language: String): Array[String] = {
125+
def loadDefaultStopWords(language: String): Array[String] = {
125126
require(supportedLanguages.contains(language),
126127
s"$language is not in the supported language list: ${supportedLanguages.mkString(", ")}.")
127128
val is = getClass.getResourceAsStream(s"/org/apache/spark/ml/feature/stopwords/$language.txt")

mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,15 @@ class StopWordsRemoverSuite
8585
testStopWordsRemover(remover, dataSet)
8686
}
8787

88+
test("default stop words of supported languages are not empty") {
89+
StopWordsRemover.supportedLanguages.foreach { lang =>
90+
assert(StopWordsRemover.loadDefaultStopWords(lang).nonEmpty,
91+
s"The default stop words of $lang cannot be empty.")
92+
}
93+
}
94+
8895
test("StopWordsRemover with language selection") {
89-
val stopWords = StopWordsRemover.loadStopWords("turkish")
96+
val stopWords = StopWordsRemover.loadDefaultStopWords("turkish")
9097
val remover = new StopWordsRemover()
9198
.setInputCol("raw")
9299
.setOutputCol("filtered")
@@ -100,7 +107,7 @@ class StopWordsRemoverSuite
100107
}
101108

102109
test("StopWordsRemover with ignored words") {
103-
val stopWords = StopWordsRemover.loadStopWords("english").toSet -- Set("a")
110+
val stopWords = StopWordsRemover.loadDefaultStopWords("english").toSet -- Set("a")
104111
val remover = new StopWordsRemover()
105112
.setInputCol("raw")
106113
.setOutputCol("filtered")
@@ -114,7 +121,7 @@ class StopWordsRemoverSuite
114121
}
115122

116123
test("StopWordsRemover with additional words") {
117-
val stopWords = StopWordsRemover.loadStopWords("english").toSet ++ Set("python", "scala")
124+
val stopWords = StopWordsRemover.loadDefaultStopWords("english").toSet ++ Set("python", "scala")
118125
val remover = new StopWordsRemover()
119126
.setInputCol("raw")
120127
.setOutputCol("filtered")

python/pyspark/ml/feature.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1778,7 +1778,7 @@ def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=
17781778
@since("1.6.0")
17791779
def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False):
17801780
"""
1781-
setParams(self, inputCol="input", outputCol="output", stopWords=None, caseSensitive=false)
1781+
setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false)
17821782
Sets params for this StopWordRemover.
17831783
"""
17841784
kwargs = self.setParams._input_kwargs
@@ -1787,43 +1787,43 @@ def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive
17871787
@since("1.6.0")
17881788
def setStopWords(self, value):
17891789
"""
1790-
Specify the stopwords to be filtered.
1790+
Sets the value of :py:attr:`stopWords`.
17911791
"""
17921792
self._set(stopWords=value)
17931793
return self
17941794

17951795
@since("1.6.0")
17961796
def getStopWords(self):
17971797
"""
1798-
Get the stopwords.
1798+
Gets the value of :py:attr:`stopWords` or its default value.
17991799
"""
18001800
return self.getOrDefault(self.stopWords)
18011801

18021802
@since("1.6.0")
18031803
def setCaseSensitive(self, value):
18041804
"""
1805-
Set whether to do a case sensitive comparison over the stop words
1805+
Sets the value of :py:attr:`caseSensitive`.
18061806
"""
18071807
self._set(caseSensitive=value)
18081808
return self
18091809

18101810
@since("1.6.0")
18111811
def getCaseSensitive(self):
18121812
"""
1813-
Get whether to do a case sensitive comparison over the stop words.
1813+
Gets the value of :py:attr:`caseSensitive` or its default value.
18141814
"""
18151815
return self.getOrDefault(self.caseSensitive)
18161816

18171817
@staticmethod
18181818
@since("2.0.0")
1819-
def loadStopWords(language):
1819+
def loadDefaultStopWords(language):
18201820
"""
1821-
Load stop words for the language
1821+
Loads the default stop words for the given language.
18221822
Supported languages: danish, dutch, english, finnish, french, german, hungarian,
18231823
italian, norwegian, portuguese, russian, spanish, swedish, turkish
18241824
"""
18251825
stopWordsObj = _jvm().org.apache.spark.ml.feature.StopWordsRemover
1826-
return list(stopWordsObj.loadStopWords(language))
1826+
return list(stopWordsObj.loadDefaultStopWords(language))
18271827

18281828

18291829
@inherit_doc
@@ -1875,7 +1875,7 @@ def __init__(self, inputCol=None, outputCol=None):
18751875
@since("1.3.0")
18761876
def setParams(self, inputCol=None, outputCol=None):
18771877
"""
1878-
setParams(self, inputCol="input", outputCol="output")
1878+
setParams(self, inputCol=None, outputCol=None)
18791879
Sets params for this Tokenizer.
18801880
"""
18811881
kwargs = self.setParams._input_kwargs

python/pyspark/ml/tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ def test_stopwordsremover(self):
418418
transformedDF = stopWordRemover.transform(dataset)
419419
self.assertEqual(transformedDF.head().output, ["a"])
420420
# with language selection
421-
stopwords = StopWordsRemover.loadStopWords("turkish")
421+
stopwords = StopWordsRemover.loadDefaultStopWords("turkish")
422422
dataset = sqlContext.createDataFrame([Row(input=["acaba", "ama", "biri"])])
423423
stopWordRemover.setStopWords(stopwords)
424424
self.assertEqual(stopWordRemover.getStopWords(), stopwords)

0 commit comments

Comments
 (0)