diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 5c0d3dcf3af9..a50ff5f12563 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -29,15 +29,10 @@ setClass("PipelineModel", representation(model = "jobj")) #' @param formula A symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. #' @param data DataFrame for training -#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg. +#' @param family a description of the error distribution and link function to be used in the model, +#' as in [[https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html]] #' @param lambda Regularization parameter -#' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details) -#' @param standardize Whether to standardize features before training -#' @param solver The solver algorithm used for optimization, this can be "l-bfgs", "normal" and -#' "auto". "l-bfgs" denotes Limited-memory BFGS which is a limited-memory -#' quasi-Newton optimization method. "normal" denotes using Normal Equation as an -#' analytical solution to the linear regression problem. The default value is "auto" -#' which means that the solver algorithm is selected automatically. +#' @param solver Currently only support "irls" which is also the default solver. #' @return a fitted MLlib model #' @rdname glm #' @export @@ -51,13 +46,12 @@ setClass("PipelineModel", representation(model = "jobj")) #' summary(model) #'} setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"), - function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0, - standardize = TRUE, solver = "auto") { - family <- match.arg(family) + function(formula, family = gaussian(), data, lambda = 0, solver = "auto") { + familyName <- family$family + linkName <- family$link formula <- paste(deparse(formula), collapse = "") model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "fitRModelFormula", formula, data@sdf, family, lambda, - alpha, standardize, solver) + "fitGLM", formula, data@sdf, familyName, linkName, lambda, solver) return(new("PipelineModel", model = model)) }) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index d23e4fc9d1f5..84cb02d19b7a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -17,15 +17,34 @@ package org.apache.spark.ml.api.r +import org.apache.spark.SparkException import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} import org.apache.spark.ml.clustering.{KMeans, KMeansModel} import org.apache.spark.ml.feature.{RFormula, VectorAssembler} -import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} +import org.apache.spark.ml.regression._ import org.apache.spark.sql.DataFrame private[r] object SparkRWrappers { + def fitGLM( + value: String, + df: DataFrame, + family: String, + link: String, + lambda: Double, + solver: String): PipelineModel = { + val formula = new RFormula().setFormula(value) + val estimator = new GeneralizedLinearRegression() + .setFamily(family) + .setRegParam(lambda) + .setFitIntercept(formula.hasIntercept) + + if (link != null) estimator.setLink(link) + val pipeline = new Pipeline().setStages(Array(formula, estimator)) + pipeline.fit(df) + } + def fitRModelFormula( value: String, df: DataFrame,