From 7c61fb32801ed802b8792663b1769c9eddd1346e Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Mon, 29 Jun 2015 19:42:40 +0800 Subject: [PATCH 1/7] add countVectorizer --- .../spark/ml/feature/CountVectorizer.scala | 80 ++++++++++++++++++ .../ml/feature/CountVectorizorSuite.scala | 83 +++++++++++++++++++ 2 files changed, 163 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala new file mode 100644 index 000000000000..e84a887da079 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.mutable + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.linalg.{Vectors, VectorUDT, Vector} +import org.apache.spark.sql.types.{StringType, ArrayType, DataType} + +/** + * :: Experimental :: + * Converts a text document to a sparse vector of token counts. + * @param vocabulary An Array over terms. Only the terms in the vocabulary will be counted. + */ +@Experimental +class CountVectorizer (override val uid: String, vocabulary: Array[String]) + extends UnaryTransformer[Seq[String], Vector, CountVectorizer] { + + def this(vocabulary: Array[String]) = this(Identifiable.randomUID("countVectorizer"), vocabulary) + + /** + * Corpus-specific stop words filter. Terms with count less than the given threshold are ignored. + * Default: 1 + * @group param + */ + val minTermCounts: IntParam = new IntParam(this, "minTermCounts", + "lower bound of effective term counts (>= 0)", ParamValidators.gtEq(1)) + + /** @group setParam */ + def setMinTermCounts(value: Int): this.type = set(minTermCounts, value) + + /** @group getParam */ + def getMinTermCounts: Int = $(minTermCounts) + + setDefault(minTermCounts -> 1) + + override protected def createTransformFunc: Seq[String] => Vector = { + val dict = vocabulary.zipWithIndex.toMap + document => + val termCounts = mutable.HashMap.empty[Int, Double] + document.foreach { term => + val index = dict.getOrElse(term, -1) + if (index >= 0) { + termCounts.put(index, termCounts.getOrElse(index, 0.0) + 1.0) + } + } + Vectors.sparse(dict.size, termCounts.filter(_._2 >= $(minTermCounts)).toSeq) + } + + override protected def validateInputType(inputType: DataType): Unit = { + require(inputType.sameType(ArrayType(StringType)), + s"Input type must be Array type but got $inputType.") + } + + override protected def outputDataType: DataType = new VectorUDT() + + override def copy(extra: ParamMap): CountVectorizer = { + val copied = new CountVectorizer(uid, vocabulary) + copyValues(copied, extra) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala new file mode 100644 index 000000000000..7b11285617be --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("params") { + ParamsSuite.checkParams(new CountVectorizer(Array("empty"))) + } + + test("CountVectorizer common cases") { + val df = sqlContext.createDataFrame(Seq( + (0, "a b c d".split(" ").toSeq), + (1, "a b b c d a".split(" ").toSeq), + (2, "a".split(" ").toSeq), + (3, "".split(" ").toSeq), // empty string + (3, "a notInDict d".split(" ").toSeq) // with words not in vocabulary + )).toDF("id", "words") + val hashingTF = new CountVectorizer(Array("a", "b", "c", "d")) + .setInputCol("words") + .setOutputCol("features") + .setMinTermCounts(1) + val output = hashingTF.transform(df) + val features = output.select("features").collect() + + val expected = Seq( + Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0))), + Vectors.sparse(4, Seq((0, 2.0), (1, 2.0), (2, 1.0), (3, 1.0))), + Vectors.sparse(4, Seq((0, 1.0))), + Vectors.sparse(4, Seq()), + Vectors.sparse(4, Seq((0, 1.0), (3, 1.0)))) + + features.zip(expected).foreach(p => + assert(p._1.getAs[Vector](0) == p._2) + ) + } + + test("CountVectorizer with minTermCounts") { + val df = sqlContext.createDataFrame(Seq( + (0, "a a a b b c c c d ".split(" ").toSeq), + (1, "c c c c c c".split(" ").toSeq), + (2, "a".split(" ").toSeq), + (3, "e e e e e".split(" ").toSeq) + )).toDF("id", "words") + val cv = new CountVectorizer(Array("a", "b", "c", "d")) + .setInputCol("words") + .setOutputCol("features") + .setMinTermCounts(3) + val output = cv.transform(df) + val features = output.select("features").collect() + + val expected = Seq( + Vectors.sparse(4, Seq((0, 3.0), (2, 3.0))), + Vectors.sparse(4, Seq((2, 6.0))), + Vectors.sparse(4, Seq()), + Vectors.sparse(4, Seq())) + + features.zip(expected).foreach(p => + assert(p._1.getAs[Vector](0) == p._2) + ) + } +} + + From 809fb5947728de16c9addd5a8e27a41371394ff9 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Mon, 29 Jun 2015 20:27:00 +0800 Subject: [PATCH 2/7] minor fix for ut --- .../apache/spark/ml/feature/CountVectorizorSuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala index 7b11285617be..9bd5bdd25621 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala @@ -20,6 +20,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -35,11 +36,10 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { (3, "".split(" ").toSeq), // empty string (3, "a notInDict d".split(" ").toSeq) // with words not in vocabulary )).toDF("id", "words") - val hashingTF = new CountVectorizer(Array("a", "b", "c", "d")) + val cv = new CountVectorizer(Array("a", "b", "c", "d")) .setInputCol("words") .setOutputCol("features") - .setMinTermCounts(1) - val output = hashingTF.transform(df) + val output = cv.transform(df) val features = output.select("features").collect() val expected = Seq( @@ -50,7 +50,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { Vectors.sparse(4, Seq((0, 1.0), (3, 1.0)))) features.zip(expected).foreach(p => - assert(p._1.getAs[Vector](0) == p._2) + assert(p._1.getAs[Vector](0) ~== p._2 absTol 1e-14) ) } @@ -75,7 +75,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { Vectors.sparse(4, Seq())) features.zip(expected).foreach(p => - assert(p._1.getAs[Vector](0) == p._2) + assert(p._1.getAs[Vector](0) ~== p._2 absTol 1e-14) ) } } From 7ee1c310bdc6e4e9f26035c143a1daaba9a3212f Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 30 Jun 2015 10:50:38 +0800 Subject: [PATCH 3/7] extends HashingTF --- .../spark/ml/feature/CountVectorizer.scala | 31 +++++++------------ 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index e84a887da079..2ddb6bea1277 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -14,17 +14,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark.ml.feature import scala.collection.mutable import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} import org.apache.spark.ml.util.Identifiable -import org.apache.spark.mllib.linalg.{Vectors, VectorUDT, Vector} -import org.apache.spark.sql.types.{StringType, ArrayType, DataType} +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ /** * :: Experimental :: @@ -32,8 +31,7 @@ import org.apache.spark.sql.types.{StringType, ArrayType, DataType} * @param vocabulary An Array over terms. Only the terms in the vocabulary will be counted. */ @Experimental -class CountVectorizer (override val uid: String, vocabulary: Array[String]) - extends UnaryTransformer[Seq[String], Vector, CountVectorizer] { +class CountVectorizer (override val uid: String, vocabulary: Array[String]) extends HashingTF{ def this(vocabulary: Array[String]) = this(Identifiable.randomUID("countVectorizer"), vocabulary) @@ -43,7 +41,7 @@ class CountVectorizer (override val uid: String, vocabulary: Array[String]) * @group param */ val minTermCounts: IntParam = new IntParam(this, "minTermCounts", - "lower bound of effective term counts (>= 0)", ParamValidators.gtEq(1)) + "lower bound of effective term counts (>= 1)", ParamValidators.gtEq(1)) /** @group setParam */ def setMinTermCounts(value: Int): this.type = set(minTermCounts, value) @@ -51,28 +49,23 @@ class CountVectorizer (override val uid: String, vocabulary: Array[String]) /** @group getParam */ def getMinTermCounts: Int = $(minTermCounts) - setDefault(minTermCounts -> 1) + setDefault(minTermCounts -> 1, numFeatures -> vocabulary.size) - override protected def createTransformFunc: Seq[String] => Vector = { + override def transform(dataset: DataFrame): DataFrame = { val dict = vocabulary.zipWithIndex.toMap - document => + val t = udf { terms: Seq[String] => val termCounts = mutable.HashMap.empty[Int, Double] - document.foreach { term => + terms.foreach { term => val index = dict.getOrElse(term, -1) if (index >= 0) { termCounts.put(index, termCounts.getOrElse(index, 0.0) + 1.0) } } Vectors.sparse(dict.size, termCounts.filter(_._2 >= $(minTermCounts)).toSeq) + } + dataset.withColumn($(outputCol), t(col($(inputCol)))) } - override protected def validateInputType(inputType: DataType): Unit = { - require(inputType.sameType(ArrayType(StringType)), - s"Input type must be Array type but got $inputType.") - } - - override protected def outputDataType: DataType = new VectorUDT() - override def copy(extra: ParamMap): CountVectorizer = { val copied = new CountVectorizer(uid, vocabulary) copyValues(copied, extra) From 99b0c14bba1c210c180598a8bbdb4e5364f17df6 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Thu, 2 Jul 2015 10:30:57 +0800 Subject: [PATCH 4/7] undo extension from HashingTF --- .../spark/ml/feature/CountVectorizer.scala | 34 +++++++++++-------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 2ddb6bea1277..83c8eb8525e0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -19,11 +19,11 @@ package org.apache.spark.ml.feature import scala.collection.mutable import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param.{ParamMap, ParamValidators, IntParam} import org.apache.spark.ml.util.Identifiable -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.functions._ +import org.apache.spark.mllib.linalg.{Vectors, VectorUDT, Vector} +import org.apache.spark.sql.types.{StringType, ArrayType, DataType} /** * :: Experimental :: @@ -31,7 +31,8 @@ import org.apache.spark.sql.functions._ * @param vocabulary An Array over terms. Only the terms in the vocabulary will be counted. */ @Experimental -class CountVectorizer (override val uid: String, vocabulary: Array[String]) extends HashingTF{ +class CountVectorizer (override val uid: String, vocabulary: Array[String]) + extends UnaryTransformer[Seq[String], Vector, CountVectorizer] { def this(vocabulary: Array[String]) = this(Identifiable.randomUID("countVectorizer"), vocabulary) @@ -49,23 +50,28 @@ class CountVectorizer (override val uid: String, vocabulary: Array[String]) exte /** @group getParam */ def getMinTermCounts: Int = $(minTermCounts) - setDefault(minTermCounts -> 1, numFeatures -> vocabulary.size) + setDefault(minTermCounts -> 1) - override def transform(dataset: DataFrame): DataFrame = { + override protected def createTransformFunc: Seq[String] => Vector = { val dict = vocabulary.zipWithIndex.toMap - val t = udf { terms: Seq[String] => + document => val termCounts = mutable.HashMap.empty[Int, Double] - terms.foreach { term => - val index = dict.getOrElse(term, -1) - if (index >= 0) { - termCounts.put(index, termCounts.getOrElse(index, 0.0) + 1.0) + document.foreach { term => + dict.get(term) match { + case Some(index) => termCounts.put(index, termCounts.getOrElse(index, 0.0) + 1.0) + case None => // ignore terms not in the vocabulary } } Vectors.sparse(dict.size, termCounts.filter(_._2 >= $(minTermCounts)).toSeq) - } - dataset.withColumn($(outputCol), t(col($(inputCol)))) } + override protected def validateInputType(inputType: DataType): Unit = { + require(inputType.sameType(ArrayType(StringType)), + s"Input type must be Array type but got $inputType.") + } + + override protected def outputDataType: DataType = new VectorUDT() + override def copy(extra: ParamMap): CountVectorizer = { val copied = new CountVectorizer(uid, vocabulary) copyValues(copied, extra) From 576728ab883e94a8312420696522fc3e627ab42f Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Mon, 6 Jul 2015 13:15:25 +0800 Subject: [PATCH 5/7] rename to model and some fix --- ...rizer.scala => CountVectorizerModel.scala} | 29 ++++---- .../ml/feature/CountVectorizorSuite.scala | 74 ++++++++----------- 2 files changed, 48 insertions(+), 55 deletions(-) rename mllib/src/main/scala/org/apache/spark/ml/feature/{CountVectorizer.scala => CountVectorizerModel.scala} (68%) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala similarity index 68% rename from mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala rename to mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala index 83c8eb8525e0..f705ae359a91 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala @@ -31,26 +31,29 @@ import org.apache.spark.sql.types.{StringType, ArrayType, DataType} * @param vocabulary An Array over terms. Only the terms in the vocabulary will be counted. */ @Experimental -class CountVectorizer (override val uid: String, vocabulary: Array[String]) - extends UnaryTransformer[Seq[String], Vector, CountVectorizer] { +class CountVectorizerModel (override val uid: String, val vocabulary: Array[String]) + extends UnaryTransformer[Seq[String], Vector, CountVectorizerModel] { - def this(vocabulary: Array[String]) = this(Identifiable.randomUID("countVectorizer"), vocabulary) + def this(vocabulary: Array[String]) = + this(Identifiable.randomUID("countVectorizerModel"), vocabulary) /** - * Corpus-specific stop words filter. Terms with count less than the given threshold are ignored. + * Corpus-specific filter to neglect scarce words in a document. For each document, terms with + * frequency (count) less than the given threshold are ignored. * Default: 1 * @group param */ - val minTermCounts: IntParam = new IntParam(this, "minTermCounts", - "lower bound of effective term counts (>= 1)", ParamValidators.gtEq(1)) + val minTermFreq: IntParam = new IntParam(this, "minTermFreq", + "minimum frequency (count) filter used to neglect scarce words (>= 1). For each document, " + + "terms with frequency less than the given threshold are ignored.", ParamValidators.gtEq(1)) /** @group setParam */ - def setMinTermCounts(value: Int): this.type = set(minTermCounts, value) + def setMinTermFreq(value: Int): this.type = set(minTermFreq, value) /** @group getParam */ - def getMinTermCounts: Int = $(minTermCounts) + def getMinTermFreq: Int = $(minTermFreq) - setDefault(minTermCounts -> 1) + setDefault(minTermFreq -> 1) override protected def createTransformFunc: Seq[String] => Vector = { val dict = vocabulary.zipWithIndex.toMap @@ -62,18 +65,18 @@ class CountVectorizer (override val uid: String, vocabulary: Array[String]) case None => // ignore terms not in the vocabulary } } - Vectors.sparse(dict.size, termCounts.filter(_._2 >= $(minTermCounts)).toSeq) + Vectors.sparse(dict.size, termCounts.filter(_._2 >= $(minTermFreq)).toSeq) } override protected def validateInputType(inputType: DataType): Unit = { require(inputType.sameType(ArrayType(StringType)), - s"Input type must be Array type but got $inputType.") + s"Input type must be ArrayType(StringType) but got $inputType.") } override protected def outputDataType: DataType = new VectorUDT() - override def copy(extra: ParamMap): CountVectorizer = { - val copied = new CountVectorizer(uid, vocabulary) + override def copy(extra: ParamMap): CountVectorizerModel = { + val copied = new CountVectorizerModel(uid, vocabulary) copyValues(copied, extra) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala index 9bd5bdd25621..abee1e2b5c76 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala @@ -25,58 +25,48 @@ import org.apache.spark.mllib.util.TestingUtils._ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { test("params") { - ParamsSuite.checkParams(new CountVectorizer(Array("empty"))) + ParamsSuite.checkParams(new CountVectorizerModel(Array("empty"))) } - test("CountVectorizer common cases") { + test("CountVectorizerModel common cases") { val df = sqlContext.createDataFrame(Seq( - (0, "a b c d".split(" ").toSeq), - (1, "a b b c d a".split(" ").toSeq), - (2, "a".split(" ").toSeq), - (3, "".split(" ").toSeq), // empty string - (3, "a notInDict d".split(" ").toSeq) // with words not in vocabulary - )).toDF("id", "words") - val cv = new CountVectorizer(Array("a", "b", "c", "d")) + (0, "a b c d".split(" ").toSeq, + Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))), + (1, "a b b c d a".split(" ").toSeq, + Vectors.sparse(4, Seq((0, 2.0), (1, 2.0), (2, 1.0), (3, 1.0)))), + (2, "a".split(" ").toSeq, Vectors.sparse(4, Seq((0, 1.0)))), + (3, "".split(" ").toSeq, Vectors.sparse(4, Seq())), // empty string + (4, "a notInDict d".split(" ").toSeq, + Vectors.sparse(4, Seq((0, 1.0), (3, 1.0)))) // with words not in vocabulary + )).toDF("id", "words", "expected") + val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) .setInputCol("words") .setOutputCol("features") - val output = cv.transform(df) - val features = output.select("features").collect() - - val expected = Seq( - Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0))), - Vectors.sparse(4, Seq((0, 2.0), (1, 2.0), (2, 1.0), (3, 1.0))), - Vectors.sparse(4, Seq((0, 1.0))), - Vectors.sparse(4, Seq()), - Vectors.sparse(4, Seq((0, 1.0), (3, 1.0)))) - - features.zip(expected).foreach(p => - assert(p._1.getAs[Vector](0) ~== p._2 absTol 1e-14) - ) + val output = cv.transform(df).collect() + output.foreach{ p => + val features = p.getAs[Vector]("features") + val expected = p.getAs[Vector]("expected") + assert(features ~== expected absTol 1e-14) + } } - test("CountVectorizer with minTermCounts") { + test("CountVectorizerModel with minTermFreq") { val df = sqlContext.createDataFrame(Seq( - (0, "a a a b b c c c d ".split(" ").toSeq), - (1, "c c c c c c".split(" ").toSeq), - (2, "a".split(" ").toSeq), - (3, "e e e e e".split(" ").toSeq) - )).toDF("id", "words") - val cv = new CountVectorizer(Array("a", "b", "c", "d")) + (0, "a a a b b c c c d ".split(" ").toSeq, Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))), + (1, "c c c c c c".split(" ").toSeq, Vectors.sparse(4, Seq((2, 6.0)))), + (2, "a".split(" ").toSeq, Vectors.sparse(4, Seq())), + (3, "e e e e e".split(" ").toSeq, Vectors.sparse(4, Seq()))) + ).toDF("id", "words", "expected") + val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) .setInputCol("words") .setOutputCol("features") - .setMinTermCounts(3) - val output = cv.transform(df) - val features = output.select("features").collect() - - val expected = Seq( - Vectors.sparse(4, Seq((0, 3.0), (2, 3.0))), - Vectors.sparse(4, Seq((2, 6.0))), - Vectors.sparse(4, Seq()), - Vectors.sparse(4, Seq())) - - features.zip(expected).foreach(p => - assert(p._1.getAs[Vector](0) ~== p._2 absTol 1e-14) - ) + .setMinTermFreq(3) + val output = cv.transform(df).collect() + output.foreach{ p => + val features = p.getAs[Vector]("features") + val expected = p.getAs[Vector]("expected") + assert(features ~== expected absTol 1e-14) + } } } From 24728e4c35b422457c82ddc540c7de5ff8c47fd8 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Mon, 6 Jul 2015 13:21:06 +0800 Subject: [PATCH 6/7] style improvement --- .../org/apache/spark/ml/feature/CountVectorizorSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala index abee1e2b5c76..e90d9d4ef21f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala @@ -43,7 +43,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { .setInputCol("words") .setOutputCol("features") val output = cv.transform(df).collect() - output.foreach{ p => + output.foreach { p => val features = p.getAs[Vector]("features") val expected = p.getAs[Vector]("expected") assert(features ~== expected absTol 1e-14) @@ -62,7 +62,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { .setOutputCol("features") .setMinTermFreq(3) val output = cv.transform(df).collect() - output.foreach{ p => + output.foreach { p => val features = p.getAs[Vector]("features") val expected = p.getAs[Vector]("expected") assert(features ~== expected absTol 1e-14) From 5f3f655c7bf2feaf9502e13abdf160657cc41fdc Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Thu, 9 Jul 2015 17:39:06 +0800 Subject: [PATCH 7/7] text change --- .../org/apache/spark/ml/feature/CountVectorizerModel.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala index f705ae359a91..6b77de89a033 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala @@ -35,10 +35,10 @@ class CountVectorizerModel (override val uid: String, val vocabulary: Array[Stri extends UnaryTransformer[Seq[String], Vector, CountVectorizerModel] { def this(vocabulary: Array[String]) = - this(Identifiable.randomUID("countVectorizerModel"), vocabulary) + this(Identifiable.randomUID("cntVec"), vocabulary) /** - * Corpus-specific filter to neglect scarce words in a document. For each document, terms with + * Corpus-specific filter to ignore scarce words in a document. For each document, terms with * frequency (count) less than the given threshold are ignored. * Default: 1 * @group param