@@ -19,12 +19,14 @@ package org.apache.spark.ml.classification
1919
2020import scala .collection .JavaConverters ._
2121
22+ import org .apache .hadoop .fs .Path
23+
2224import org .apache .spark .annotation .{Experimental , Since }
2325import org .apache .spark .ml .{PredictionModel , Predictor , PredictorParams }
2426import org .apache .spark .ml .ann .{FeedForwardTopology , FeedForwardTrainer }
2527import org .apache .spark .ml .param .{IntArrayParam , IntParam , ParamMap , ParamValidators }
2628import org .apache .spark .ml .param .shared .{HasMaxIter , HasSeed , HasTol }
27- import org .apache .spark .ml .util .Identifiable
29+ import org .apache .spark .ml .util ._
2830import org .apache .spark .mllib .linalg .{Vector , Vectors }
2931import org .apache .spark .mllib .regression .LabeledPoint
3032import org .apache .spark .sql .DataFrame
@@ -110,7 +112,7 @@ private object LabelConverter {
110112class MultilayerPerceptronClassifier @ Since (" 1.5.0" ) (
111113 @ Since (" 1.5.0" ) override val uid : String )
112114 extends Predictor [Vector , MultilayerPerceptronClassifier , MultilayerPerceptronClassificationModel ]
113- with MultilayerPerceptronParams {
115+ with MultilayerPerceptronParams with DefaultParamsWritable {
114116
115117 @ Since (" 1.5.0" )
116118 def this () = this (Identifiable .randomUID(" mlpc" ))
@@ -172,6 +174,14 @@ class MultilayerPerceptronClassifier @Since("1.5.0") (
172174 }
173175}
174176
177+ @ Since (" 2.0.0" )
178+ object MultilayerPerceptronClassifier
179+ extends DefaultParamsReadable [MultilayerPerceptronClassifier ] {
180+
181+ @ Since (" 2.0.0" )
182+ override def load (path : String ): MultilayerPerceptronClassifier = super .load(path)
183+ }
184+
175185/**
176186 * :: Experimental ::
177187 * Classification model based on the Multilayer Perceptron.
@@ -188,7 +198,7 @@ class MultilayerPerceptronClassificationModel private[ml] (
188198 @ Since (" 1.5.0" ) val layers : Array [Int ],
189199 @ Since (" 1.5.0" ) val weights : Vector )
190200 extends PredictionModel [Vector , MultilayerPerceptronClassificationModel ]
191- with Serializable {
201+ with Serializable with MLWritable {
192202
193203 @ Since (" 1.6.0" )
194204 override val numFeatures : Int = layers.head
@@ -214,4 +224,57 @@ class MultilayerPerceptronClassificationModel private[ml] (
214224 override def copy (extra : ParamMap ): MultilayerPerceptronClassificationModel = {
215225 copyValues(new MultilayerPerceptronClassificationModel (uid, layers, weights), extra)
216226 }
227+
228+ @ Since (" 2.0.0" )
229+ override def write : MLWriter =
230+ new MultilayerPerceptronClassificationModel .MultilayerPerceptronClassificationModelWriter (this )
231+ }
232+
233+ @ Since (" 2.0.0" )
234+ object MultilayerPerceptronClassificationModel
235+ extends MLReadable [MultilayerPerceptronClassificationModel ] {
236+
237+ @ Since (" 2.0.0" )
238+ override def read : MLReader [MultilayerPerceptronClassificationModel ] =
239+ new MultilayerPerceptronClassificationModelReader
240+
241+ @ Since (" 2.0.0" )
242+ override def load (path : String ): MultilayerPerceptronClassificationModel = super .load(path)
243+
244+ /** [[MLWriter ]] instance for [[MultilayerPerceptronClassificationModel ]] */
245+ private [MultilayerPerceptronClassificationModel ]
246+ class MultilayerPerceptronClassificationModelWriter (
247+ instance : MultilayerPerceptronClassificationModel ) extends MLWriter {
248+
249+ private case class Data (layers : Array [Int ], weights : Vector )
250+
251+ override protected def saveImpl (path : String ): Unit = {
252+ // Save metadata and Params
253+ DefaultParamsWriter .saveMetadata(instance, path, sc)
254+ // Save model data: layers, weights
255+ val data = Data (instance.layers, instance.weights)
256+ val dataPath = new Path (path, " data" ).toString
257+ sqlContext.createDataFrame(Seq (data)).repartition(1 ).write.parquet(dataPath)
258+ }
259+ }
260+
261+ private class MultilayerPerceptronClassificationModelReader
262+ extends MLReader [MultilayerPerceptronClassificationModel ] {
263+
264+ /** Checked against metadata when loading model */
265+ private val className = classOf [MultilayerPerceptronClassificationModel ].getName
266+
267+ override def load (path : String ): MultilayerPerceptronClassificationModel = {
268+ val metadata = DefaultParamsReader .loadMetadata(path, sc, className)
269+
270+ val dataPath = new Path (path, " data" ).toString
271+ val data = sqlContext.read.parquet(dataPath).select(" layers" , " weights" ).head()
272+ val layers = data.getAs[Seq [Int ]](0 ).toArray
273+ val weights = data.getAs[Vector ](1 )
274+ val model = new MultilayerPerceptronClassificationModel (metadata.uid, layers, weights)
275+
276+ DefaultParamsReader .getAndSetParams(model, metadata)
277+ model
278+ }
279+ }
217280}
0 commit comments