From 78c0362891c42c0d92e4275ce2a3285c2a67e153 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 11 Aug 2015 00:27:14 -0700 Subject: [PATCH 01/10] Start work towards adding stopwordpython interface. --- python/pyspark/ml/feature.py | 40 +++++++++++++++++++++++++++++++++++- python/pyspark/ml/tests.py | 10 +++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index cb4dfa21298ce..d2ef957debc3b 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -29,7 +29,8 @@ __all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', - 'Word2Vec', 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel'] + 'Word2Vec', 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel', + 'StopWordsRemover'] @inherit_doc @@ -761,6 +762,43 @@ class StringIndexerModel(JavaModel): """ +class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + A feature transformer that filters out stop words from input. + Note: null values from input array are preserved unless adding null to stopWords explicitly. + """ + def __init__(self, inputCol=None, outputCol=None, stopWords=[]): + """ + Initialize this instace of the StopWordsRemover. + """ + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover") + self.uid = self._java_obj.uid + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, inputCol=None, outputCol=None, stopWords=[]): + """ + setParams(self, inputCol="input", outputCol="output") + Sets params for this StopWordRemover. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setStopWords(self, value): + """ + Specify the stopwords to be filtered. + """ + return self.setStopWords(value) + + def getStopWords(self): + """ + Get the stopwords. + """ + return self._java_obj.getStopWords() + @inherit_doc @ignore_unicode_prefix class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index c151d21fd661a..daca592c6c230 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -263,6 +263,16 @@ def test_ngram(self): transformedDF = ngram0.transform(dataset) self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"]) + def test_stopwordsremover(self): + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame([ + ([["a", "b", "c", "d", "e"]])], ["input"]) + stopwordremover = StopWordsRemover() + stopwords = ["a", "b", "c", "d"] + stopwordremover.setStopWords(stopwords) + self.assertEquals(stopwordremover.getStopWords(), stopwords) + transformedDF = stopwordremover.transform(dataset) + self.assertEquals(transformedDF.head().output, ["e", "e"]) if __name__ == "__main__": unittest.main() From 7a65dc3d647721dc4a334e8aa6da8b338dbd746d Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 11 Aug 2015 16:07:12 -0700 Subject: [PATCH 02/10] pep8 fix (add second blank line between classes) --- python/pyspark/ml/feature.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index d2ef957debc3b..87d51f32f647b 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -799,6 +799,7 @@ def getStopWords(self): """ return self._java_obj.getStopWords() + @inherit_doc @ignore_unicode_prefix class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): From c634fa14a0997ee7746ad670b6ff9dbbea9db2bc Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 11 Aug 2015 17:22:03 -0700 Subject: [PATCH 03/10] Make the stopwrods param settable from Python & Java, fix test --- .../spark/ml/feature/StopWordsRemover.scala | 4 ++-- python/pyspark/ml/feature.py | 15 +++++++++++---- python/pyspark/ml/tests.py | 8 +++++--- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 3cc41424460f2..f08104ff6482a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.Experimental import org.apache.spark.ml.Transformer import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.param.{ParamMap, BooleanParam, Param} +import org.apache.spark.ml.param.{ParamMap, StringArrayParam, BooleanParam, Param} import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.{StringType, StructField, ArrayType, StructType} @@ -100,7 +100,7 @@ class StopWordsRemover(override val uid: String) * the stop words set to be filtered out * @group param */ - val stopWords: Param[Array[String]] = new Param(this, "stopWords", "stop words") + val stopWords: StringArrayParam = new StringArrayParam(this, "stopWords", "stop words") /** @group setParam */ def setStopWords(value: Array[String]): this.type = set(stopWords, value) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 87d51f32f647b..7adcbd941c0d9 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -769,12 +769,18 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol): A feature transformer that filters out stop words from input. Note: null values from input array are preserved unless adding null to stopWords explicitly. """ + # a placeholder to make the stopwords show up in generated doc + stopWords = Param(Params._dummy(), "stopWords", "The words to be filtered out") + + @keyword_only def __init__(self, inputCol=None, outputCol=None, stopWords=[]): """ Initialize this instace of the StopWordsRemover. """ - self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover") - self.uid = self._java_obj.uid + super(StopWordsRemover, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover", + self.uid) + self.stopWords = Param(self, "stopWords", "The words to be filtered out") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -791,13 +797,14 @@ def setStopWords(self, value): """ Specify the stopwords to be filtered. """ - return self.setStopWords(value) + self._paramMap[self.stopWords] = value + return self def getStopWords(self): """ Get the stopwords. """ - return self._java_obj.getStopWords() + return self.getOrDefault(self.stopWords) @inherit_doc diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index daca592c6c230..d7aa96ef80367 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -265,14 +265,16 @@ def test_ngram(self): def test_stopwordsremover(self): sqlContext = SQLContext(self.sc) + data = ["a", "b", "c", "d", "e"] dataset = sqlContext.createDataFrame([ - ([["a", "b", "c", "d", "e"]])], ["input"]) - stopwordremover = StopWordsRemover() + ([data])], ["input"]) + stopwordremover = StopWordsRemover(inputCol="input", outputCol="output") stopwords = ["a", "b", "c", "d"] stopwordremover.setStopWords(stopwords) + self.assertEquals(stopwordremover.getInputCol(), "input") self.assertEquals(stopwordremover.getStopWords(), stopwords) transformedDF = stopwordremover.transform(dataset) - self.assertEquals(transformedDF.head().output, ["e", "e"]) + self.assertEquals(transformedDF.head().output, ["e"]) if __name__ == "__main__": unittest.main() From 12fb73e55438d7e9e1961ce46954684624362482 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 26 Aug 2015 13:49:49 -0700 Subject: [PATCH 04/10] Some progress on code review, add caseSensitive option for stopwords, use english stop words as default --- .../spark/ml/feature/StopWordsRemover.scala | 6 ++-- python/pyspark/ml/feature.py | 31 ++++++++++++++++--- 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 5d77ea08db657..918fee1fd0316 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -29,14 +29,14 @@ import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructTyp /** * stop words list */ -private object StopWords { +protected[spark] object StopWords { /** * Use the same default stopwords list as scikit-learn. * The original list can be found from "Glasgow Information Retrieval Group" * [[http://ir.dcs.gla.ac.uk/resources/linguistic_utils/stop_words]] */ - val EnglishStopWords = Array( "a", "about", "above", "across", "after", "afterwards", "again", + val ENGLISH_STOP_WORDS = Array( "a", "about", "above", "across", "after", "afterwards", "again", "against", "all", "almost", "alone", "along", "already", "also", "although", "always", "am", "among", "amongst", "amoungst", "amount", "an", "and", "another", "any", "anyhow", "anyone", "anything", "anyway", "anywhere", "are", @@ -121,7 +121,7 @@ class StopWordsRemover(override val uid: String) /** @group getParam */ def getCaseSensitive: Boolean = $(caseSensitive) - setDefault(stopWords -> StopWords.EnglishStopWords, caseSensitive -> false) + setDefault(stopWords -> StopWords.ENGLISH_STOP_WORDS, caseSensitive -> false) override def transform(dataset: DataFrame): DataFrame = { val outputSchema = transformSchema(dataset.schema) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 4fc649723b2d6..8371545f4e30e 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -818,7 +818,6 @@ class StringIndexerModel(JavaModel): Model fitted by StringIndexer. """ - class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol): """ .. note:: Experimental @@ -828,23 +827,32 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol): """ # a placeholder to make the stopwords show up in generated doc stopWords = Param(Params._dummy(), "stopWords", "The words to be filtered out") + caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " + + "comparison over the stop words") + ENGLISH_STOP_WORDS = SparkContext._jvm.org.apache.spark.ml.feature.StopWords$.ENGLISH_STOP_WORDS + @keyword_only - def __init__(self, inputCol=None, outputCol=None, stopWords=[]): + def __init__(self, inputCol=None, outputCol=None, stopWords=ENGLISH_STOP_WORDS, + caseSensitive=false): """ - Initialize this instace of the StopWordsRemover. + __init__(self, inputCol=None, outputCol=None, stopWords=ENGLISH_STOP_WORDS, + caseSensitive=false) """ super(StopWordsRemover, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover", self.uid) self.stopWords = Param(self, "stopWords", "The words to be filtered out") + self.caseSensitive = Param(self._dummy(), "caseSensitive", "whether to do a case " + + "sensitive comparison over the stop words") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only - def setParams(self, inputCol=None, outputCol=None, stopWords=[]): + def setParams(self, inputCol=None, outputCol=None, stopWords=ENGLISH_STOP_WORDS, + caseSensitive=false): """ - setParams(self, inputCol="input", outputCol="output") + setParams(self, inputCol="input", outputCol="output", caseSensitive=false) Sets params for this StopWordRemover. """ kwargs = self.setParams._input_kwargs @@ -863,6 +871,19 @@ def getStopWords(self): """ return self.getOrDefault(self.stopWords) + def setCaseSensitive(self, value): + """ + Set whether to do a case sensitive comparison over the stop words + """ + self._paramMap[self.caseSensitive] = value + return self + + def getCaseSensitive(self): + """ + Get whether to do a case sensitive comparison over the stop words. + """ + return self.getOrDefault(self.caseSensitive) + @inherit_doc @ignore_unicode_prefix From 84bc507492829e7174288fe7399c892bf1391a3e Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 26 Aug 2015 15:32:46 -0700 Subject: [PATCH 05/10] Use the english stop words from scala by default --- python/pyspark/ml/feature.py | 22 +++++++++++++--------- python/pyspark/ml/tests.py | 11 ++++++++--- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 8371545f4e30e..5412f39b21dc8 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -22,7 +22,7 @@ from pyspark.rdd import ignore_unicode_prefix from pyspark.ml.param.shared import * from pyspark.ml.util import keyword_only -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm from pyspark.mllib.common import inherit_doc from pyspark.mllib.linalg import _convert_to_vector @@ -829,14 +829,12 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol): stopWords = Param(Params._dummy(), "stopWords", "The words to be filtered out") caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " + "comparison over the stop words") - - ENGLISH_STOP_WORDS = SparkContext._jvm.org.apache.spark.ml.feature.StopWords$.ENGLISH_STOP_WORDS @keyword_only - def __init__(self, inputCol=None, outputCol=None, stopWords=ENGLISH_STOP_WORDS, - caseSensitive=false): + def __init__(self, inputCol=None, outputCol=None, stopWords=None, + caseSensitive=False): """ - __init__(self, inputCol=None, outputCol=None, stopWords=ENGLISH_STOP_WORDS, + __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false) """ super(StopWordsRemover, self).__init__() @@ -845,14 +843,20 @@ def __init__(self, inputCol=None, outputCol=None, stopWords=ENGLISH_STOP_WORDS, self.stopWords = Param(self, "stopWords", "The words to be filtered out") self.caseSensitive = Param(self._dummy(), "caseSensitive", "whether to do a case " + "sensitive comparison over the stop words") + stopWordsObj = _jvm().org.apache.spark.ml.feature.StopWords + defaultStopWords = stopWordsObj.ENGLISH_STOP_WORDS() + print "Constructing java param pair for value "+str(defaultStopWords) + print "Input class is "+defaultStopWords.__class__.__name__ + self._setDefault(stopWords=defaultStopWords) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only - def setParams(self, inputCol=None, outputCol=None, stopWords=ENGLISH_STOP_WORDS, - caseSensitive=false): + def setParams(self, inputCol=None, outputCol=None, stopWords=None, + caseSensitive=False): """ - setParams(self, inputCol="input", outputCol="output", caseSensitive=false) + setParams(self, inputCol="input", outputCol="output", stopWords=None, + caseSensitive=false) Sets params for this StopWordRemover. """ kwargs = self.setParams._input_kwargs diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index d7aa96ef80367..01f8b6d597a43 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -265,16 +265,21 @@ def test_ngram(self): def test_stopwordsremover(self): sqlContext = SQLContext(self.sc) - data = ["a", "b", "c", "d", "e"] + data = ["a", "panda"] dataset = sqlContext.createDataFrame([ ([data])], ["input"]) stopwordremover = StopWordsRemover(inputCol="input", outputCol="output") - stopwords = ["a", "b", "c", "d"] + # Default + self.assertEquals(stopwordremover.getInputCol(), "input") + transformedDF = stopwordremover.transform(dataset) + self.assertEquals(transformedDF.head().output, ["panda"]) + # Custom + stopwords = ["panda"] stopwordremover.setStopWords(stopwords) self.assertEquals(stopwordremover.getInputCol(), "input") self.assertEquals(stopwordremover.getStopWords(), stopwords) transformedDF = stopwordremover.transform(dataset) - self.assertEquals(transformedDF.head().output, ["e"]) + self.assertEquals(transformedDF.head().output, ["a"]) if __name__ == "__main__": unittest.main() From acfc9fe2bd2dc9903bddfe932b12123861e0aef6 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 26 Aug 2015 21:46:08 -0700 Subject: [PATCH 06/10] CR feedback (whitespace and remove prints) --- python/pyspark/ml/feature.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 5412f39b21dc8..948a6d9804db6 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -818,6 +818,7 @@ class StringIndexerModel(JavaModel): Model fitted by StringIndexer. """ + class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol): """ .. note:: Experimental @@ -829,7 +830,7 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol): stopWords = Param(Params._dummy(), "stopWords", "The words to be filtered out") caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " + "comparison over the stop words") - + @keyword_only def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False): @@ -845,8 +846,6 @@ def __init__(self, inputCol=None, outputCol=None, stopWords=None, "sensitive comparison over the stop words") stopWordsObj = _jvm().org.apache.spark.ml.feature.StopWords defaultStopWords = stopWordsObj.ENGLISH_STOP_WORDS() - print "Constructing java param pair for value "+str(defaultStopWords) - print "Input class is "+defaultStopWords.__class__.__name__ self._setDefault(stopWords=defaultStopWords) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) From 53f97b7902d0afdfcb50e2c45827726dce4eb33e Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 26 Aug 2015 22:23:21 -0700 Subject: [PATCH 07/10] new constant for english stop words --- .../org/apache/spark/ml/feature/StopWordsRemoverSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index f01306f89cb5f..d962525adf058 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -65,7 +65,7 @@ class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext { } test("StopWordsRemover with additional words") { - val stopWords = StopWords.EnglishStopWords ++ Array("python", "scala") + val stopWords = StopWords.ENGLISH_STOP_WORDS ++ Array("python", "scala") val remover = new StopWordsRemover() .setInputCol("raw") .setOutputCol("filtered") From 7767df04efd0a996d563e6ce19071cc365b27205 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 28 Aug 2015 14:46:08 -0700 Subject: [PATCH 08/10] CR feedback from mengxr --- .../apache/spark/ml/feature/StopWordsRemover.scala | 6 +++--- .../spark/ml/feature/StopWordsRemoverSuite.scala | 2 +- python/pyspark/ml/feature.py | 8 ++++---- python/pyspark/ml/tests.py | 14 +++++++------- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 918fee1fd0316..7da430c7d16df 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -29,14 +29,14 @@ import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructTyp /** * stop words list */ -protected[spark] object StopWords { +private[spark] object StopWords { /** * Use the same default stopwords list as scikit-learn. * The original list can be found from "Glasgow Information Retrieval Group" * [[http://ir.dcs.gla.ac.uk/resources/linguistic_utils/stop_words]] */ - val ENGLISH_STOP_WORDS = Array( "a", "about", "above", "across", "after", "afterwards", "again", + val English = Array( "a", "about", "above", "across", "after", "afterwards", "again", "against", "all", "almost", "alone", "along", "already", "also", "although", "always", "am", "among", "amongst", "amoungst", "amount", "an", "and", "another", "any", "anyhow", "anyone", "anything", "anyway", "anywhere", "are", @@ -121,7 +121,7 @@ class StopWordsRemover(override val uid: String) /** @group getParam */ def getCaseSensitive: Boolean = $(caseSensitive) - setDefault(stopWords -> StopWords.ENGLISH_STOP_WORDS, caseSensitive -> false) + setDefault(stopWords -> StopWords.English, caseSensitive -> false) override def transform(dataset: DataFrame): DataFrame = { val outputSchema = transformSchema(dataset.schema) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index d962525adf058..e0d433f566c25 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -65,7 +65,7 @@ class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext { } test("StopWordsRemover with additional words") { - val stopWords = StopWords.ENGLISH_STOP_WORDS ++ Array("python", "scala") + val stopWords = StopWords.English ++ Array("python", "scala") val remover = new StopWordsRemover() .setInputCol("raw") .setOutputCol("filtered") diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 948a6d9804db6..8a514283c9e84 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -835,17 +835,17 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol): def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False): """ - __init__(self, inputCol=None, outputCol=None, stopWords=None, + __init__(self, inputCol=None, outputCol=None, stopWords=None,\ caseSensitive=false) """ super(StopWordsRemover, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover", self.uid) self.stopWords = Param(self, "stopWords", "The words to be filtered out") - self.caseSensitive = Param(self._dummy(), "caseSensitive", "whether to do a case " + + self.caseSensitive = Param(self, "caseSensitive", "whether to do a case " + "sensitive comparison over the stop words") stopWordsObj = _jvm().org.apache.spark.ml.feature.StopWords - defaultStopWords = stopWordsObj.ENGLISH_STOP_WORDS() + defaultStopWords = stopWordsObj.English() self._setDefault(stopWords=defaultStopWords) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -854,7 +854,7 @@ def __init__(self, inputCol=None, outputCol=None, stopWords=None, def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False): """ - setParams(self, inputCol="input", outputCol="output", stopWords=None, + setParams(self, inputCol="input", outputCol="output", stopWords=None,\ caseSensitive=false) Sets params for this StopWordRemover. """ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 01f8b6d597a43..eed043f8670cc 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -268,17 +268,17 @@ def test_stopwordsremover(self): data = ["a", "panda"] dataset = sqlContext.createDataFrame([ ([data])], ["input"]) - stopwordremover = StopWordsRemover(inputCol="input", outputCol="output") + stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output") # Default - self.assertEquals(stopwordremover.getInputCol(), "input") - transformedDF = stopwordremover.transform(dataset) + self.assertEquals(stopWordRemover.getInputCol(), "input") + transformedDF = stopWordRemover.transform(dataset) self.assertEquals(transformedDF.head().output, ["panda"]) # Custom stopwords = ["panda"] - stopwordremover.setStopWords(stopwords) - self.assertEquals(stopwordremover.getInputCol(), "input") - self.assertEquals(stopwordremover.getStopWords(), stopwords) - transformedDF = stopwordremover.transform(dataset) + stopWordRemover.setStopWords(stopwords) + self.assertEquals(stopWordRemover.getInputCol(), "input") + self.assertEquals(stopWordRemover.getStopWords(), stopwords) + transformedDF = stopWordRemover.transform(dataset) self.assertEquals(transformedDF.head().output, ["a"]) if __name__ == "__main__": From 345bde2ecc99b0c577eedacf273a2570c727609d Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 28 Aug 2015 15:00:32 -0700 Subject: [PATCH 09/10] Add an extra blank line that ended up missing after the automatic merge --- python/pyspark/ml/tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 4648c170d1836..1c1891ac7e5bb 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -284,6 +284,7 @@ def test_stopwordsremover(self): transformedDF = stopWordRemover.transform(dataset) self.assertEquals(transformedDF.head().output, ["a"]) + class HasInducedError(Params): def __init__(self): From 62b821aa7c01b7098c8154e32103ca6d88e30206 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 31 Aug 2015 17:47:30 -0700 Subject: [PATCH 10/10] CR feedback --- python/pyspark/ml/tests.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 1c1891ac7e5bb..b892318f50bd9 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -31,7 +31,7 @@ import unittest from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase -from pyspark.sql import DataFrame, SQLContext +from pyspark.sql import DataFrame, SQLContext, Row from pyspark.sql.functions import rand from pyspark.ml.evaluation import RegressionEvaluator from pyspark.ml.param import Param, Params @@ -258,7 +258,7 @@ def test_idf(self): def test_ngram(self): sqlContext = SQLContext(self.sc) dataset = sqlContext.createDataFrame([ - ([["a", "b", "c", "d", "e"]])], ["input"]) + Row(input=["a", "b", "c", "d", "e"])]) ngram0 = NGram(n=4, inputCol="input", outputCol="output") self.assertEqual(ngram0.getN(), 4) self.assertEqual(ngram0.getInputCol(), "input") @@ -268,9 +268,7 @@ def test_ngram(self): def test_stopwordsremover(self): sqlContext = SQLContext(self.sc) - data = ["a", "panda"] - dataset = sqlContext.createDataFrame([ - ([data])], ["input"]) + dataset = sqlContext.createDataFrame([Row(input=["a", "panda"])]) stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output") # Default self.assertEquals(stopWordRemover.getInputCol(), "input")