Skip to content

Commit d4a9122

Browse files
yanboliangshivaram
authored andcommitted
[SPARK-16710][SPARKR][ML] spark.glm should support weightCol
## What changes were proposed in this pull request? Training GLMs on weighted dataset is very important use cases, but it is not supported by SparkR currently. Users can pass argument ```weights``` to specify the weights vector in native R. For ```spark.glm```, we can pass in the ```weightCol``` which is consistent with MLlib. ## How was this patch tested? Unit test. Author: Yanbo Liang <[email protected]> Closes #14346 from yanboliang/spark-16710.
1 parent 19af298 commit d4a9122

File tree

3 files changed

+36
-5
lines changed

3 files changed

+36
-5
lines changed

R/pkg/R/mllib.R

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ NULL
9191
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
9292
#' @param tol Positive convergence tolerance of iterations.
9393
#' @param maxIter Integer giving the maximal number of IRLS iterations.
94+
#' @param weightCol The weight column name. If this is not set or NULL, we treat all instance
95+
#' weights as 1.0.
9496
#' @aliases spark.glm,SparkDataFrame,formula-method
9597
#' @return \code{spark.glm} returns a fitted generalized linear model
9698
#' @rdname spark.glm
@@ -119,7 +121,7 @@ NULL
119121
#' @note spark.glm since 2.0.0
120122
#' @seealso \link{glm}, \link{read.ml}
121123
setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
122-
function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25) {
124+
function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL) {
123125
if (is.character(family)) {
124126
family <- get(family, mode = "function", envir = parent.frame())
125127
}
@@ -132,10 +134,13 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
132134
}
133135

134136
formula <- paste(deparse(formula), collapse = "")
137+
if (is.null(weightCol)) {
138+
weightCol <- ""
139+
}
135140

136141
jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
137142
"fit", formula, data@sdf, family$family, family$link,
138-
tol, as.integer(maxIter))
143+
tol, as.integer(maxIter), weightCol)
139144
return(new("GeneralizedLinearRegressionModel", jobj = jobj))
140145
})
141146

@@ -151,6 +156,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
151156
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
152157
#' @param epsilon Positive convergence tolerance of iterations.
153158
#' @param maxit Integer giving the maximal number of IRLS iterations.
159+
#' @param weightCol The weight column name. If this is not set or NULL, we treat all instance
160+
#' weights as 1.0.
154161
#' @return \code{glm} returns a fitted generalized linear model.
155162
#' @rdname glm
156163
#' @export
@@ -165,8 +172,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
165172
#' @note glm since 1.5.0
166173
#' @seealso \link{spark.glm}
167174
setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDataFrame"),
168-
function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25) {
169-
spark.glm(data, formula, family, tol = epsilon, maxIter = maxit)
175+
function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25, weightCol = NULL) {
176+
spark.glm(data, formula, family, tol = epsilon, maxIter = maxit, weightCol = weightCol)
170177
})
171178

172179
# Returns the summary of a model produced by glm() or spark.glm(), similarly to R's summary().

R/pkg/inst/tests/testthat/test_mllib.R

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,28 @@ test_that("spark.glm summary", {
118118
expect_equal(stats$df.residual, rStats$df.residual)
119119
expect_equal(stats$aic, rStats$aic)
120120

121+
# Test spark.glm works with weighted dataset
122+
a1 <- c(0, 1, 2, 3)
123+
a2 <- c(5, 2, 1, 3)
124+
w <- c(1, 2, 3, 4)
125+
b <- c(1, 0, 1, 0)
126+
data <- as.data.frame(cbind(a1, a2, w, b))
127+
df <- suppressWarnings(createDataFrame(data))
128+
129+
stats <- summary(spark.glm(df, b ~ a1 + a2, family = "binomial", weightCol = "w"))
130+
rStats <- summary(glm(b ~ a1 + a2, family = "binomial", data = data, weights = w))
131+
132+
coefs <- unlist(stats$coefficients)
133+
rCoefs <- unlist(rStats$coefficients)
134+
expect_true(all(abs(rCoefs - coefs) < 1e-3))
135+
expect_true(all(rownames(stats$coefficients) == c("(Intercept)", "a1", "a2")))
136+
expect_equal(stats$dispersion, rStats$dispersion)
137+
expect_equal(stats$null.deviance, rStats$null.deviance)
138+
expect_equal(stats$deviance, rStats$deviance)
139+
expect_equal(stats$df.null, rStats$df.null)
140+
expect_equal(stats$df.residual, rStats$df.residual)
141+
expect_equal(stats$aic, rStats$aic)
142+
121143
# Test summary works on base GLM models
122144
baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris)
123145
baseSummary <- summary(baseModel)

mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ private[r] object GeneralizedLinearRegressionWrapper
6868
family: String,
6969
link: String,
7070
tol: Double,
71-
maxIter: Int): GeneralizedLinearRegressionWrapper = {
71+
maxIter: Int,
72+
weightCol: String): GeneralizedLinearRegressionWrapper = {
7273
val rFormula = new RFormula()
7374
.setFormula(formula)
7475
val rFormulaModel = rFormula.fit(data)
@@ -84,6 +85,7 @@ private[r] object GeneralizedLinearRegressionWrapper
8485
.setFitIntercept(rFormula.hasIntercept)
8586
.setTol(tol)
8687
.setMaxIter(maxIter)
88+
.setWeightCol(weightCol)
8789
val pipeline = new Pipeline()
8890
.setStages(Array(rFormulaModel, glr))
8991
.fit(data)

0 commit comments

Comments
 (0)