Skip to content

Commit b3aa957

Browse files
committed
add stopWordsRemover
1 parent 6e4fb0c commit b3aa957

File tree

2 files changed

+190
-0
lines changed

2 files changed

+190
-0
lines changed
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
import org.apache.spark.annotation.Experimental
21+
import org.apache.spark.ml.Transformer
22+
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
23+
import org.apache.spark.ml.param.{BooleanParam, Param}
24+
import org.apache.spark.ml.util.Identifiable
25+
import org.apache.spark.sql.DataFrame
26+
import org.apache.spark.sql.types.{StructField, ArrayType, StructType}
27+
import org.apache.spark.sql.functions.{col, udf}
28+
29+
/**
30+
* :: Experimental ::
31+
* stop words list
32+
*/
33+
@Experimental
34+
object StopWords{
35+
val EnglishSet = ("a an and are as at be by for from has he in is it its of on that the to " +
36+
"was were will with").split("\\s").toSet
37+
}
38+
39+
/**
40+
* :: Experimental ::
41+
* A feature transformer that filters out stop words from input
42+
* [[http://en.wikipedia.org/wiki/Stop_words]]
43+
*/
44+
@Experimental
45+
class StopWordsRemover(override val uid: String)
46+
extends Transformer with HasInputCol with HasOutputCol {
47+
48+
def this() = this(Identifiable.randomUID("stopWords"))
49+
50+
/** @group setParam */
51+
def setInputCol(value: String): this.type = set(inputCol, value)
52+
53+
/** @group setParam */
54+
def setOutputCol(value: String): this.type = set(outputCol, value)
55+
56+
/**
57+
* the stop words set to be filtered out
58+
* @group param
59+
*/
60+
val stopWords: Param[Set[String]] = new Param(this, "stopWords", "stop words")
61+
62+
/** @group setParam */
63+
def setStopWords(value: Set[String]): this.type = set(stopWords, value)
64+
65+
/** @group getParam */
66+
def getStopWords: Set[String] = getOrDefault(stopWords)
67+
68+
/**
69+
* whether to do a case sensitive comparison over the stop words
70+
* @group param
71+
*/
72+
val caseSensitive: BooleanParam = new BooleanParam(this, "caseSensitive",
73+
"whether to do case-sensitive filter")
74+
75+
/** @group setParam */
76+
def setCaseSensitive(value: Boolean): this.type = set(caseSensitive, value)
77+
78+
/** @group getParam */
79+
def getCaseSensitive: Boolean = getOrDefault(caseSensitive)
80+
81+
setDefault(stopWords -> StopWords.EnglishSet, caseSensitive -> false)
82+
83+
override def transform(dataset: DataFrame): DataFrame = {
84+
val outputSchema = transformSchema(dataset.schema)
85+
val t = udf { terms: Seq[String] =>
86+
if ($(caseSensitive)) {
87+
terms.filterNot(s => s != null && $(stopWords).contains(s))
88+
}
89+
else {
90+
val lowerStopWords = $(stopWords).map(_.toLowerCase)
91+
terms.filterNot(s => s != null && lowerStopWords.contains(s.toLowerCase))
92+
}
93+
}
94+
val metadata = outputSchema($(outputCol)).metadata
95+
dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))
96+
}
97+
98+
override def transformSchema(schema: StructType): StructType = {
99+
val inputType = schema($(inputCol)).dataType
100+
require(inputType.isInstanceOf[ArrayType],
101+
s"The input column must be ArrayType, but got $inputType.")
102+
val outputFields = schema.fields :+
103+
StructField($(outputCol), inputType, schema($(inputCol)).nullable)
104+
StructType(outputFields)
105+
}
106+
}
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.mllib.util.MLlibTestSparkContext
22+
import org.apache.spark.sql.{DataFrame, Row}
23+
24+
import scala.beans.BeanInfo
25+
26+
@BeanInfo
27+
case class StopWordsTestData(raw: Array[String], wanted: Array[String])
28+
29+
object StopWordsRemoverSuite extends SparkFunSuite {
30+
def testStopWordsRemover(t: StopWordsRemover, dataset: DataFrame): Unit = {
31+
t.transform(dataset)
32+
.select("filtered", "wanted")
33+
.collect()
34+
.foreach { case Row(tokens, wantedTokens) =>
35+
assert(tokens === wantedTokens)
36+
}
37+
}
38+
}
39+
40+
class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext {
41+
import org.apache.spark.ml.feature.StopWordsRemoverSuite._
42+
43+
test("StopWordsRemover default") {
44+
val remover = new StopWordsRemover()
45+
.setInputCol("raw")
46+
.setOutputCol("filtered")
47+
val dataset = sqlContext.createDataFrame(Seq(
48+
StopWordsTestData(Array("test", "test"), Array("test", "test")),
49+
StopWordsTestData(Array("a", "b", "c", "d"), Array("b", "c", "d")),
50+
StopWordsTestData(Array("a", "the", "an"), Array()),
51+
StopWordsTestData(Array("A", "The", "AN"), Array()),
52+
StopWordsTestData(Array(null), Array(null)),
53+
StopWordsTestData(Array(), Array())
54+
))
55+
testStopWordsRemover(remover, dataset)
56+
}
57+
58+
test("StopWordsRemover case sensitive") {
59+
val remover = new StopWordsRemover()
60+
.setInputCol("raw")
61+
.setOutputCol("filtered")
62+
.setCaseSensitive(true)
63+
64+
val dataset = sqlContext.createDataFrame(Seq(
65+
StopWordsTestData(Array("A"), Array("A")),
66+
StopWordsTestData(Array("The", "the"), Array("The"))
67+
))
68+
testStopWordsRemover(remover, dataset)
69+
}
70+
71+
test("StopWordsRemover with additional words") {
72+
val stopWords = StopWords.EnglishSet + "python" + "scala"
73+
val remover = new StopWordsRemover()
74+
.setInputCol("raw")
75+
.setOutputCol("filtered")
76+
.setStopWords(stopWords)
77+
78+
val dataset = sqlContext.createDataFrame(Seq(
79+
StopWordsTestData(Array("python", "scala", "a"), Array()),
80+
StopWordsTestData(Array("Python", "Scala", "swift"), Array("swift"))
81+
))
82+
testStopWordsRemover(remover, dataset)
83+
}
84+
}

0 commit comments

Comments
 (0)