Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions R/pkg/R/mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,10 @@ setClass("PipelineModel", representation(model = "jobj"))
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', '.', ':', '+', and '-'.
#' @param data DataFrame for training
#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg.
#' @param family a description of the error distribution and link function to be used in the model,
#' as in [[https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html]]
#' @param lambda Regularization parameter
#' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details)
#' @param standardize Whether to standardize features before training
#' @param solver The solver algorithm used for optimization, this can be "l-bfgs", "normal" and
#' "auto". "l-bfgs" denotes Limited-memory BFGS which is a limited-memory
#' quasi-Newton optimization method. "normal" denotes using Normal Equation as an
#' analytical solution to the linear regression problem. The default value is "auto"
#' which means that the solver algorithm is selected automatically.
#' @param solver Currently only support "irls" which is also the default solver.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous comment was more explicit, especially with respect to 'auto' (the default). It should mention auto and irls as the two options.

#' @return a fitted MLlib model
#' @rdname glm
#' @export
Expand All @@ -51,13 +46,12 @@ setClass("PipelineModel", representation(model = "jobj"))
#' summary(model)
#'}
setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"),
function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0,
standardize = TRUE, solver = "auto") {
family <- match.arg(family)
function(formula, family = gaussian(), data, lambda = 0, solver = "auto") {
familyName <- family$family
linkName <- family$link
formula <- paste(deparse(formula), collapse = "")
model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"fitRModelFormula", formula, data@sdf, family, lambda,
alpha, standardize, solver)
"fitGLM", formula, data@sdf, familyName, linkName, lambda, solver)
return(new("PipelineModel", model = model))
})

Expand Down
21 changes: 20 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,34 @@

package org.apache.spark.ml.api.r

import org.apache.spark.SparkException
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
import org.apache.spark.ml.feature.{RFormula, VectorAssembler}
import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
import org.apache.spark.ml.regression._
import org.apache.spark.sql.DataFrame

private[r] object SparkRWrappers {
def fitGLM(
value: String,
df: DataFrame,
family: String,
link: String,
lambda: Double,
solver: String): PipelineModel = {
val formula = new RFormula().setFormula(value)
val estimator = new GeneralizedLinearRegression()
.setFamily(family)
.setRegParam(lambda)
.setFitIntercept(formula.hasIntercept)

if (link != null) estimator.setLink(link)
val pipeline = new Pipeline().setStages(Array(formula, estimator))
pipeline.fit(df)
}

def fitRModelFormula(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that method is not used anymore, right? We should remove it.

value: String,
df: DataFrame,
Expand Down