From 6c933650389798f8e3caf3e50604bceae79a126e Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Sun, 6 Mar 2016 15:00:44 -0800 Subject: [PATCH 1/3] change R glm to use GLM --- R/pkg/R/mllib.R | 22 +++---- .../apache/spark/ml/r/SparkRWrappers.scala | 44 ++++++++++++- .../GeneralizedLinearRegression.scala | 64 +++++++++++++++---- 3 files changed, 104 insertions(+), 26 deletions(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 346f33d7dab2..88f46912311c 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -29,15 +29,9 @@ setClass("PipelineModel", representation(model = "jobj")) #' @param formula A symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. #' @param data DataFrame for training -#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg. +#' @param family a description of the error distribution and link function to be used in the model.. #' @param lambda Regularization parameter -#' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details) -#' @param standardize Whether to standardize features before training -#' @param solver The solver algorithm used for optimization, this can be "l-bfgs", "normal" and -#' "auto". "l-bfgs" denotes Limited-memory BFGS which is a limited-memory -#' quasi-Newton optimization method. "normal" denotes using Normal Equation as an -#' analytical solution to the linear regression problem. The default value is "auto" -#' which means that the solver algorithm is selected automatically. +#' @param solver Currently only support "irls" which is also the default solver. #' @return a fitted MLlib model #' @rdname glm #' @export @@ -51,13 +45,10 @@ setClass("PipelineModel", representation(model = "jobj")) #' summary(model) #'} setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"), - function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0, - standardize = TRUE, solver = "auto") { - family <- match.arg(family) + function(formula, family="gaussian", data, lambda = 0, solver = "irls") { formula <- paste(deparse(formula), collapse="") model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "fitRModelFormula", formula, data@sdf, family, lambda, - alpha, standardize, solver) + "fitGLM", formula, data@sdf, family, lambda, solver) return(new("PipelineModel", model = model)) }) @@ -124,6 +115,11 @@ setMethod("summary", signature(object = "PipelineModel"), colnames(coefficients) <- c("Estimate") rownames(coefficients) <- unlist(features) return(list(coefficients = coefficients)) + } else if (modelName == "GeneralizedLinearRegressionModel") { + coefficients <- as.matrix(unlist(coefficients)) + colnames(coefficients) <- c("Estimate") + rownames(coefficients) <- unlist(features) + return(list(coefficients = coefficients)) } else if (modelName == "KMeansModel") { modelSize <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "getKMeansModelSize", object@model) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index d23e4fc9d1f5..f43e10bfe6ad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -17,15 +17,41 @@ package org.apache.spark.ml.api.r +import org.apache.spark.SparkException import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} import org.apache.spark.ml.clustering.{KMeans, KMeansModel} import org.apache.spark.ml.feature.{RFormula, VectorAssembler} -import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} +import org.apache.spark.ml.regression._ import org.apache.spark.sql.DataFrame private[r] object SparkRWrappers { + def fitGLM( + value: String, + df: DataFrame, + family: String, + lambda: Double, + solver: String): PipelineModel = { + if (solver.trim != "irls") throw new SparkException("Currently only support irls") + + val formula = new RFormula().setFormula(value) + val regex = "^\\s*(\\w+)\\s*(\\(\\s*link\\s*=\\s*\"(\\w+)\"\\s*\\))?\\s*$".r + val estimator = family match { + case regex(familyName, group2, linkName) => + val estimator = new GeneralizedLinearRegression() + .setFamily(familyName) + .setRegParam(lambda) + .setFitIntercept(formula.hasIntercept) + if (linkName != null) estimator.setLink(linkName) + estimator + case _ => throw new SparkException(s"Could not parse family: $family") + } + + val pipeline = new Pipeline().setStages(Array(formula, estimator)) + pipeline.fit(df) + } + def fitRModelFormula( value: String, df: DataFrame, @@ -91,6 +117,12 @@ private[r] object SparkRWrappers { } case m: KMeansModel => m.clusterCenters.flatMap(_.toArray) + case m: GeneralizedLinearRegressionModel => + if (m.getFitIntercept) { + Array(m.intercept) ++ m.coefficients.toArray + } else { + m.coefficients.toArray + } } } @@ -151,6 +183,14 @@ private[r] object SparkRWrappers { val attrs = AttributeGroup.fromStructField( m.summary.predictions.schema(m.summary.featuresCol)) attrs.attributes.get.map(_.name.get) + case m: GeneralizedLinearRegressionModel => + val attrs = AttributeGroup.fromStructField( + m.summary.predictions.schema(m.summary.featuresCol)) + if (m.getFitIntercept) { + Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) + } else { + attrs.attributes.get.map(_.name.get) + } } } @@ -162,6 +202,8 @@ private[r] object SparkRWrappers { "LogisticRegressionModel" case m: KMeansModel => "KMeansModel" + case m: GeneralizedLinearRegressionModel => + "GeneralizedLinearRegressionModel" } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index a850dfee0a45..4ab978ce53ee 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -208,26 +208,29 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val Instance(label, weight, features) } - if (familyObj == Gaussian && linkObj == Identity) { + val model = if (familyObj == Gaussian && linkObj == Identity) { // TODO: Make standardizeFeatures and standardizeLabel configurable. val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), standardizeFeatures = true, standardizeLabel = true) val wlsModel = optimizer.fit(instances) - val model = copyValues( + copyValues( new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients, wlsModel.intercept) .setParent(this)) - return model + } + else { + // Fit Generalized Linear Model by iteratively reweighted least squares (IRLS). + val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam)) + val optimizer = new IterativelyReweightedLeastSquares(initialModel, + familyAndLink.reweightFunc, $(fitIntercept), $(regParam), $(maxIter), $(tol)) + val irlsModel = optimizer.fit(instances) + copyValues( + new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept) + .setParent(this)) } - // Fit Generalized Linear Model by iteratively reweighted least squares (IRLS). - val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam)) - val optimizer = new IterativelyReweightedLeastSquares(initialModel, familyAndLink.reweightFunc, - $(fitIntercept), $(regParam), $(maxIter), $(tol)) - val irlsModel = optimizer.fit(instances) - - val model = copyValues( - new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept) - .setParent(this)) + val summary = new GeneralizedLinearRegressionSummary(model.transform(dataset), + $(predictionCol), $(labelCol), $(featuresCol)) + model.setSummary(summary) model } @@ -569,9 +572,46 @@ class GeneralizedLinearRegressionModel private[ml] ( familyAndLink.fitted(eta) } + private var trainingSummary: Option[GeneralizedLinearRegressionSummary] = None + + private[regression] def setSummary(summary: GeneralizedLinearRegressionSummary): this.type = { + this.trainingSummary = Some(summary) + this + } + + /** + * Gets summary of model on training set. An exception is + * thrown if `trainingSummary == None`. + */ + @Since("2.0.0") + def summary: GeneralizedLinearRegressionSummary = trainingSummary match { + case Some(summ) => summ + case None => + throw new SparkException( + "No training summary available for this GeneralizedLinearRegressionModel", + new NullPointerException()) + } + @Since("2.0.0") override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = { copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra) .setParent(parent) } } + +/** + * :: Experimental :: + * GeneralizedLinearRegressionModel results evaluated on a dataset. + * + * @param predictions dataframe outputted by the model's `transform` method. + * @param predictionCol field in "predictions" which gives the prediction of each instance. + * @param labelCol field in "predictions" which gives the true label of each instance. + * @param featuresCol field in "predictions" which gives the features of each instance as a vector. + */ +@Experimental +@Since("2.0.0") +class GeneralizedLinearRegressionSummary private[regression] ( + @Since("2.0.0") @transient val predictions: DataFrame, + @Since("2.0.0") val predictionCol: String, + @Since("2.0.0") val labelCol: String, + @Since("2.0.0") val featuresCol: String) extends Serializable From 1cca19e68d9ef256769594e02d123ce6e3b0bd7d Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Sun, 6 Mar 2016 15:27:58 -0800 Subject: [PATCH 2/3] refine family --- R/pkg/R/mllib.R | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 88f46912311c..b4ce2aa9866c 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -45,7 +45,9 @@ setClass("PipelineModel", representation(model = "jobj")) #' summary(model) #'} setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"), - function(formula, family="gaussian", data, lambda = 0, solver = "irls") { + function(formula, family = c("gaussian", "binomial", "poisson", "gamma"), data, + lambda = 0, solver = "irls") { + family <- match.arg(family) formula <- paste(deparse(formula), collapse="") model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "fitGLM", formula, data@sdf, family, lambda, solver) From 8b3dd3eebd808bc2c8e139f55e20de9d54b5cda2 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Thu, 17 Mar 2016 01:58:30 -0700 Subject: [PATCH 3/3] extract link and family name in R --- R/pkg/R/mllib.R | 16 ++++----- .../apache/spark/ml/r/SparkRWrappers.scala | 35 ++++--------------- 2 files changed, 12 insertions(+), 39 deletions(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index a1f763ccb90d..a50ff5f12563 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -29,7 +29,8 @@ setClass("PipelineModel", representation(model = "jobj")) #' @param formula A symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. #' @param data DataFrame for training -#' @param family a description of the error distribution and link function to be used in the model.. +#' @param family a description of the error distribution and link function to be used in the model, +#' as in [[https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html]] #' @param lambda Regularization parameter #' @param solver Currently only support "irls" which is also the default solver. #' @return a fitted MLlib model @@ -45,12 +46,12 @@ setClass("PipelineModel", representation(model = "jobj")) #' summary(model) #'} setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"), - function(formula, family = c("gaussian", "binomial", "poisson", "gamma"), data, - lambda = 0, solver = "irls") { - family <- match.arg(family) + function(formula, family = gaussian(), data, lambda = 0, solver = "auto") { + familyName <- family$family + linkName <- family$link formula <- paste(deparse(formula), collapse = "") model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "fitGLM", formula, data@sdf, family, lambda, solver) + "fitGLM", formula, data@sdf, familyName, linkName, lambda, solver) return(new("PipelineModel", model = model)) }) @@ -117,11 +118,6 @@ setMethod("summary", signature(object = "PipelineModel"), colnames(coefficients) <- c("Estimate") rownames(coefficients) <- unlist(features) return(list(coefficients = coefficients)) - } else if (modelName == "GeneralizedLinearRegressionModel") { - coefficients <- as.matrix(unlist(coefficients)) - colnames(coefficients) <- c("Estimate") - rownames(coefficients) <- unlist(features) - return(list(coefficients = coefficients)) } else if (modelName == "KMeansModel") { modelSize <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "getKMeansModelSize", object@model) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index f43e10bfe6ad..84cb02d19b7a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -31,23 +31,16 @@ private[r] object SparkRWrappers { value: String, df: DataFrame, family: String, + link: String, lambda: Double, solver: String): PipelineModel = { - if (solver.trim != "irls") throw new SparkException("Currently only support irls") - val formula = new RFormula().setFormula(value) - val regex = "^\\s*(\\w+)\\s*(\\(\\s*link\\s*=\\s*\"(\\w+)\"\\s*\\))?\\s*$".r - val estimator = family match { - case regex(familyName, group2, linkName) => - val estimator = new GeneralizedLinearRegression() - .setFamily(familyName) - .setRegParam(lambda) - .setFitIntercept(formula.hasIntercept) - if (linkName != null) estimator.setLink(linkName) - estimator - case _ => throw new SparkException(s"Could not parse family: $family") - } + val estimator = new GeneralizedLinearRegression() + .setFamily(family) + .setRegParam(lambda) + .setFitIntercept(formula.hasIntercept) + if (link != null) estimator.setLink(link) val pipeline = new Pipeline().setStages(Array(formula, estimator)) pipeline.fit(df) } @@ -117,12 +110,6 @@ private[r] object SparkRWrappers { } case m: KMeansModel => m.clusterCenters.flatMap(_.toArray) - case m: GeneralizedLinearRegressionModel => - if (m.getFitIntercept) { - Array(m.intercept) ++ m.coefficients.toArray - } else { - m.coefficients.toArray - } } } @@ -183,14 +170,6 @@ private[r] object SparkRWrappers { val attrs = AttributeGroup.fromStructField( m.summary.predictions.schema(m.summary.featuresCol)) attrs.attributes.get.map(_.name.get) - case m: GeneralizedLinearRegressionModel => - val attrs = AttributeGroup.fromStructField( - m.summary.predictions.schema(m.summary.featuresCol)) - if (m.getFitIntercept) { - Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) - } else { - attrs.attributes.get.map(_.name.get) - } } } @@ -202,8 +181,6 @@ private[r] object SparkRWrappers { "LogisticRegressionModel" case m: KMeansModel => "KMeansModel" - case m: GeneralizedLinearRegressionModel => - "GeneralizedLinearRegressionModel" } } }