1616 */
1717package org .apache .spark .ml .feature
1818
19- import org .apache .spark .annotation .Experimental
19+ import org .apache .hadoop .fs .Path
20+
21+ import org .apache .spark .annotation .{Experimental , Since }
2022import org .apache .spark .broadcast .Broadcast
23+ import org .apache .spark .ml .{Estimator , Model }
2124import org .apache .spark .ml .param ._
2225import org .apache .spark .ml .param .shared .{HasInputCol , HasOutputCol }
23- import org .apache .spark .ml .util .{Identifiable , SchemaUtils }
24- import org .apache .spark .ml .{Estimator , Model }
26+ import org .apache .spark .ml .util ._
2527import org .apache .spark .mllib .linalg .{VectorUDT , Vectors }
2628import org .apache .spark .rdd .RDD
29+ import org .apache .spark .sql .DataFrame
2730import org .apache .spark .sql .functions ._
2831import org .apache .spark .sql .types ._
29- import org .apache .spark .sql .DataFrame
3032import org .apache .spark .util .collection .OpenHashMap
3133
3234/**
@@ -105,7 +107,7 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit
105107 */
106108@ Experimental
107109class CountVectorizer (override val uid : String )
108- extends Estimator [CountVectorizerModel ] with CountVectorizerParams {
110+ extends Estimator [CountVectorizerModel ] with CountVectorizerParams with Writable {
109111
110112 def this () = this (Identifiable .randomUID(" cntVec" ))
111113
@@ -169,6 +171,19 @@ class CountVectorizer(override val uid: String)
169171 }
170172
171173 override def copy (extra : ParamMap ): CountVectorizer = defaultCopy(extra)
174+
175+ @ Since (" 1.6.0" )
176+ override def write : Writer = new DefaultParamsWriter (this )
177+ }
178+
179+ @ Since (" 1.6.0" )
180+ object CountVectorizer extends Readable [CountVectorizer ] {
181+
182+ @ Since (" 1.6.0" )
183+ override def read : Reader [CountVectorizer ] = new DefaultParamsReader
184+
185+ @ Since (" 1.6.0" )
186+ override def load (path : String ): CountVectorizer = super .load(path)
172187}
173188
174189/**
@@ -178,7 +193,9 @@ class CountVectorizer(override val uid: String)
178193 */
179194@ Experimental
180195class CountVectorizerModel (override val uid : String , val vocabulary : Array [String ])
181- extends Model [CountVectorizerModel ] with CountVectorizerParams {
196+ extends Model [CountVectorizerModel ] with CountVectorizerParams with Writable {
197+
198+ import CountVectorizerModel ._
182199
183200 def this (vocabulary : Array [String ]) = {
184201 this (Identifiable .randomUID(" cntVecModel" ), vocabulary)
@@ -232,4 +249,47 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin
232249 val copied = new CountVectorizerModel (uid, vocabulary).setParent(parent)
233250 copyValues(copied, extra)
234251 }
252+
253+ @ Since (" 1.6.0" )
254+ override def write : Writer = new CountVectorizerModelWriter (this )
255+ }
256+
257+ @ Since (" 1.6.0" )
258+ object CountVectorizerModel extends Readable [CountVectorizerModel ] {
259+
260+ private [CountVectorizerModel ]
261+ class CountVectorizerModelWriter (instance : CountVectorizerModel ) extends Writer {
262+
263+ private case class Data (vocabulary : Seq [String ])
264+
265+ override protected def saveImpl (path : String ): Unit = {
266+ DefaultParamsWriter .saveMetadata(instance, path, sc)
267+ val data = Data (instance.vocabulary)
268+ val dataPath = new Path (path, " data" ).toString
269+ sqlContext.createDataFrame(Seq (data)).repartition(1 ).write.parquet(dataPath)
270+ }
271+ }
272+
273+ private class CountVectorizerModelReader extends Reader [CountVectorizerModel ] {
274+
275+ private val className = " org.apache.spark.ml.feature.CountVectorizerModel"
276+
277+ override def load (path : String ): CountVectorizerModel = {
278+ val metadata = DefaultParamsReader .loadMetadata(path, sc, className)
279+ val dataPath = new Path (path, " data" ).toString
280+ val data = sqlContext.read.parquet(dataPath)
281+ .select(" vocabulary" )
282+ .head()
283+ val vocabulary = data.getAs[Seq [String ]](0 ).toArray
284+ val model = new CountVectorizerModel (metadata.uid, vocabulary)
285+ DefaultParamsReader .getAndSetParams(model, metadata)
286+ model
287+ }
288+ }
289+
290+ @ Since (" 1.6.0" )
291+ override def read : Reader [CountVectorizerModel ] = new CountVectorizerModelReader
292+
293+ @ Since (" 1.6.0" )
294+ override def load (path : String ): CountVectorizerModel = super .load(path)
235295}
0 commit comments