Skip to content

Commit 01471ec

Browse files
committed
address feedback
1 parent cb786ee commit 01471ec

File tree

4 files changed

+32
-24
lines changed

4 files changed

+32
-24
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class StopWordsRemover(override val uid: String)
4949
/**
5050
* The words to be filtered out.
5151
* Default: English stop words
52-
* @see [[StopWordsRemover.loadStopWords()]]
52+
* @see [[StopWordsRemover.loadDefaultStopWords()]]
5353
* @group param
5454
*/
5555
val stopWords: StringArrayParam =
@@ -89,7 +89,7 @@ class StopWordsRemover(override val uid: String)
8989
/** @group getParam */
9090
def getLocale: String = $(locale)
9191

92-
setDefault(stopWords -> StopWordsRemover.loadStopWords("english"),
92+
setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"),
9393
caseSensitive -> false, locale -> "en")
9494

9595
@Since("2.0.0")
@@ -123,20 +123,21 @@ object StopWordsRemover extends DefaultParamsReadable[StopWordsRemover] {
123123

124124
private def loadLocale(value : String) = new Locale(value)
125125

126-
private val supportedLanguages = Set("danish", "dutch", "english", "finnish", "french", "german",
126+
private[feature]
127+
val supportedLanguages = Set("danish", "dutch", "english", "finnish", "french", "german",
127128
"hungarian", "italian", "norwegian", "portuguese", "russian", "spanish", "swedish", "turkish")
128129

129130
@Since("1.6.0")
130131
override def load(path: String): StopWordsRemover = super.load(path)
131132

132133
/**
133-
* Load stop words for the language
134+
* Loads the default stop words for the given language.
134135
* Supported languages: danish, dutch, english, finnish, french, german, hungarian,
135136
* italian, norwegian, portuguese, russian, spanish, swedish, turkish
136137
* @see [[http://anoncvs.postgresql.org/cvsweb.cgi/pgsql/src/backend/snowball/stopwords/]]
137138
*/
138139
@Since("2.0.0")
139-
def loadStopWords(language: String): Array[String] = {
140+
def loadDefaultStopWords(language: String): Array[String] = {
140141
require(supportedLanguages.contains(language),
141142
s"$language is not in the supported language list: ${supportedLanguages.mkString(", ")}.")
142143
val is = getClass.getResourceAsStream(s"/org/apache/spark/ml/feature/stopwords/$language.txt")

mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,15 @@ class StopWordsRemoverSuite
8585
testStopWordsRemover(remover, dataSet)
8686
}
8787

88+
test("default stop words of supported languages are not empty") {
89+
StopWordsRemover.supportedLanguages.foreach { lang =>
90+
assert(StopWordsRemover.loadDefaultStopWords(lang).nonEmpty,
91+
s"The default stop words of $lang cannot be empty.")
92+
}
93+
}
94+
8895
test("StopWordsRemover with language selection") {
89-
val stopWords = StopWordsRemover.loadStopWords("turkish")
96+
val stopWords = StopWordsRemover.loadDefaultStopWords("turkish")
9097
val remover = new StopWordsRemover()
9198
.setInputCol("raw")
9299
.setOutputCol("filtered")
@@ -100,7 +107,7 @@ class StopWordsRemoverSuite
100107
}
101108

102109
test("StopWordsRemover with ignored words") {
103-
val stopWords = StopWordsRemover.loadStopWords("english").toSet -- Set("a")
110+
val stopWords = StopWordsRemover.loadDefaultStopWords("english").toSet -- Set("a")
104111
val remover = new StopWordsRemover()
105112
.setInputCol("raw")
106113
.setOutputCol("filtered")
@@ -114,7 +121,7 @@ class StopWordsRemoverSuite
114121
}
115122

116123
test("StopWordsRemover with additional words") {
117-
val stopWords = StopWordsRemover.loadStopWords("english").toSet ++ Set("python", "scala")
124+
val stopWords = StopWordsRemover.loadDefaultStopWords("english").toSet ++ Set("python", "scala")
118125
val remover = new StopWordsRemover()
119126
.setInputCol("raw")
120127
.setOutputCol("filtered")

python/pyspark/ml/feature.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1743,13 +1743,13 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl
17431743
def __init__(self, inputCol=None, outputCol=None, stopWords=None,
17441744
caseSensitive=False, locale="en"):
17451745
"""
1746-
__init__(self, inputCol=None, outputCol=None, stopWords=None,\
1747-
caseSensitive=false, locale="en")
1746+
__init__(self, inputCol=None, outputCol=None, stopWords=None, \
1747+
caseSensitive=false, locale="en")
17481748
"""
17491749
super(StopWordsRemover, self).__init__()
17501750
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover",
17511751
self.uid)
1752-
self._setDefault(stopWords=StopWordsRemover.loadStopWords("english"),
1752+
self._setDefault(stopWords=StopWordsRemover.loadDefaultStopWords("english"),
17531753
caseSensitive=False, locale="en")
17541754
kwargs = self.__init__._input_kwargs
17551755
self.setParams(**kwargs)
@@ -1759,8 +1759,8 @@ def __init__(self, inputCol=None, outputCol=None, stopWords=None,
17591759
def setParams(self, inputCol=None, outputCol=None, stopWords=None,
17601760
caseSensitive=False, locale="en"):
17611761
"""
1762-
setParams(self, inputCol="input", outputCol="output", stopWords=None,
1763-
caseSensitive=false, locale="en")
1762+
setParams(self, inputCol="input", outputCol="output", stopWords=None, \
1763+
caseSensitive=false, locale="en")
17641764
Sets params for this StopWordRemover.
17651765
"""
17661766
kwargs = self.setParams._input_kwargs
@@ -1769,56 +1769,56 @@ def setParams(self, inputCol=None, outputCol=None, stopWords=None,
17691769
@since("1.6.0")
17701770
def setStopWords(self, value):
17711771
"""
1772-
Specify the stopwords to be filtered.
1772+
Sets the value of :py:attr:`stopWords`.
17731773
"""
17741774
return self._set(stopWords=value)
17751775

17761776
@since("1.6.0")
17771777
def getStopWords(self):
17781778
"""
1779-
Get the stopwords.
1779+
Gets the value of :py:attr:`stopWords` or its default value.
17801780
"""
17811781
return self.getOrDefault(self.stopWords)
17821782

17831783
@since("1.6.0")
17841784
def setCaseSensitive(self, value):
17851785
"""
1786-
Set whether to do a case sensitive comparison over the stop words
1786+
Sets the value of :py:attr:`caseSensitive`.
17871787
"""
17881788
return self._set(caseSensitive=value)
17891789

17901790
@since("1.6.0")
17911791
def getCaseSensitive(self):
17921792
"""
1793-
Get whether to do a case sensitive comparison over the stop words.
1793+
Gets the value of :py:attr:`caseSensitive` or its default value.
17941794
"""
17951795
return self.getOrDefault(self.caseSensitive)
17961796

17971797
@since("2.0.0")
17981798
def setLocale(self, value):
17991799
"""
1800-
Set locale for doing a case sensitive comparison
1800+
Sets the value of :py:attr:`locale`.
18011801
"""
18021802
self._set(caseSensitive=value)
18031803
return self
18041804

18051805
@since("2.0.0")
18061806
def getLocale(self):
18071807
"""
1808-
Get locale for doing a case sensitive comparison
1808+
Gets the value of :py:attr:`locale`.
18091809
"""
18101810
return self.getOrDefault(self.caseSensitive)
18111811

18121812
@staticmethod
18131813
@since("2.0.0")
1814-
def loadStopWords(language):
1814+
def loadDefaultStopWords(language):
18151815
"""
1816-
Load stop words for the language
1816+
Loads the default stop words for the given language.
18171817
Supported languages: danish, dutch, english, finnish, french, german, hungarian,
18181818
italian, norwegian, portuguese, russian, spanish, swedish, turkish
18191819
"""
18201820
stopWordsObj = _jvm().org.apache.spark.ml.feature.StopWordsRemover
1821-
return list(stopWordsObj.loadStopWords(language))
1821+
return list(stopWordsObj.loadDefaultStopWords(language))
18221822

18231823

18241824
@inherit_doc

python/pyspark/ml/tests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,15 +410,15 @@ def test_stopwordsremover(self):
410410
self.assertEqual(transformedDF.head().output, ["panda"])
411411
self.assertEqual(type(stopWordRemover.getStopWords()), list)
412412
self.assertTrue(isinstance(stopWordRemover.getStopWords()[0], basestring))
413-
# with particular stop words list
413+
# Custom
414414
stopwords = ["panda"]
415415
stopWordRemover.setStopWords(stopwords)
416416
self.assertEqual(stopWordRemover.getInputCol(), "input")
417417
self.assertEqual(stopWordRemover.getStopWords(), stopwords)
418418
transformedDF = stopWordRemover.transform(dataset)
419419
self.assertEqual(transformedDF.head().output, ["a"])
420420
# with language selection
421-
stopwords = StopWordsRemover.loadStopWords("turkish")
421+
stopwords = StopWordsRemover.loadDefaultStopWords("turkish")
422422
dataset = sqlContext.createDataFrame([Row(input=["acaba", "ama", "biri"])])
423423
stopWordRemover.setStopWords(stopwords)
424424
self.assertEqual(stopWordRemover.getStopWords(), stopwords)

0 commit comments

Comments
 (0)