1818package org .apache .spark .ml .regression
1919
2020import breeze .stats .distributions .{Gaussian => GD }
21+ import org .apache .hadoop .fs .Path
2122
2223import org .apache .spark .{Logging , SparkException }
2324import org .apache .spark .annotation .{Experimental , Since }
@@ -26,7 +27,7 @@ import org.apache.spark.ml.feature.Instance
2627import org .apache .spark .ml .optim ._
2728import org .apache .spark .ml .param ._
2829import org .apache .spark .ml .param .shared ._
29- import org .apache .spark .ml .util .Identifiable
30+ import org .apache .spark .ml .util ._
3031import org .apache .spark .mllib .linalg .{BLAS , Vector }
3132import org .apache .spark .rdd .RDD
3233import org .apache .spark .sql .{DataFrame , Row }
@@ -106,7 +107,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
106107@ Since (" 2.0.0" )
107108class GeneralizedLinearRegression @ Since (" 2.0.0" ) (@ Since (" 2.0.0" ) override val uid : String )
108109 extends Regressor [Vector , GeneralizedLinearRegression , GeneralizedLinearRegressionModel ]
109- with GeneralizedLinearRegressionBase with Logging {
110+ with GeneralizedLinearRegressionBase with DefaultParamsWritable with Logging {
110111
111112 import GeneralizedLinearRegression ._
112113
@@ -236,23 +237,26 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
236237}
237238
238239@ Since (" 2.0.0" )
239- private [ml] object GeneralizedLinearRegression {
240+ object GeneralizedLinearRegression extends DefaultParamsReadable [GeneralizedLinearRegression ] {
241+
242+ @ Since (" 2.0.0" )
243+ override def load (path : String ): GeneralizedLinearRegression = super .load(path)
240244
241245 /** Set of family and link pairs that GeneralizedLinearRegression supports. */
242- lazy val supportedFamilyAndLinkPairs = Set (
246+ private [ml] lazy val supportedFamilyAndLinkPairs = Set (
243247 Gaussian -> Identity , Gaussian -> Log , Gaussian -> Inverse ,
244248 Binomial -> Logit , Binomial -> Probit , Binomial -> CLogLog ,
245249 Poisson -> Log , Poisson -> Identity , Poisson -> Sqrt ,
246250 Gamma -> Inverse , Gamma -> Identity , Gamma -> Log
247251 )
248252
249253 /** Set of family names that GeneralizedLinearRegression supports. */
250- lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name)
254+ private [ml] lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name)
251255
252256 /** Set of link names that GeneralizedLinearRegression supports. */
253- lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name)
257+ private [ml] lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name)
254258
255- val epsilon : Double = 1E-16
259+ private [ml] val epsilon : Double = 1E-16
256260
257261 /**
258262 * Wrapper of family and link combination used in the model.
@@ -552,7 +556,7 @@ class GeneralizedLinearRegressionModel private[ml] (
552556 @ Since (" 2.0.0" ) val coefficients : Vector ,
553557 @ Since (" 2.0.0" ) val intercept : Double )
554558 extends RegressionModel [Vector , GeneralizedLinearRegressionModel ]
555- with GeneralizedLinearRegressionBase {
559+ with GeneralizedLinearRegressionBase with MLWritable {
556560
557561 import GeneralizedLinearRegression ._
558562
@@ -574,4 +578,58 @@ class GeneralizedLinearRegressionModel private[ml] (
574578 copyValues(new GeneralizedLinearRegressionModel (uid, coefficients, intercept), extra)
575579 .setParent(parent)
576580 }
581+
582+ @ Since (" 2.0.0" )
583+ override def write : MLWriter =
584+ new GeneralizedLinearRegressionModel .GeneralizedLinearRegressionModelWriter (this )
585+ }
586+
587+ @ Since (" 2.0.0" )
588+ object GeneralizedLinearRegressionModel extends MLReadable [GeneralizedLinearRegressionModel ] {
589+
590+ @ Since (" 2.0.0" )
591+ override def read : MLReader [GeneralizedLinearRegressionModel ] =
592+ new GeneralizedLinearRegressionModelReader
593+
594+ @ Since (" 2.0.0" )
595+ override def load (path : String ): GeneralizedLinearRegressionModel = super .load(path)
596+
597+ /** [[MLWriter ]] instance for [[GeneralizedLinearRegressionModel ]] */
598+ private [GeneralizedLinearRegressionModel ]
599+ class GeneralizedLinearRegressionModelWriter (instance : GeneralizedLinearRegressionModel )
600+ extends MLWriter with Logging {
601+
602+ private case class Data (intercept : Double , coefficients : Vector )
603+
604+ override protected def saveImpl (path : String ): Unit = {
605+ // Save metadata and Params
606+ DefaultParamsWriter .saveMetadata(instance, path, sc)
607+ // Save model data: intercept, coefficients
608+ val data = Data (instance.intercept, instance.coefficients)
609+ val dataPath = new Path (path, " data" ).toString
610+ sqlContext.createDataFrame(Seq (data)).repartition(1 ).write.parquet(dataPath)
611+ }
612+ }
613+
614+ private class GeneralizedLinearRegressionModelReader
615+ extends MLReader [GeneralizedLinearRegressionModel ] {
616+
617+ /** Checked against metadata when loading model */
618+ private val className = classOf [GeneralizedLinearRegressionModel ].getName
619+
620+ override def load (path : String ): GeneralizedLinearRegressionModel = {
621+ val metadata = DefaultParamsReader .loadMetadata(path, sc, className)
622+
623+ val dataPath = new Path (path, " data" ).toString
624+ val data = sqlContext.read.parquet(dataPath)
625+ .select(" intercept" , " coefficients" ).head()
626+ val intercept = data.getDouble(0 )
627+ val coefficients = data.getAs[Vector ](1 )
628+
629+ val model = new GeneralizedLinearRegressionModel (metadata.uid, coefficients, intercept)
630+
631+ DefaultParamsReader .getAndSetParams(model, metadata)
632+ model
633+ }
634+ }
577635}
0 commit comments