@@ -33,18 +33,14 @@ private[spark] object StopWords {
3333
3434 /** Read stop words list from resources */
3535 def readStopWords (language : String ): Array [String ] = {
36+ require(supportedLanguages.contains(language), s " $language is not in language list " )
3637 val is = getClass.getResourceAsStream(s " /org/apache/spark/ml/feature/stopwords/ $language.txt " )
3738 scala.io.Source .fromInputStream(is).getLines().toArray
3839 }
3940
4041 /** Supported languages list must be lowercase */
41- val supportedLanguages = Array (" danish" , " dutch" , " english" , " finnish" , " french" , " german" ,
42+ private val supportedLanguages = Set (" danish" , " dutch" , " english" , " finnish" , " french" , " german" ,
4243 " hungarian" , " italian" , " norwegian" , " portuguese" , " russian" , " spanish" , " swedish" , " turkish" )
43-
44- /** Languages and stopwords map */
45- val languageMap = supportedLanguages.map{
46- language => language -> readStopWords(language)
47- }.toMap
4844}
4945
5046/**
@@ -67,16 +63,13 @@ class StopWordsRemover(override val uid: String)
6763
6864 /**
6965 * the stop words set to be filtered out
70- * Default: [[StopWords.languageMap("english") ]]
66+ * Default: [[Array.empty ]]
7167 * @group param
7268 */
7369 val stopWords : StringArrayParam = new StringArrayParam (this , " stopWords" , " stop words" )
7470
7571 /** @group setParam */
76- def setStopWords (value : Array [String ]): this .type = {
77- set(stopWords, value)
78- set(language, " unknown" )
79- }
72+ def setStopWords (value : Array [String ]): this .type = set(stopWords, value)
8073
8174 /** @group getParam */
8275 def getStopWords : Array [String ] = $(stopWords)
@@ -96,70 +89,39 @@ class StopWordsRemover(override val uid: String)
9689 def getCaseSensitive : Boolean = $(caseSensitive)
9790
9891 /**
99- * the language of stop words
100- * Default: "english"
101- * @group param
102- */
92+ * the language of stop words
93+ * Supported languages: Danish, Dutch, English, Finnish, French, German, Hungarian,
94+ * Italian, Norwegian, Portuguese, Russian, Spanish, Swedish, Turkish
95+ * Default: "English"
96+ * @group param
97+ */
10398 val language : Param [String ] = new Param [String ](this , " language" , " stopwords language" )
10499
105100 /** @group setParam */
106- def setLanguage (value : String ): this .type = {
107- val lang = value.toLowerCase
108- require(StopWords .languageMap.contains(lang), s " $lang is not in language list " )
109- set(language, lang)
110- set(stopWords, StopWords .languageMap(lang))
111- }
101+ def setLanguage (value : String ): this .type = set(language, value.toLowerCase)
112102
113103 /** @group getParam */
114104 def getLanguage : String = $(language)
115105
116- /**
117- * the ignored stop words set to be ignored out
118- * Default: [[Array.empty ]]
119- * @group param
120- */
121- val ignoredWords : StringArrayParam = new StringArrayParam (this , " ignoredWords" ,
122- " the ignored stop words set to be ignored out" )
123-
124- /** @group setParam */
125- def setIgnoredWords (value : Array [String ]): this .type = set(ignoredWords, value)
126-
127- /** @group getParam */
128- def getIgnoredWords : Array [String ] = $(ignoredWords)
129-
130- /**
131- * the additional stop words set to be filtered out
132- * Default: [[Array.empty ]]
133- * @group param
134- */
135- val additionalWords : StringArrayParam = new StringArrayParam (this , " additionalWords" ,
136- " the additional stop words set to be filtered out" )
137-
138- /** @group setParam */
139- def setAdditionalWords (value : Array [String ]): this .type = set(additionalWords, value)
140-
141- /** @group getParam */
142- def getAdditionalWords : Array [String ] = $(additionalWords)
143-
144- setDefault(stopWords -> StopWords .languageMap(" english" ),
106+ setDefault(stopWords -> Array .empty[String ],
145107 language -> " english" ,
146- ignoredWords -> Array .empty[String ],
147- additionalWords -> Array .empty[String ],
148108 caseSensitive -> false )
149109
150110 override def transform (dataset : DataFrame ): DataFrame = {
111+ val stopWordsSet = if ($(stopWords).isEmpty) {
112+ StopWords .readStopWords($(language)).toSet
113+ } else {
114+ $(stopWords).toSet
115+ }
116+
151117 val outputSchema = transformSchema(dataset.schema)
152118 val t = if ($(caseSensitive)) {
153- val stopWordsSet = ($(stopWords) ++ $(additionalWords)).toSet -- $(ignoredWords).toSet
154119 udf { terms : Seq [String ] =>
155120 terms.filter(s => ! stopWordsSet.contains(s))
156121 }
157122 } else {
158123 val toLower = (s : String ) => if (s != null ) s.toLowerCase else s
159- val lowerStopWords = {
160- ($(stopWords) ++ $(additionalWords))
161- .map(toLower(_)).toSet -- $(ignoredWords).map(toLower(_)).toSet
162- }
124+ val lowerStopWords = stopWordsSet.map(toLower(_)).toSet
163125 udf { terms : Seq [String ] =>
164126 terms.filter(s => ! lowerStopWords.contains(toLower(s)))
165127 }
@@ -185,4 +147,11 @@ object StopWordsRemover extends DefaultParamsReadable[StopWordsRemover] {
185147
186148 @ Since (" 1.6.0" )
187149 override def load (path : String ): StopWordsRemover = super .load(path)
150+
151+ /**
152+ * Stop words for the language
153+ * Supported languages: Danish, Dutch, English, Finnish, French, German, Hungarian,
154+ * Italian, Norwegian, Portuguese, Russian, Spanish, Swedish, Turkish
155+ */
156+ def loadStopWords (language : String ): Array [String ] = StopWords .readStopWords(language.toLowerCase)
188157}
0 commit comments