Skip to content

Commit 2cf46d5

Browse files
yinxusenmengxr
authored andcommitted
[SPARK-11871] Add save/load for MLPC
## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-11871 Add save/load for MLPC ## How was this patch tested? Test with Scala unit test Author: Xusen Yin <[email protected]> Closes apache#9854 from yinxusen/SPARK-11871.
1 parent d283223 commit 2cf46d5

File tree

2 files changed

+103
-9
lines changed

2 files changed

+103
-9
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@ package org.apache.spark.ml.classification
1919

2020
import scala.collection.JavaConverters._
2121

22+
import org.apache.hadoop.fs.Path
23+
2224
import org.apache.spark.annotation.{Experimental, Since}
2325
import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
2426
import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer}
2527
import org.apache.spark.ml.param.{IntArrayParam, IntParam, ParamMap, ParamValidators}
2628
import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasTol}
27-
import org.apache.spark.ml.util.Identifiable
29+
import org.apache.spark.ml.util._
2830
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2931
import org.apache.spark.mllib.regression.LabeledPoint
3032
import org.apache.spark.sql.DataFrame
@@ -110,7 +112,7 @@ private object LabelConverter {
110112
class MultilayerPerceptronClassifier @Since("1.5.0") (
111113
@Since("1.5.0") override val uid: String)
112114
extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel]
113-
with MultilayerPerceptronParams {
115+
with MultilayerPerceptronParams with DefaultParamsWritable {
114116

115117
@Since("1.5.0")
116118
def this() = this(Identifiable.randomUID("mlpc"))
@@ -172,6 +174,14 @@ class MultilayerPerceptronClassifier @Since("1.5.0") (
172174
}
173175
}
174176

177+
@Since("2.0.0")
178+
object MultilayerPerceptronClassifier
179+
extends DefaultParamsReadable[MultilayerPerceptronClassifier] {
180+
181+
@Since("2.0.0")
182+
override def load(path: String): MultilayerPerceptronClassifier = super.load(path)
183+
}
184+
175185
/**
176186
* :: Experimental ::
177187
* Classification model based on the Multilayer Perceptron.
@@ -188,7 +198,7 @@ class MultilayerPerceptronClassificationModel private[ml] (
188198
@Since("1.5.0") val layers: Array[Int],
189199
@Since("1.5.0") val weights: Vector)
190200
extends PredictionModel[Vector, MultilayerPerceptronClassificationModel]
191-
with Serializable {
201+
with Serializable with MLWritable {
192202

193203
@Since("1.6.0")
194204
override val numFeatures: Int = layers.head
@@ -214,4 +224,57 @@ class MultilayerPerceptronClassificationModel private[ml] (
214224
override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = {
215225
copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra)
216226
}
227+
228+
@Since("2.0.0")
229+
override def write: MLWriter =
230+
new MultilayerPerceptronClassificationModel.MultilayerPerceptronClassificationModelWriter(this)
231+
}
232+
233+
@Since("2.0.0")
234+
object MultilayerPerceptronClassificationModel
235+
extends MLReadable[MultilayerPerceptronClassificationModel] {
236+
237+
@Since("2.0.0")
238+
override def read: MLReader[MultilayerPerceptronClassificationModel] =
239+
new MultilayerPerceptronClassificationModelReader
240+
241+
@Since("2.0.0")
242+
override def load(path: String): MultilayerPerceptronClassificationModel = super.load(path)
243+
244+
/** [[MLWriter]] instance for [[MultilayerPerceptronClassificationModel]] */
245+
private[MultilayerPerceptronClassificationModel]
246+
class MultilayerPerceptronClassificationModelWriter(
247+
instance: MultilayerPerceptronClassificationModel) extends MLWriter {
248+
249+
private case class Data(layers: Array[Int], weights: Vector)
250+
251+
override protected def saveImpl(path: String): Unit = {
252+
// Save metadata and Params
253+
DefaultParamsWriter.saveMetadata(instance, path, sc)
254+
// Save model data: layers, weights
255+
val data = Data(instance.layers, instance.weights)
256+
val dataPath = new Path(path, "data").toString
257+
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
258+
}
259+
}
260+
261+
private class MultilayerPerceptronClassificationModelReader
262+
extends MLReader[MultilayerPerceptronClassificationModel] {
263+
264+
/** Checked against metadata when loading model */
265+
private val className = classOf[MultilayerPerceptronClassificationModel].getName
266+
267+
override def load(path: String): MultilayerPerceptronClassificationModel = {
268+
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
269+
270+
val dataPath = new Path(path, "data").toString
271+
val data = sqlContext.read.parquet(dataPath).select("layers", "weights").head()
272+
val layers = data.getAs[Seq[Int]](0).toArray
273+
val weights = data.getAs[Vector](1)
274+
val model = new MultilayerPerceptronClassificationModel(metadata.uid, layers, weights)
275+
276+
DefaultParamsReader.getAndSetParams(model, metadata)
277+
model
278+
}
279+
}
217280
}

mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,31 +18,40 @@
1818
package org.apache.spark.ml.classification
1919

2020
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.ml.util.DefaultReadWriteTest
2122
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
2223
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
2324
import org.apache.spark.mllib.evaluation.MulticlassMetrics
2425
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2526
import org.apache.spark.mllib.util.MLlibTestSparkContext
2627
import org.apache.spark.mllib.util.TestingUtils._
27-
import org.apache.spark.sql.Row
28+
import org.apache.spark.sql.{DataFrame, Row}
2829

29-
class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
30+
class MultilayerPerceptronClassifierSuite
31+
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
3032

31-
test("XOR function learning as binary classification problem with two outputs.") {
32-
val dataFrame = sqlContext.createDataFrame(Seq(
33+
@transient var dataset: DataFrame = _
34+
35+
override def beforeAll(): Unit = {
36+
super.beforeAll()
37+
38+
dataset = sqlContext.createDataFrame(Seq(
3339
(Vectors.dense(0.0, 0.0), 0.0),
3440
(Vectors.dense(0.0, 1.0), 1.0),
3541
(Vectors.dense(1.0, 0.0), 1.0),
3642
(Vectors.dense(1.0, 1.0), 0.0))
3743
).toDF("features", "label")
44+
}
45+
46+
test("XOR function learning as binary classification problem with two outputs.") {
3847
val layers = Array[Int](2, 5, 2)
3948
val trainer = new MultilayerPerceptronClassifier()
4049
.setLayers(layers)
4150
.setBlockSize(1)
4251
.setSeed(11L)
4352
.setMaxIter(100)
44-
val model = trainer.fit(dataFrame)
45-
val result = model.transform(dataFrame)
53+
val model = trainer.fit(dataset)
54+
val result = model.transform(dataset)
4655
val predictionAndLabels = result.select("prediction", "label").collect()
4756
predictionAndLabels.foreach { case Row(p: Double, l: Double) =>
4857
assert(p == l)
@@ -92,4 +101,26 @@ class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSp
92101
val mlpMetrics = new MulticlassMetrics(mlpPredictionAndLabels)
93102
assert(mlpMetrics.confusionMatrix ~== lrMetrics.confusionMatrix absTol 100)
94103
}
104+
105+
test("read/write: MultilayerPerceptronClassifier") {
106+
val mlp = new MultilayerPerceptronClassifier()
107+
.setLayers(Array(2, 3, 2))
108+
.setMaxIter(5)
109+
.setBlockSize(2)
110+
.setSeed(42)
111+
.setTol(0.1)
112+
.setFeaturesCol("myFeatures")
113+
.setLabelCol("myLabel")
114+
.setPredictionCol("myPrediction")
115+
116+
testDefaultReadWrite(mlp, testParams = true)
117+
}
118+
119+
test("read/write: MultilayerPerceptronClassificationModel") {
120+
val mlp = new MultilayerPerceptronClassifier().setLayers(Array(2, 3, 2)).setMaxIter(5)
121+
val mlpModel = mlp.fit(dataset)
122+
val newMlpModel = testDefaultReadWrite(mlpModel, testParams = true)
123+
assert(newMlpModel.layers === mlpModel.layers)
124+
assert(newMlpModel.weights === mlpModel.weights)
125+
}
95126
}

0 commit comments

Comments
 (0)