Skip to content

Commit 0dd0648

Browse files
yanboliangjkbradley
authored andcommitted
[SPARK-13615][ML] GeneralizedLinearRegression supports save/load
## What changes were proposed in this pull request? ```GeneralizedLinearRegression``` supports ```save/load```. cc mengxr ## How was this patch tested? unit test. Author: Yanbo Liang <[email protected]> Closes #11465 from yanboliang/spark-13615.
1 parent cad29a4 commit 0dd0648

File tree

2 files changed

+96
-10
lines changed

2 files changed

+96
-10
lines changed

mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.ml.regression
1919

2020
import breeze.stats.distributions.{Gaussian => GD}
21+
import org.apache.hadoop.fs.Path
2122

2223
import org.apache.spark.{Logging, SparkException}
2324
import org.apache.spark.annotation.{Experimental, Since}
@@ -26,7 +27,7 @@ import org.apache.spark.ml.feature.Instance
2627
import org.apache.spark.ml.optim._
2728
import org.apache.spark.ml.param._
2829
import org.apache.spark.ml.param.shared._
29-
import org.apache.spark.ml.util.Identifiable
30+
import org.apache.spark.ml.util._
3031
import org.apache.spark.mllib.linalg.{BLAS, Vector}
3132
import org.apache.spark.rdd.RDD
3233
import org.apache.spark.sql.{DataFrame, Row}
@@ -106,7 +107,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
106107
@Since("2.0.0")
107108
class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val uid: String)
108109
extends Regressor[Vector, GeneralizedLinearRegression, GeneralizedLinearRegressionModel]
109-
with GeneralizedLinearRegressionBase with Logging {
110+
with GeneralizedLinearRegressionBase with DefaultParamsWritable with Logging {
110111

111112
import GeneralizedLinearRegression._
112113

@@ -236,23 +237,26 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
236237
}
237238

238239
@Since("2.0.0")
239-
private[ml] object GeneralizedLinearRegression {
240+
object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLinearRegression] {
241+
242+
@Since("2.0.0")
243+
override def load(path: String): GeneralizedLinearRegression = super.load(path)
240244

241245
/** Set of family and link pairs that GeneralizedLinearRegression supports. */
242-
lazy val supportedFamilyAndLinkPairs = Set(
246+
private[ml] lazy val supportedFamilyAndLinkPairs = Set(
243247
Gaussian -> Identity, Gaussian -> Log, Gaussian -> Inverse,
244248
Binomial -> Logit, Binomial -> Probit, Binomial -> CLogLog,
245249
Poisson -> Log, Poisson -> Identity, Poisson -> Sqrt,
246250
Gamma -> Inverse, Gamma -> Identity, Gamma -> Log
247251
)
248252

249253
/** Set of family names that GeneralizedLinearRegression supports. */
250-
lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name)
254+
private[ml] lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name)
251255

252256
/** Set of link names that GeneralizedLinearRegression supports. */
253-
lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name)
257+
private[ml] lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name)
254258

255-
val epsilon: Double = 1E-16
259+
private[ml] val epsilon: Double = 1E-16
256260

257261
/**
258262
* Wrapper of family and link combination used in the model.
@@ -552,7 +556,7 @@ class GeneralizedLinearRegressionModel private[ml] (
552556
@Since("2.0.0") val coefficients: Vector,
553557
@Since("2.0.0") val intercept: Double)
554558
extends RegressionModel[Vector, GeneralizedLinearRegressionModel]
555-
with GeneralizedLinearRegressionBase {
559+
with GeneralizedLinearRegressionBase with MLWritable {
556560

557561
import GeneralizedLinearRegression._
558562

@@ -574,4 +578,58 @@ class GeneralizedLinearRegressionModel private[ml] (
574578
copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra)
575579
.setParent(parent)
576580
}
581+
582+
@Since("2.0.0")
583+
override def write: MLWriter =
584+
new GeneralizedLinearRegressionModel.GeneralizedLinearRegressionModelWriter(this)
585+
}
586+
587+
@Since("2.0.0")
588+
object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegressionModel] {
589+
590+
@Since("2.0.0")
591+
override def read: MLReader[GeneralizedLinearRegressionModel] =
592+
new GeneralizedLinearRegressionModelReader
593+
594+
@Since("2.0.0")
595+
override def load(path: String): GeneralizedLinearRegressionModel = super.load(path)
596+
597+
/** [[MLWriter]] instance for [[GeneralizedLinearRegressionModel]] */
598+
private[GeneralizedLinearRegressionModel]
599+
class GeneralizedLinearRegressionModelWriter(instance: GeneralizedLinearRegressionModel)
600+
extends MLWriter with Logging {
601+
602+
private case class Data(intercept: Double, coefficients: Vector)
603+
604+
override protected def saveImpl(path: String): Unit = {
605+
// Save metadata and Params
606+
DefaultParamsWriter.saveMetadata(instance, path, sc)
607+
// Save model data: intercept, coefficients
608+
val data = Data(instance.intercept, instance.coefficients)
609+
val dataPath = new Path(path, "data").toString
610+
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
611+
}
612+
}
613+
614+
private class GeneralizedLinearRegressionModelReader
615+
extends MLReader[GeneralizedLinearRegressionModel] {
616+
617+
/** Checked against metadata when loading model */
618+
private val className = classOf[GeneralizedLinearRegressionModel].getName
619+
620+
override def load(path: String): GeneralizedLinearRegressionModel = {
621+
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
622+
623+
val dataPath = new Path(path, "data").toString
624+
val data = sqlContext.read.parquet(dataPath)
625+
.select("intercept", "coefficients").head()
626+
val intercept = data.getDouble(0)
627+
val coefficients = data.getAs[Vector](1)
628+
629+
val model = new GeneralizedLinearRegressionModel(metadata.uid, coefficients, intercept)
630+
631+
DefaultParamsReader.getAndSetParams(model, metadata)
632+
model
633+
}
634+
}
577635
}

mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import scala.util.Random
2121

2222
import org.apache.spark.SparkFunSuite
2323
import org.apache.spark.ml.param.ParamsSuite
24-
import org.apache.spark.ml.util.MLTestingUtils
24+
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
2525
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
2626
import org.apache.spark.mllib.linalg.{BLAS, DenseVector, Vectors}
2727
import org.apache.spark.mllib.random._
@@ -30,7 +30,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
3030
import org.apache.spark.mllib.util.TestingUtils._
3131
import org.apache.spark.sql.{DataFrame, Row}
3232

33-
class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
33+
class GeneralizedLinearRegressionSuite
34+
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
3435

3536
private val seed: Int = 42
3637
@transient var datasetGaussianIdentity: DataFrame = _
@@ -464,10 +465,37 @@ class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSpark
464465
}
465466
}
466467
}
468+
469+
test("read/write") {
470+
def checkModelData(
471+
model: GeneralizedLinearRegressionModel,
472+
model2: GeneralizedLinearRegressionModel): Unit = {
473+
assert(model.intercept === model2.intercept)
474+
assert(model.coefficients.toArray === model2.coefficients.toArray)
475+
}
476+
477+
val glr = new GeneralizedLinearRegression()
478+
testEstimatorAndModelReadWrite(glr, datasetPoissonLog,
479+
GeneralizedLinearRegressionSuite.allParamSettings, checkModelData)
480+
}
467481
}
468482

469483
object GeneralizedLinearRegressionSuite {
470484

485+
/**
486+
* Mapping from all Params to valid settings which differ from the defaults.
487+
* This is useful for tests which need to exercise all Params, such as save/load.
488+
* This excludes input columns to simplify some tests.
489+
*/
490+
val allParamSettings: Map[String, Any] = Map(
491+
"family" -> "poisson",
492+
"link" -> "log",
493+
"fitIntercept" -> true,
494+
"maxIter" -> 2, // intentionally small
495+
"tol" -> 0.8,
496+
"regParam" -> 0.01,
497+
"predictionCol" -> "myPrediction")
498+
471499
def generateGeneralizedLinearRegressionInput(
472500
intercept: Double,
473501
coefficients: Array[Double],

0 commit comments

Comments
 (0)