Skip to content

Commit 3ea051a

Browse files
committed
add read/write to CountVectorizer
1 parent 7b62e6c commit 3ea051a

File tree

5 files changed

+97
-10
lines changed

5 files changed

+97
-10
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,19 @@
1616
*/
1717
package org.apache.spark.ml.feature
1818

19-
import org.apache.spark.annotation.Experimental
19+
import org.apache.hadoop.fs.Path
20+
21+
import org.apache.spark.annotation.{Experimental, Since}
2022
import org.apache.spark.broadcast.Broadcast
23+
import org.apache.spark.ml.{Estimator, Model}
2124
import org.apache.spark.ml.param._
2225
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
23-
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
24-
import org.apache.spark.ml.{Estimator, Model}
26+
import org.apache.spark.ml.util._
2527
import org.apache.spark.mllib.linalg.{VectorUDT, Vectors}
2628
import org.apache.spark.rdd.RDD
29+
import org.apache.spark.sql.DataFrame
2730
import org.apache.spark.sql.functions._
2831
import org.apache.spark.sql.types._
29-
import org.apache.spark.sql.DataFrame
3032
import org.apache.spark.util.collection.OpenHashMap
3133

3234
/**
@@ -105,7 +107,7 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit
105107
*/
106108
@Experimental
107109
class CountVectorizer(override val uid: String)
108-
extends Estimator[CountVectorizerModel] with CountVectorizerParams {
110+
extends Estimator[CountVectorizerModel] with CountVectorizerParams with Writable {
109111

110112
def this() = this(Identifiable.randomUID("cntVec"))
111113

@@ -169,6 +171,19 @@ class CountVectorizer(override val uid: String)
169171
}
170172

171173
override def copy(extra: ParamMap): CountVectorizer = defaultCopy(extra)
174+
175+
@Since("1.6.0")
176+
override def write: Writer = new DefaultParamsWriter(this)
177+
}
178+
179+
@Since("1.6.0")
180+
object CountVectorizer extends Readable[CountVectorizer] {
181+
182+
@Since("1.6.0")
183+
override def read: Reader[CountVectorizer] = new DefaultParamsReader
184+
185+
@Since("1.6.0")
186+
override def load(path: String): CountVectorizer = super.load(path)
172187
}
173188

174189
/**
@@ -178,7 +193,9 @@ class CountVectorizer(override val uid: String)
178193
*/
179194
@Experimental
180195
class CountVectorizerModel(override val uid: String, val vocabulary: Array[String])
181-
extends Model[CountVectorizerModel] with CountVectorizerParams {
196+
extends Model[CountVectorizerModel] with CountVectorizerParams with Writable {
197+
198+
import CountVectorizerModel._
182199

183200
def this(vocabulary: Array[String]) = {
184201
this(Identifiable.randomUID("cntVecModel"), vocabulary)
@@ -232,4 +249,47 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin
232249
val copied = new CountVectorizerModel(uid, vocabulary).setParent(parent)
233250
copyValues(copied, extra)
234251
}
252+
253+
@Since("1.6.0")
254+
override def write: Writer = new CountVectorizerModelWriter(this)
255+
}
256+
257+
@Since("1.6.0")
258+
object CountVectorizerModel extends Readable[CountVectorizerModel] {
259+
260+
private[CountVectorizerModel]
261+
class CountVectorizerModelWriter(instance: CountVectorizerModel) extends Writer {
262+
263+
private case class Data(vocabulary: Seq[String])
264+
265+
override protected def saveImpl(path: String): Unit = {
266+
DefaultParamsWriter.saveMetadata(instance, path, sc)
267+
val data = Data(instance.vocabulary)
268+
val dataPath = new Path(path, "data").toString
269+
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
270+
}
271+
}
272+
273+
private class CountVectorizerModelReader extends Reader[CountVectorizerModel] {
274+
275+
private val className = "org.apache.spark.ml.feature.CountVectorizerModel"
276+
277+
override def load(path: String): CountVectorizerModel = {
278+
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
279+
val dataPath = new Path(path, "data").toString
280+
val data = sqlContext.read.parquet(dataPath)
281+
.select("vocabulary")
282+
.head()
283+
val vocabulary = data.getAs[Seq[String]](0).toArray
284+
val model = new CountVectorizerModel(metadata.uid, vocabulary)
285+
DefaultParamsReader.getAndSetParams(model, metadata)
286+
model
287+
}
288+
}
289+
290+
@Since("1.6.0")
291+
override def read: Reader[CountVectorizerModel] = new CountVectorizerModelReader
292+
293+
@Since("1.6.0")
294+
override def load(path: String): CountVectorizerModel = super.load(path)
235295
}

mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ object IDFModel extends Readable[IDFModel] {
160160

161161
private class IDFModelReader extends Reader[IDFModel] {
162162

163-
private val className = "org.apache.spark.ml.feature.StringIndexerModel"
163+
private val className = "org.apache.spark.ml.feature.IDFModel"
164164

165165
override def load(path: String): IDFModel = {
166166
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,5 +236,3 @@ object MinMaxScalerModel extends Readable[MinMaxScalerModel] {
236236
@Since("1.6.0")
237237
override def load(path: String): MinMaxScalerModel = super.load(path)
238238
}
239-
240-

mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,17 @@ package org.apache.spark.ml.feature
1818

1919
import org.apache.spark.SparkFunSuite
2020
import org.apache.spark.ml.param.ParamsSuite
21+
import org.apache.spark.ml.util.DefaultReadWriteTest
2122
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2223
import org.apache.spark.mllib.util.MLlibTestSparkContext
2324
import org.apache.spark.mllib.util.TestingUtils._
2425
import org.apache.spark.sql.Row
2526

26-
class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext {
27+
class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
28+
with DefaultReadWriteTest {
2729

2830
test("params") {
31+
ParamsSuite.checkParams(new CountVectorizer)
2932
ParamsSuite.checkParams(new CountVectorizerModel(Array("empty")))
3033
}
3134

@@ -164,4 +167,23 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext {
164167
assert(features ~== expected absTol 1e-14)
165168
}
166169
}
170+
171+
test("CountVectorizer read/write") {
172+
val t = new CountVectorizer()
173+
.setInputCol("myInputCol")
174+
.setOutputCol("myOutputCol")
175+
.setMinDF(0.5)
176+
.setMinTF(3.0)
177+
.setVocabSize(10)
178+
testDefaultReadWrite(t)
179+
}
180+
181+
test("CountVectorizerModel read/write") {
182+
val instance = new CountVectorizerModel("myCountVectorizerModel", Array("a", "b", "c"))
183+
.setInputCol("myInputCol")
184+
.setOutputCol("myOutputCol")
185+
.setMinTF(3.0)
186+
val newInstance = testDefaultReadWrite(instance)
187+
assert(newInstance.vocabulary === instance.vocabulary)
188+
}
167189
}

mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.ml.feature
1919

2020

2121
import org.apache.spark.SparkFunSuite
22+
import org.apache.spark.ml.param.ParamsSuite
2223
import org.apache.spark.ml.util.DefaultReadWriteTest
2324
import org.apache.spark.mllib.feature
2425
import org.apache.spark.mllib.linalg.{Vector, Vectors}
@@ -67,6 +68,12 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
6768
}
6869
}
6970

71+
test("params") {
72+
ParamsSuite.checkParams(new StandardScaler)
73+
val oldModel = new feature.StandardScalerModel(Vectors.dense(1.0), Vectors.dense(2.0))
74+
ParamsSuite.checkParams(new StandardScalerModel("empty", oldModel))
75+
}
76+
7077
test("Standardization with default parameter") {
7178
val df0 = sqlContext.createDataFrame(data.zip(resWithStd)).toDF("features", "expected")
7279

0 commit comments

Comments
 (0)