Skip to content

Commit 5680138

Browse files
felixcheungcmonkey
authored andcommitted
[SPARK-19133][SPARKR][ML] fix glm for Gamma, clarify glm family supported
## What changes were proposed in this pull request? R family is a longer list than what Spark supports. ## How was this patch tested? manual Author: Felix Cheung <[email protected]> Closes apache#16511 from felixcheung/rdocglmfamily.
1 parent ee54c0b commit 5680138

File tree

3 files changed

+24
-11
lines changed

3 files changed

+24
-11
lines changed

R/pkg/R/mllib_regression.R

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj"))
5252
#' This can be a character string naming a family function, a family function or
5353
#' the result of a call to a family function. Refer R family at
5454
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
55+
#' Currently these families are supported: \code{binomial}, \code{gaussian},
56+
#' \code{Gamma}, and \code{poisson}.
5557
#' @param tol positive convergence tolerance of iterations.
5658
#' @param maxIter integer giving the maximal number of IRLS iterations.
5759
#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
@@ -104,8 +106,9 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
104106
weightCol <- ""
105107
}
106108

109+
# For known families, Gamma is upper-cased
107110
jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
108-
"fit", formula, data@sdf, family$family, family$link,
111+
"fit", formula, data@sdf, tolower(family$family), family$link,
109112
tol, as.integer(maxIter), as.character(weightCol), regParam)
110113
new("GeneralizedLinearRegressionModel", jobj = jobj)
111114
})
@@ -120,6 +123,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
120123
#' This can be a character string naming a family function, a family function or
121124
#' the result of a call to a family function. Refer R family at
122125
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
126+
#' Currently these families are supported: \code{binomial}, \code{gaussian},
127+
#' \code{Gamma}, and \code{poisson}.
123128
#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
124129
#' weights as 1.0.
125130
#' @param epsilon positive convergence tolerance of iterations.

R/pkg/R/sparkR.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ sparkR.session <- function(
423423
#' sparkR.session()
424424
#' url <- sparkR.uiWebUrl()
425425
#' }
426-
#' @note sparkR.uiWebUrl since 2.2.0
426+
#' @note sparkR.uiWebUrl since 2.1.1
427427
sparkR.uiWebUrl <- function() {
428428
sc <- sparkR.callJMethod(getSparkContext(), "sc")
429429
u <- callJMethod(sc, "uiWebUrl")

R/pkg/inst/tests/testthat/test_mllib_regression.R

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,22 @@ test_that("spark.glm and predict", {
6161

6262
# poisson family
6363
model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species,
64-
family = poisson(link = identity))
64+
family = poisson(link = identity))
6565
prediction <- predict(model, training)
6666
expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
6767
vals <- collect(select(prediction, "prediction"))
6868
rVals <- suppressWarnings(predict(glm(Sepal.Width ~ Sepal.Length + Species,
69-
data = iris, family = poisson(link = identity)), iris))
69+
data = iris, family = poisson(link = identity)), iris))
7070
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
7171

72+
# Gamma family
73+
x <- runif(100, -1, 1)
74+
y <- rgamma(100, rate = 10 / exp(0.5 + 1.2 * x), shape = 10)
75+
df <- as.DataFrame(as.data.frame(list(x = x, y = y)))
76+
model <- glm(y ~ x, family = Gamma, df)
77+
out <- capture.output(print(summary(model)))
78+
expect_true(any(grepl("Dispersion parameter for gamma family", out)))
79+
7280
# Test stats::predict is working
7381
x <- rnorm(15)
7482
y <- x + rnorm(15)
@@ -103,11 +111,11 @@ test_that("spark.glm summary", {
103111
df <- suppressWarnings(createDataFrame(iris))
104112
training <- df[df$Species %in% c("versicolor", "virginica"), ]
105113
stats <- summary(spark.glm(training, Species ~ Sepal_Length + Sepal_Width,
106-
family = binomial(link = "logit")))
114+
family = binomial(link = "logit")))
107115

108116
rTraining <- iris[iris$Species %in% c("versicolor", "virginica"), ]
109117
rStats <- summary(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining,
110-
family = binomial(link = "logit")))
118+
family = binomial(link = "logit")))
111119

112120
coefs <- unlist(stats$coefficients)
113121
rCoefs <- unlist(rStats$coefficients)
@@ -222,7 +230,7 @@ test_that("glm and predict", {
222230
training <- suppressWarnings(createDataFrame(iris))
223231
# gaussian family
224232
model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training)
225-
prediction <- predict(model, training)
233+
prediction <- predict(model, training)
226234
expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
227235
vals <- collect(select(prediction, "prediction"))
228236
rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris)
@@ -235,7 +243,7 @@ test_that("glm and predict", {
235243
expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
236244
vals <- collect(select(prediction, "prediction"))
237245
rVals <- suppressWarnings(predict(glm(Sepal.Width ~ Sepal.Length + Species,
238-
data = iris, family = poisson(link = identity)), iris))
246+
data = iris, family = poisson(link = identity)), iris))
239247
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
240248

241249
# Test stats::predict is working
@@ -268,11 +276,11 @@ test_that("glm summary", {
268276
df <- suppressWarnings(createDataFrame(iris))
269277
training <- df[df$Species %in% c("versicolor", "virginica"), ]
270278
stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training,
271-
family = binomial(link = "logit")))
279+
family = binomial(link = "logit")))
272280

273281
rTraining <- iris[iris$Species %in% c("versicolor", "virginica"), ]
274282
rStats <- summary(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining,
275-
family = binomial(link = "logit")))
283+
family = binomial(link = "logit")))
276284

277285
coefs <- unlist(stats$coefficients)
278286
rCoefs <- unlist(rStats$coefficients)
@@ -409,7 +417,7 @@ test_that("spark.survreg", {
409417
x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1))
410418
expect_error(
411419
model <- survival::survreg(formula = survival::Surv(time, status) ~ x + sex, data = rData),
412-
NA)
420+
NA)
413421
expect_equal(predict(model, rData)[[1]], 3.724591, tolerance = 1e-4)
414422
}
415423
})

0 commit comments

Comments
 (0)