Skip to content

Commit 7e987de

Browse files
committed
[SPARK-6787][ML] add read/write to estimators under ml.feature (1)
Add read/write support to the following estimators under spark.ml: * CountVectorizer * IDF * MinMaxScaler * StandardScaler (a little awkward because we store some params in spark.mllib model) * StringIndexer Added some necessary method for read/write. Maybe we should add `private[ml] trait DefaultParamsReadable` and `DefaultParamsWritable` to save some boilerplate code, though we still need to override `load` for Java compatibility. jkbradley Author: Xiangrui Meng <[email protected]> Closes #9798 from mengxr/SPARK-6787.
1 parent 5df0894 commit 7e987de

File tree

10 files changed

+467
-47
lines changed

10 files changed

+467
-47
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,19 @@
1616
*/
1717
package org.apache.spark.ml.feature
1818

19-
import org.apache.spark.annotation.Experimental
19+
import org.apache.hadoop.fs.Path
20+
21+
import org.apache.spark.annotation.{Experimental, Since}
2022
import org.apache.spark.broadcast.Broadcast
23+
import org.apache.spark.ml.{Estimator, Model}
2124
import org.apache.spark.ml.param._
2225
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
23-
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
24-
import org.apache.spark.ml.{Estimator, Model}
26+
import org.apache.spark.ml.util._
2527
import org.apache.spark.mllib.linalg.{VectorUDT, Vectors}
2628
import org.apache.spark.rdd.RDD
29+
import org.apache.spark.sql.DataFrame
2730
import org.apache.spark.sql.functions._
2831
import org.apache.spark.sql.types._
29-
import org.apache.spark.sql.DataFrame
3032
import org.apache.spark.util.collection.OpenHashMap
3133

3234
/**
@@ -105,7 +107,7 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit
105107
*/
106108
@Experimental
107109
class CountVectorizer(override val uid: String)
108-
extends Estimator[CountVectorizerModel] with CountVectorizerParams {
110+
extends Estimator[CountVectorizerModel] with CountVectorizerParams with Writable {
109111

110112
def this() = this(Identifiable.randomUID("cntVec"))
111113

@@ -169,6 +171,19 @@ class CountVectorizer(override val uid: String)
169171
}
170172

171173
override def copy(extra: ParamMap): CountVectorizer = defaultCopy(extra)
174+
175+
@Since("1.6.0")
176+
override def write: Writer = new DefaultParamsWriter(this)
177+
}
178+
179+
@Since("1.6.0")
180+
object CountVectorizer extends Readable[CountVectorizer] {
181+
182+
@Since("1.6.0")
183+
override def read: Reader[CountVectorizer] = new DefaultParamsReader
184+
185+
@Since("1.6.0")
186+
override def load(path: String): CountVectorizer = super.load(path)
172187
}
173188

174189
/**
@@ -178,7 +193,9 @@ class CountVectorizer(override val uid: String)
178193
*/
179194
@Experimental
180195
class CountVectorizerModel(override val uid: String, val vocabulary: Array[String])
181-
extends Model[CountVectorizerModel] with CountVectorizerParams {
196+
extends Model[CountVectorizerModel] with CountVectorizerParams with Writable {
197+
198+
import CountVectorizerModel._
182199

183200
def this(vocabulary: Array[String]) = {
184201
this(Identifiable.randomUID("cntVecModel"), vocabulary)
@@ -232,4 +249,47 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin
232249
val copied = new CountVectorizerModel(uid, vocabulary).setParent(parent)
233250
copyValues(copied, extra)
234251
}
252+
253+
@Since("1.6.0")
254+
override def write: Writer = new CountVectorizerModelWriter(this)
255+
}
256+
257+
@Since("1.6.0")
258+
object CountVectorizerModel extends Readable[CountVectorizerModel] {
259+
260+
private[CountVectorizerModel]
261+
class CountVectorizerModelWriter(instance: CountVectorizerModel) extends Writer {
262+
263+
private case class Data(vocabulary: Seq[String])
264+
265+
override protected def saveImpl(path: String): Unit = {
266+
DefaultParamsWriter.saveMetadata(instance, path, sc)
267+
val data = Data(instance.vocabulary)
268+
val dataPath = new Path(path, "data").toString
269+
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
270+
}
271+
}
272+
273+
private class CountVectorizerModelReader extends Reader[CountVectorizerModel] {
274+
275+
private val className = "org.apache.spark.ml.feature.CountVectorizerModel"
276+
277+
override def load(path: String): CountVectorizerModel = {
278+
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
279+
val dataPath = new Path(path, "data").toString
280+
val data = sqlContext.read.parquet(dataPath)
281+
.select("vocabulary")
282+
.head()
283+
val vocabulary = data.getAs[Seq[String]](0).toArray
284+
val model = new CountVectorizerModel(metadata.uid, vocabulary)
285+
DefaultParamsReader.getAndSetParams(model, metadata)
286+
model
287+
}
288+
}
289+
290+
@Since("1.6.0")
291+
override def read: Reader[CountVectorizerModel] = new CountVectorizerModelReader
292+
293+
@Since("1.6.0")
294+
override def load(path: String): CountVectorizerModel = super.load(path)
235295
}

mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717

1818
package org.apache.spark.ml.feature
1919

20-
import org.apache.spark.annotation.Experimental
20+
import org.apache.hadoop.fs.Path
21+
22+
import org.apache.spark.annotation.{Experimental, Since}
2123
import org.apache.spark.ml._
2224
import org.apache.spark.ml.param._
2325
import org.apache.spark.ml.param.shared._
24-
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
26+
import org.apache.spark.ml.util._
2527
import org.apache.spark.mllib.feature
2628
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
2729
import org.apache.spark.sql._
@@ -60,7 +62,7 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol
6062
* Compute the Inverse Document Frequency (IDF) given a collection of documents.
6163
*/
6264
@Experimental
63-
final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase {
65+
final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase with Writable {
6466

6567
def this() = this(Identifiable.randomUID("idf"))
6668

@@ -85,6 +87,19 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa
8587
}
8688

8789
override def copy(extra: ParamMap): IDF = defaultCopy(extra)
90+
91+
@Since("1.6.0")
92+
override def write: Writer = new DefaultParamsWriter(this)
93+
}
94+
95+
@Since("1.6.0")
96+
object IDF extends Readable[IDF] {
97+
98+
@Since("1.6.0")
99+
override def read: Reader[IDF] = new DefaultParamsReader
100+
101+
@Since("1.6.0")
102+
override def load(path: String): IDF = super.load(path)
88103
}
89104

90105
/**
@@ -95,7 +110,9 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa
95110
class IDFModel private[ml] (
96111
override val uid: String,
97112
idfModel: feature.IDFModel)
98-
extends Model[IDFModel] with IDFBase {
113+
extends Model[IDFModel] with IDFBase with Writable {
114+
115+
import IDFModel._
99116

100117
/** @group setParam */
101118
def setInputCol(value: String): this.type = set(inputCol, value)
@@ -117,4 +134,50 @@ class IDFModel private[ml] (
117134
val copied = new IDFModel(uid, idfModel)
118135
copyValues(copied, extra).setParent(parent)
119136
}
137+
138+
/** Returns the IDF vector. */
139+
@Since("1.6.0")
140+
def idf: Vector = idfModel.idf
141+
142+
@Since("1.6.0")
143+
override def write: Writer = new IDFModelWriter(this)
144+
}
145+
146+
@Since("1.6.0")
147+
object IDFModel extends Readable[IDFModel] {
148+
149+
private[IDFModel] class IDFModelWriter(instance: IDFModel) extends Writer {
150+
151+
private case class Data(idf: Vector)
152+
153+
override protected def saveImpl(path: String): Unit = {
154+
DefaultParamsWriter.saveMetadata(instance, path, sc)
155+
val data = Data(instance.idf)
156+
val dataPath = new Path(path, "data").toString
157+
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
158+
}
159+
}
160+
161+
private class IDFModelReader extends Reader[IDFModel] {
162+
163+
private val className = "org.apache.spark.ml.feature.IDFModel"
164+
165+
override def load(path: String): IDFModel = {
166+
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
167+
val dataPath = new Path(path, "data").toString
168+
val data = sqlContext.read.parquet(dataPath)
169+
.select("idf")
170+
.head()
171+
val idf = data.getAs[Vector](0)
172+
val model = new IDFModel(metadata.uid, new feature.IDFModel(idf))
173+
DefaultParamsReader.getAndSetParams(model, metadata)
174+
model
175+
}
176+
}
177+
178+
@Since("1.6.0")
179+
override def read: Reader[IDFModel] = new IDFModelReader
180+
181+
@Since("1.6.0")
182+
override def load(path: String): IDFModel = super.load(path)
120183
}

mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,14 @@
1717

1818
package org.apache.spark.ml.feature
1919

20-
import org.apache.spark.annotation.Experimental
21-
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
22-
import org.apache.spark.ml.param.{ParamMap, DoubleParam, Params}
23-
import org.apache.spark.ml.util.Identifiable
20+
21+
import org.apache.hadoop.fs.Path
22+
23+
import org.apache.spark.annotation.{Experimental, Since}
2424
import org.apache.spark.ml.{Estimator, Model}
25+
import org.apache.spark.ml.param.{DoubleParam, ParamMap, Params}
26+
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
27+
import org.apache.spark.ml.util._
2528
import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
2629
import org.apache.spark.mllib.stat.Statistics
2730
import org.apache.spark.sql._
@@ -85,7 +88,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H
8588
*/
8689
@Experimental
8790
class MinMaxScaler(override val uid: String)
88-
extends Estimator[MinMaxScalerModel] with MinMaxScalerParams {
91+
extends Estimator[MinMaxScalerModel] with MinMaxScalerParams with Writable {
8992

9093
def this() = this(Identifiable.randomUID("minMaxScal"))
9194

@@ -115,6 +118,19 @@ class MinMaxScaler(override val uid: String)
115118
}
116119

117120
override def copy(extra: ParamMap): MinMaxScaler = defaultCopy(extra)
121+
122+
@Since("1.6.0")
123+
override def write: Writer = new DefaultParamsWriter(this)
124+
}
125+
126+
@Since("1.6.0")
127+
object MinMaxScaler extends Readable[MinMaxScaler] {
128+
129+
@Since("1.6.0")
130+
override def read: Reader[MinMaxScaler] = new DefaultParamsReader
131+
132+
@Since("1.6.0")
133+
override def load(path: String): MinMaxScaler = super.load(path)
118134
}
119135

120136
/**
@@ -131,7 +147,9 @@ class MinMaxScalerModel private[ml] (
131147
override val uid: String,
132148
val originalMin: Vector,
133149
val originalMax: Vector)
134-
extends Model[MinMaxScalerModel] with MinMaxScalerParams {
150+
extends Model[MinMaxScalerModel] with MinMaxScalerParams with Writable {
151+
152+
import MinMaxScalerModel._
135153

136154
/** @group setParam */
137155
def setInputCol(value: String): this.type = set(inputCol, value)
@@ -175,4 +193,46 @@ class MinMaxScalerModel private[ml] (
175193
val copied = new MinMaxScalerModel(uid, originalMin, originalMax)
176194
copyValues(copied, extra).setParent(parent)
177195
}
196+
197+
@Since("1.6.0")
198+
override def write: Writer = new MinMaxScalerModelWriter(this)
199+
}
200+
201+
@Since("1.6.0")
202+
object MinMaxScalerModel extends Readable[MinMaxScalerModel] {
203+
204+
private[MinMaxScalerModel]
205+
class MinMaxScalerModelWriter(instance: MinMaxScalerModel) extends Writer {
206+
207+
private case class Data(originalMin: Vector, originalMax: Vector)
208+
209+
override protected def saveImpl(path: String): Unit = {
210+
DefaultParamsWriter.saveMetadata(instance, path, sc)
211+
val data = new Data(instance.originalMin, instance.originalMax)
212+
val dataPath = new Path(path, "data").toString
213+
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
214+
}
215+
}
216+
217+
private class MinMaxScalerModelReader extends Reader[MinMaxScalerModel] {
218+
219+
private val className = "org.apache.spark.ml.feature.MinMaxScalerModel"
220+
221+
override def load(path: String): MinMaxScalerModel = {
222+
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
223+
val dataPath = new Path(path, "data").toString
224+
val Row(originalMin: Vector, originalMax: Vector) = sqlContext.read.parquet(dataPath)
225+
.select("originalMin", "originalMax")
226+
.head()
227+
val model = new MinMaxScalerModel(metadata.uid, originalMin, originalMax)
228+
DefaultParamsReader.getAndSetParams(model, metadata)
229+
model
230+
}
231+
}
232+
233+
@Since("1.6.0")
234+
override def read: Reader[MinMaxScalerModel] = new MinMaxScalerModelReader
235+
236+
@Since("1.6.0")
237+
override def load(path: String): MinMaxScalerModel = super.load(path)
178238
}

0 commit comments

Comments
 (0)