Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructTyp
/**
* stop words list
*/
private object StopWords {
private[spark] object StopWords {

/**
* Use the same default stopwords list as scikit-learn.
* The original list can be found from "Glasgow Information Retrieval Group"
* [[http://ir.dcs.gla.ac.uk/resources/linguistic_utils/stop_words]]
*/
val EnglishStopWords = Array( "a", "about", "above", "across", "after", "afterwards", "again",
val English = Array( "a", "about", "above", "across", "after", "afterwards", "again",
"against", "all", "almost", "alone", "along", "already", "also", "although", "always",
"am", "among", "amongst", "amoungst", "amount", "an", "and", "another",
"any", "anyhow", "anyone", "anything", "anyway", "anywhere", "are",
Expand Down Expand Up @@ -121,7 +121,7 @@ class StopWordsRemover(override val uid: String)
/** @group getParam */
def getCaseSensitive: Boolean = $(caseSensitive)

setDefault(stopWords -> StopWords.EnglishStopWords, caseSensitive -> false)
setDefault(stopWords -> StopWords.English, caseSensitive -> false)

override def transform(dataset: DataFrame): DataFrame = {
val outputSchema = transformSchema(dataset.schema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext {
}

test("StopWordsRemover with additional words") {
val stopWords = StopWords.EnglishStopWords ++ Array("python", "scala")
val stopWords = StopWords.English ++ Array("python", "scala")
val remover = new StopWordsRemover()
.setInputCol("raw")
.setOutputCol("filtered")
Expand Down
73 changes: 71 additions & 2 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@
from pyspark.rdd import ignore_unicode_prefix
from pyspark.ml.param.shared import *
from pyspark.ml.util import keyword_only
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm
from pyspark.mllib.common import inherit_doc
from pyspark.mllib.linalg import _convert_to_vector

__all__ = ['Binarizer', 'Bucketizer', 'DCT', 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel',
'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer',
'SQLTransformer', 'StandardScaler', 'StandardScalerModel', 'StringIndexer',
'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec',
'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel']
'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel', 'StopWordsRemover']


@inherit_doc
Expand Down Expand Up @@ -933,6 +933,75 @@ class StringIndexerModel(JavaModel):
"""


class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style checker wants 2 blank lines

"""
.. note:: Experimental

A feature transformer that filters out stop words from input.
Note: null values from input array are preserved unless adding null to stopWords explicitly.
"""
# a placeholder to make the stopwords show up in generated doc
stopWords = Param(Params._dummy(), "stopWords", "The words to be filtered out")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also provide caseSensitive?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " +
"comparison over the stop words")

@keyword_only
def __init__(self, inputCol=None, outputCol=None, stopWords=None,
caseSensitive=False):
"""
__init__(self, inputCol=None, outputCol=None, stopWords=None,\
caseSensitive=false)
"""
super(StopWordsRemover, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover",
self.uid)
self.stopWords = Param(self, "stopWords", "The words to be filtered out")
self.caseSensitive = Param(self, "caseSensitive", "whether to do a case " +
"sensitive comparison over the stop words")
stopWordsObj = _jvm().org.apache.spark.ml.feature.StopWords
defaultStopWords = stopWordsObj.English()
self._setDefault(stopWords=defaultStopWords)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)

@keyword_only
def setParams(self, inputCol=None, outputCol=None, stopWords=None,
caseSensitive=False):
"""
setParams(self, inputCol="input", outputCol="output", stopWords=None,\
caseSensitive=false)
Sets params for this StopWordRemover.
"""
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)

def setStopWords(self, value):
"""
Specify the stopwords to be filtered.
"""
self._paramMap[self.stopWords] = value
return self

def getStopWords(self):
"""
Get the stopwords.
"""
return self.getOrDefault(self.stopWords)

def setCaseSensitive(self, value):
"""
Set whether to do a case sensitive comparison over the stop words
"""
self._paramMap[self.caseSensitive] = value
return self

def getCaseSensitive(self):
"""
Get whether to do a case sensitive comparison over the stop words.
"""
return self.getOrDefault(self.caseSensitive)


@inherit_doc
@ignore_unicode_prefix
class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):
Expand Down
20 changes: 18 additions & 2 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import unittest

from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
from pyspark.sql import DataFrame, SQLContext
from pyspark.sql import DataFrame, SQLContext, Row
from pyspark.sql.functions import rand
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.param import Param, Params
Expand Down Expand Up @@ -258,14 +258,30 @@ def test_idf(self):
def test_ngram(self):
sqlContext = SQLContext(self.sc)
dataset = sqlContext.createDataFrame([
([["a", "b", "c", "d", "e"]])], ["input"])
Row(input=["a", "b", "c", "d", "e"])])
ngram0 = NGram(n=4, inputCol="input", outputCol="output")
self.assertEqual(ngram0.getN(), 4)
self.assertEqual(ngram0.getInputCol(), "input")
self.assertEqual(ngram0.getOutputCol(), "output")
transformedDF = ngram0.transform(dataset)
self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"])

def test_stopwordsremover(self):
sqlContext = SQLContext(self.sc)
dataset = sqlContext.createDataFrame([Row(input=["a", "panda"])])
stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output")
# Default
self.assertEquals(stopWordRemover.getInputCol(), "input")
transformedDF = stopWordRemover.transform(dataset)
self.assertEquals(transformedDF.head().output, ["panda"])
# Custom
stopwords = ["panda"]
stopWordRemover.setStopWords(stopwords)
self.assertEquals(stopWordRemover.getInputCol(), "input")
self.assertEquals(stopWordRemover.getStopWords(), stopwords)
transformedDF = stopWordRemover.transform(dataset)
self.assertEquals(transformedDF.head().output, ["a"])


class HasInducedError(Params):

Expand Down