Skip to content

Commit d6ae7d4

Browse files
committed
[SPARK-14665][ML][PYTHON] Fixed bug with StopWordsRemover default stopwords
## What changes were proposed in this pull request? The default stopwords were a Java object. They are no longer. ## How was this patch tested? Unit test which failed before the fix Author: Joseph K. Bradley <[email protected]> Closes #12422 from jkbradley/pyspark-stopwords.
1 parent 83af297 commit d6ae7d4

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

python/pyspark/ml/feature.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1765,7 +1765,7 @@ def __init__(self, inputCol=None, outputCol=None, stopWords=None,
17651765
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover",
17661766
self.uid)
17671767
stopWordsObj = _jvm().org.apache.spark.ml.feature.StopWords
1768-
defaultStopWords = stopWordsObj.English()
1768+
defaultStopWords = list(stopWordsObj.English())
17691769
self._setDefault(stopWords=defaultStopWords, caseSensitive=False)
17701770
kwargs = self.__init__._input_kwargs
17711771
self.setParams(**kwargs)

python/pyspark/ml/tests.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import sys
2323
if sys.version > '3':
2424
xrange = range
25+
basestring = str
2526

2627
try:
2728
import xmlrunner
@@ -398,6 +399,8 @@ def test_stopwordsremover(self):
398399
self.assertEqual(stopWordRemover.getInputCol(), "input")
399400
transformedDF = stopWordRemover.transform(dataset)
400401
self.assertEqual(transformedDF.head().output, ["panda"])
402+
self.assertEqual(type(stopWordRemover.getStopWords()), list)
403+
self.assertTrue(isinstance(stopWordRemover.getStopWords()[0], basestring))
401404
# Custom
402405
stopwords = ["panda"]
403406
stopWordRemover.setStopWords(stopwords)

0 commit comments

Comments
 (0)