Skip to content

Commit bc36fe6

Browse files
thunterdbmengxr
authored andcommitted
[SPARK-14831][SPARKR] Make the SparkR MLlib API more consistent with Spark
## What changes were proposed in this pull request? This PR splits the MLlib algorithms into two flavors: - the R flavor, which tries to mimic the existing R API for these algorithms (and works as an S4 specialization for Spark dataframes) - the Spark flavor, which follows the same API and naming conventions as the rest of the MLlib algorithms in the other languages In practice, the former calls the latter. ## How was this patch tested? The tests for the various algorithms were adapted to be run against both interfaces. Author: Timothy Hunter <[email protected]> Closes #12789 from thunterdb/14831.
1 parent 43b149f commit bc36fe6

File tree

4 files changed

+247
-72
lines changed

4 files changed

+247
-72
lines changed

R/pkg/NAMESPACE

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@ export("print.jobj")
1212

1313
# MLlib integration
1414
exportMethods("glm",
15+
"spark.glm",
1516
"predict",
1617
"summary",
17-
"kmeans",
18+
"spark.kmeans",
1819
"fitted",
19-
"naiveBayes",
20-
"survreg")
20+
"spark.naiveBayes",
21+
"spark.survreg")
2122

2223
# Job group lifecycle management methods
2324
export("setJobGroup",

R/pkg/R/generics.R

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,6 +1181,10 @@ setGeneric("window", function(x, ...) { standardGeneric("window") })
11811181
#' @export
11821182
setGeneric("year", function(x) { standardGeneric("year") })
11831183

1184+
#' @rdname spark.glm
1185+
#' @export
1186+
setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") })
1187+
11841188
#' @rdname glm
11851189
#' @export
11861190
setGeneric("glm")
@@ -1193,21 +1197,21 @@ setGeneric("predict", function(object, ...) { standardGeneric("predict") })
11931197
#' @export
11941198
setGeneric("rbind", signature = "...")
11951199

1196-
#' @rdname kmeans
1200+
#' @rdname spark.kmeans
11971201
#' @export
1198-
setGeneric("kmeans")
1202+
setGeneric("spark.kmeans", function(data, k, ...) { standardGeneric("spark.kmeans") })
11991203

12001204
#' @rdname fitted
12011205
#' @export
12021206
setGeneric("fitted")
12031207

1204-
#' @rdname naiveBayes
1208+
#' @rdname spark.naiveBayes
12051209
#' @export
1206-
setGeneric("naiveBayes", function(formula, data, ...) { standardGeneric("naiveBayes") })
1210+
setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("spark.naiveBayes") })
12071211

1208-
#' @rdname survreg
1212+
#' @rdname spark.survreg
12091213
#' @export
1210-
setGeneric("survreg", function(formula, data, ...) { standardGeneric("survreg") })
1214+
setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") })
12111215

12121216
#' @rdname ml.save
12131217
#' @export

R/pkg/R/mllib.R

Lines changed: 97 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@
1717

1818
# mllib.R: Provides methods for MLlib integration
1919

20+
# Integration with R's standard functions.
21+
# Most of MLlib's argorithms are provided in two flavours:
22+
# - a specialization of the default R methods (glm). These methods try to respect
23+
# the inputs and the outputs of R's method to the largest extent, but some small differences
24+
# may exist.
25+
# - a set of methods that reflect the arguments of the other languages supported by Spark. These
26+
# methods are prefixed with the `spark.` prefix: spark.glm, spark.kmeans, etc.
27+
2028
#' @title S4 class that represents a generalized linear model
2129
#' @param jobj a Java object reference to the backing Scala GeneralizedLinearRegressionWrapper
2230
#' @export
@@ -39,6 +47,54 @@ setClass("KMeansModel", representation(jobj = "jobj"))
3947

4048
#' Fits a generalized linear model
4149
#'
50+
#' Fits a generalized linear model against a Spark DataFrame.
51+
#'
52+
#' @param data SparkDataFrame for training.
53+
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
54+
#' operators are supported, including '~', '.', ':', '+', and '-'.
55+
#' @param family A description of the error distribution and link function to be used in the model.
56+
#' This can be a character string naming a family function, a family function or
57+
#' the result of a call to a family function. Refer R family at
58+
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
59+
#' @param epsilon Positive convergence tolerance of iterations.
60+
#' @param maxit Integer giving the maximal number of IRLS iterations.
61+
#' @return a fitted generalized linear model
62+
#' @rdname spark.glm
63+
#' @export
64+
#' @examples
65+
#' \dontrun{
66+
#' sc <- sparkR.init()
67+
#' sqlContext <- sparkRSQL.init(sc)
68+
#' data(iris)
69+
#' df <- createDataFrame(sqlContext, iris)
70+
#' model <- spark.glm(df, Sepal_Length ~ Sepal_Width, family="gaussian")
71+
#' summary(model)
72+
#' }
73+
setMethod(
74+
"spark.glm",
75+
signature(data = "SparkDataFrame", formula = "formula"),
76+
function(data, formula, family = gaussian, epsilon = 1e-06, maxit = 25) {
77+
if (is.character(family)) {
78+
family <- get(family, mode = "function", envir = parent.frame())
79+
}
80+
if (is.function(family)) {
81+
family <- family()
82+
}
83+
if (is.null(family$family)) {
84+
print(family)
85+
stop("'family' not recognized")
86+
}
87+
88+
formula <- paste(deparse(formula), collapse = "")
89+
90+
jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
91+
"fit", formula, data@sdf, family$family, family$link,
92+
epsilon, as.integer(maxit))
93+
return(new("GeneralizedLinearRegressionModel", jobj = jobj))
94+
})
95+
96+
#' Fits a generalized linear model (R-compliant).
97+
#'
4298
#' Fits a generalized linear model, similarly to R's glm().
4399
#'
44100
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
@@ -64,23 +120,7 @@ setClass("KMeansModel", representation(jobj = "jobj"))
64120
#' }
65121
setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDataFrame"),
66122
function(formula, family = gaussian, data, epsilon = 1e-06, maxit = 25) {
67-
if (is.character(family)) {
68-
family <- get(family, mode = "function", envir = parent.frame())
69-
}
70-
if (is.function(family)) {
71-
family <- family()
72-
}
73-
if (is.null(family$family)) {
74-
print(family)
75-
stop("'family' not recognized")
76-
}
77-
78-
formula <- paste(deparse(formula), collapse = "")
79-
80-
jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
81-
"fit", formula, data@sdf, family$family, family$link,
82-
epsilon, as.integer(maxit))
83-
return(new("GeneralizedLinearRegressionModel", jobj = jobj))
123+
spark.glm(data, formula, family, epsilon, maxit)
84124
})
85125

86126
#' Get the summary of a generalized linear model
@@ -188,7 +228,7 @@ setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"),
188228
#' @export
189229
#' @examples
190230
#' \dontrun{
191-
#' model <- naiveBayes(y ~ x, trainingData)
231+
#' model <- spark.naiveBayes(trainingData, y ~ x)
192232
#' predicted <- predict(model, testData)
193233
#' showDF(predicted)
194234
#'}
@@ -208,7 +248,7 @@ setMethod("predict", signature(object = "NaiveBayesModel"),
208248
#' @export
209249
#' @examples
210250
#' \dontrun{
211-
#' model <- naiveBayes(y ~ x, trainingData)
251+
#' model <- spark.naiveBayes(trainingData, y ~ x)
212252
#' summary(model)
213253
#'}
214254
setMethod("summary", signature(object = "NaiveBayesModel"),
@@ -230,23 +270,23 @@ setMethod("summary", signature(object = "NaiveBayesModel"),
230270
#'
231271
#' Fit a k-means model, similarly to R's kmeans().
232272
#'
233-
#' @param x SparkDataFrame for training
234-
#' @param centers Number of centers
235-
#' @param iter.max Maximum iteration number
236-
#' @param algorithm Algorithm choosen to fit the model
273+
#' @param data SparkDataFrame for training
274+
#' @param k Number of centers
275+
#' @param maxIter Maximum iteration number
276+
#' @param initializationMode Algorithm choosen to fit the model
237277
#' @return A fitted k-means model
238-
#' @rdname kmeans
278+
#' @rdname spark.kmeans
239279
#' @export
240280
#' @examples
241281
#' \dontrun{
242-
#' model <- kmeans(x, centers = 2, algorithm="random")
282+
#' model <- spark.kmeans(data, k = 2, initializationMode="random")
243283
#' }
244-
setMethod("kmeans", signature(x = "SparkDataFrame"),
245-
function(x, centers, iter.max = 10, algorithm = c("random", "k-means||")) {
246-
columnNames <- as.array(colnames(x))
247-
algorithm <- match.arg(algorithm)
248-
jobj <- callJStatic("org.apache.spark.ml.r.KMeansWrapper", "fit", x@sdf,
249-
centers, iter.max, algorithm, columnNames)
284+
setMethod("spark.kmeans", signature(data = "SparkDataFrame"),
285+
function(data, k, maxIter = 10, initializationMode = c("random", "k-means||")) {
286+
columnNames <- as.array(colnames(data))
287+
initializationMode <- match.arg(initializationMode)
288+
jobj <- callJStatic("org.apache.spark.ml.r.KMeansWrapper", "fit", data@sdf,
289+
k, maxIter, initializationMode, columnNames)
250290
return(new("KMeansModel", jobj = jobj))
251291
})
252292

@@ -261,7 +301,7 @@ setMethod("kmeans", signature(x = "SparkDataFrame"),
261301
#' @export
262302
#' @examples
263303
#' \dontrun{
264-
#' model <- kmeans(trainingData, 2)
304+
#' model <- spark.kmeans(trainingData, 2)
265305
#' fitted.model <- fitted(model)
266306
#' showDF(fitted.model)
267307
#'}
@@ -288,7 +328,7 @@ setMethod("fitted", signature(object = "KMeansModel"),
288328
#' @export
289329
#' @examples
290330
#' \dontrun{
291-
#' model <- kmeans(trainingData, 2)
331+
#' model <- spark.kmeans(trainingData, 2)
292332
#' summary(model)
293333
#' }
294334
setMethod("summary", signature(object = "KMeansModel"),
@@ -322,7 +362,7 @@ setMethod("summary", signature(object = "KMeansModel"),
322362
#' @export
323363
#' @examples
324364
#' \dontrun{
325-
#' model <- kmeans(trainingData, 2)
365+
#' model <- spark.kmeans(trainingData, 2)
326366
#' predicted <- predict(model, testData)
327367
#' showDF(predicted)
328368
#' }
@@ -333,30 +373,28 @@ setMethod("predict", signature(object = "KMeansModel"),
333373

334374
#' Fit a Bernoulli naive Bayes model
335375
#'
336-
#' Fit a Bernoulli naive Bayes model, similarly to R package e1071's naiveBayes() while only
337-
#' categorical features are supported. The input should be a SparkDataFrame of observations instead
338-
#' of a contingency table.
376+
#' Fit a Bernoulli naive Bayes model on a Spark DataFrame (only categorical data is supported).
339377
#'
378+
#' @param data SparkDataFrame for training
340379
#' @param object A symbolic description of the model to be fitted. Currently only a few formula
341380
#' operators are supported, including '~', '.', ':', '+', and '-'.
342-
#' @param data SparkDataFrame for training
343381
#' @param laplace Smoothing parameter
344382
#' @return a fitted naive Bayes model
345-
#' @rdname naiveBayes
383+
#' @rdname spark.naiveBayes
346384
#' @seealso e1071: \url{https://cran.r-project.org/web/packages/e1071/}
347385
#' @export
348386
#' @examples
349387
#' \dontrun{
350388
#' df <- createDataFrame(sqlContext, infert)
351-
#' model <- naiveBayes(education ~ ., df, laplace = 0)
389+
#' model <- spark.naiveBayes(df, education ~ ., laplace = 0)
352390
#'}
353-
setMethod("naiveBayes", signature(formula = "formula", data = "SparkDataFrame"),
354-
function(formula, data, laplace = 0, ...) {
355-
formula <- paste(deparse(formula), collapse = "")
356-
jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit",
357-
formula, data@sdf, laplace)
358-
return(new("NaiveBayesModel", jobj = jobj))
359-
})
391+
setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "formula"),
392+
function(data, formula, laplace = 0, ...) {
393+
formula <- paste(deparse(formula), collapse = "")
394+
jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit",
395+
formula, data@sdf, laplace)
396+
return(new("NaiveBayesModel", jobj = jobj))
397+
})
360398

361399
#' Save the Bernoulli naive Bayes model to the input path.
362400
#'
@@ -371,7 +409,7 @@ setMethod("naiveBayes", signature(formula = "formula", data = "SparkDataFrame"),
371409
#' @examples
372410
#' \dontrun{
373411
#' df <- createDataFrame(sqlContext, infert)
374-
#' model <- naiveBayes(education ~ ., df, laplace = 0)
412+
#' model <- spark.naiveBayes(education ~ ., df, laplace = 0)
375413
#' path <- "path/to/model"
376414
#' ml.save(model, path)
377415
#' }
@@ -396,7 +434,7 @@ setMethod("ml.save", signature(object = "NaiveBayesModel", path = "character"),
396434
#' @export
397435
#' @examples
398436
#' \dontrun{
399-
#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData)
437+
#' model <- spark.survreg(trainingData, Surv(futime, fustat) ~ ecog_ps + rx)
400438
#' path <- "path/to/model"
401439
#' ml.save(model, path)
402440
#' }
@@ -446,7 +484,7 @@ setMethod("ml.save", signature(object = "GeneralizedLinearRegressionModel", path
446484
#' @export
447485
#' @examples
448486
#' \dontrun{
449-
#' model <- kmeans(x, centers = 2, algorithm="random")
487+
#' model <- spark.kmeans(x, k = 2, initializationMode="random")
450488
#' path <- "path/to/model"
451489
#' ml.save(model, path)
452490
#' }
@@ -489,29 +527,30 @@ ml.load <- function(path) {
489527

490528
#' Fit an accelerated failure time (AFT) survival regression model.
491529
#'
492-
#' Fit an accelerated failure time (AFT) survival regression model, similarly to R's survreg().
530+
#' Fit an accelerated failure time (AFT) survival regression model on a Spark DataFrame.
493531
#'
532+
#' @param data SparkDataFrame for training.
494533
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
495534
#' operators are supported, including '~', ':', '+', and '-'.
496535
#' Note that operator '.' is not supported currently.
497-
#' @param data SparkDataFrame for training.
498536
#' @return a fitted AFT survival regression model
499-
#' @rdname survreg
537+
#' @rdname spark.survreg
500538
#' @seealso survival: \url{https://cran.r-project.org/web/packages/survival/}
501539
#' @export
502540
#' @examples
503541
#' \dontrun{
504542
#' df <- createDataFrame(sqlContext, ovarian)
505-
#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, df)
543+
#' model <- spark.survreg(Surv(df, futime, fustat) ~ ecog_ps + rx)
506544
#' }
507-
setMethod("survreg", signature(formula = "formula", data = "SparkDataFrame"),
508-
function(formula, data, ...) {
545+
setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula"),
546+
function(data, formula, ...) {
509547
formula <- paste(deparse(formula), collapse = "")
510548
jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper",
511549
"fit", formula, data@sdf)
512550
return(new("AFTSurvivalRegressionModel", jobj = jobj))
513551
})
514552

553+
515554
#' Get the summary of an AFT survival regression model
516555
#'
517556
#' Returns the summary of an AFT survival regression model produced by survreg(),
@@ -523,7 +562,7 @@ setMethod("survreg", signature(formula = "formula", data = "SparkDataFrame"),
523562
#' @export
524563
#' @examples
525564
#' \dontrun{
526-
#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData)
565+
#' model <- spark.survreg(trainingData, Surv(futime, fustat) ~ ecog_ps + rx)
527566
#' summary(model)
528567
#' }
529568
setMethod("summary", signature(object = "AFTSurvivalRegressionModel"),
@@ -548,7 +587,7 @@ setMethod("summary", signature(object = "AFTSurvivalRegressionModel"),
548587
#' @export
549588
#' @examples
550589
#' \dontrun{
551-
#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData)
590+
#' model <- spark.survreg(trainingData, Surv(futime, fustat) ~ ecog_ps + rx)
552591
#' predicted <- predict(model, testData)
553592
#' showDF(predicted)
554593
#' }

0 commit comments

Comments
 (0)