1717
1818# mllib.R: Provides methods for MLlib integration
1919
20+ # Integration with R's standard functions.
21+ # Most of MLlib's argorithms are provided in two flavours:
22+ # - a specialization of the default R methods (glm). These methods try to respect
23+ # the inputs and the outputs of R's method to the largest extent, but some small differences
24+ # may exist.
25+ # - a set of methods that reflect the arguments of the other languages supported by Spark. These
26+ # methods are prefixed with the `spark.` prefix: spark.glm, spark.kmeans, etc.
27+
2028# ' @title S4 class that represents a generalized linear model
2129# ' @param jobj a Java object reference to the backing Scala GeneralizedLinearRegressionWrapper
2230# ' @export
@@ -39,6 +47,54 @@ setClass("KMeansModel", representation(jobj = "jobj"))
3947
4048# ' Fits a generalized linear model
4149# '
50+ # ' Fits a generalized linear model against a Spark DataFrame.
51+ # '
52+ # ' @param data SparkDataFrame for training.
53+ # ' @param formula A symbolic description of the model to be fitted. Currently only a few formula
54+ # ' operators are supported, including '~', '.', ':', '+', and '-'.
55+ # ' @param family A description of the error distribution and link function to be used in the model.
56+ # ' This can be a character string naming a family function, a family function or
57+ # ' the result of a call to a family function. Refer R family at
58+ # ' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
59+ # ' @param epsilon Positive convergence tolerance of iterations.
60+ # ' @param maxit Integer giving the maximal number of IRLS iterations.
61+ # ' @return a fitted generalized linear model
62+ # ' @rdname spark.glm
63+ # ' @export
64+ # ' @examples
65+ # ' \dontrun{
66+ # ' sc <- sparkR.init()
67+ # ' sqlContext <- sparkRSQL.init(sc)
68+ # ' data(iris)
69+ # ' df <- createDataFrame(sqlContext, iris)
70+ # ' model <- spark.glm(df, Sepal_Length ~ Sepal_Width, family="gaussian")
71+ # ' summary(model)
72+ # ' }
73+ setMethod(
74+ " spark.glm" ,
75+ signature(data = " SparkDataFrame" , formula = " formula" ),
76+ function (data , formula , family = gaussian , epsilon = 1e-06 , maxit = 25 ) {
77+ if (is.character(family )) {
78+ family <- get(family , mode = " function" , envir = parent.frame())
79+ }
80+ if (is.function(family )) {
81+ family <- family()
82+ }
83+ if (is.null(family $ family )) {
84+ print(family )
85+ stop(" 'family' not recognized" )
86+ }
87+
88+ formula <- paste(deparse(formula ), collapse = " " )
89+
90+ jobj <- callJStatic(" org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper" ,
91+ " fit" , formula , data @ sdf , family $ family , family $ link ,
92+ epsilon , as.integer(maxit ))
93+ return (new(" GeneralizedLinearRegressionModel" , jobj = jobj ))
94+ })
95+
96+ # ' Fits a generalized linear model (R-compliant).
97+ # '
4298# ' Fits a generalized linear model, similarly to R's glm().
4399# '
44100# ' @param formula A symbolic description of the model to be fitted. Currently only a few formula
@@ -64,23 +120,7 @@ setClass("KMeansModel", representation(jobj = "jobj"))
64120# ' }
65121setMethod ("glm ", signature(formula = "formula", family = "ANY", data = "SparkDataFrame"),
66122 function (formula , family = gaussian , data , epsilon = 1e-06 , maxit = 25 ) {
67- if (is.character(family )) {
68- family <- get(family , mode = " function" , envir = parent.frame())
69- }
70- if (is.function(family )) {
71- family <- family()
72- }
73- if (is.null(family $ family )) {
74- print(family )
75- stop(" 'family' not recognized" )
76- }
77-
78- formula <- paste(deparse(formula ), collapse = " " )
79-
80- jobj <- callJStatic(" org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper" ,
81- " fit" , formula , data @ sdf , family $ family , family $ link ,
82- epsilon , as.integer(maxit ))
83- return (new(" GeneralizedLinearRegressionModel" , jobj = jobj ))
123+ spark.glm(data , formula , family , epsilon , maxit )
84124 })
85125
86126# ' Get the summary of a generalized linear model
@@ -188,7 +228,7 @@ setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"),
188228# ' @export
189229# ' @examples
190230# ' \dontrun{
191- # ' model <- naiveBayes(y ~ x, trainingData )
231+ # ' model <- spark. naiveBayes(trainingData, y ~ x)
192232# ' predicted <- predict(model, testData)
193233# ' showDF(predicted)
194234# '}
@@ -208,7 +248,7 @@ setMethod("predict", signature(object = "NaiveBayesModel"),
208248# ' @export
209249# ' @examples
210250# ' \dontrun{
211- # ' model <- naiveBayes(y ~ x, trainingData )
251+ # ' model <- spark. naiveBayes(trainingData, y ~ x)
212252# ' summary(model)
213253# '}
214254setMethod ("summary ", signature(object = "NaiveBayesModel"),
@@ -230,23 +270,23 @@ setMethod("summary", signature(object = "NaiveBayesModel"),
230270# '
231271# ' Fit a k-means model, similarly to R's kmeans().
232272# '
233- # ' @param x SparkDataFrame for training
234- # ' @param centers Number of centers
235- # ' @param iter.max Maximum iteration number
236- # ' @param algorithm Algorithm choosen to fit the model
273+ # ' @param data SparkDataFrame for training
274+ # ' @param k Number of centers
275+ # ' @param maxIter Maximum iteration number
276+ # ' @param initializationMode Algorithm choosen to fit the model
237277# ' @return A fitted k-means model
238- # ' @rdname kmeans
278+ # ' @rdname spark. kmeans
239279# ' @export
240280# ' @examples
241281# ' \dontrun{
242- # ' model <- kmeans(x, centers = 2, algorithm ="random")
282+ # ' model <- spark. kmeans(data, k = 2, initializationMode ="random")
243283# ' }
244- setMethod ("kmeans ", signature(x = "SparkDataFrame"),
245- function (x , centers , iter.max = 10 , algorithm = c(" random" , " k-means||" )) {
246- columnNames <- as.array(colnames(x ))
247- algorithm <- match.arg(algorithm )
248- jobj <- callJStatic(" org.apache.spark.ml.r.KMeansWrapper" , " fit" , x @ sdf ,
249- centers , iter.max , algorithm , columnNames )
284+ setMethod ("spark. kmeans ", signature(data = "SparkDataFrame"),
285+ function (data , k , maxIter = 10 , initializationMode = c(" random" , " k-means||" )) {
286+ columnNames <- as.array(colnames(data ))
287+ initializationMode <- match.arg(initializationMode )
288+ jobj <- callJStatic(" org.apache.spark.ml.r.KMeansWrapper" , " fit" , data @ sdf ,
289+ k , maxIter , initializationMode , columnNames )
250290 return (new(" KMeansModel" , jobj = jobj ))
251291 })
252292
@@ -261,7 +301,7 @@ setMethod("kmeans", signature(x = "SparkDataFrame"),
261301# ' @export
262302# ' @examples
263303# ' \dontrun{
264- # ' model <- kmeans(trainingData, 2)
304+ # ' model <- spark. kmeans(trainingData, 2)
265305# ' fitted.model <- fitted(model)
266306# ' showDF(fitted.model)
267307# '}
@@ -288,7 +328,7 @@ setMethod("fitted", signature(object = "KMeansModel"),
288328# ' @export
289329# ' @examples
290330# ' \dontrun{
291- # ' model <- kmeans(trainingData, 2)
331+ # ' model <- spark. kmeans(trainingData, 2)
292332# ' summary(model)
293333# ' }
294334setMethod ("summary ", signature(object = "KMeansModel"),
@@ -322,7 +362,7 @@ setMethod("summary", signature(object = "KMeansModel"),
322362# ' @export
323363# ' @examples
324364# ' \dontrun{
325- # ' model <- kmeans(trainingData, 2)
365+ # ' model <- spark. kmeans(trainingData, 2)
326366# ' predicted <- predict(model, testData)
327367# ' showDF(predicted)
328368# ' }
@@ -333,30 +373,28 @@ setMethod("predict", signature(object = "KMeansModel"),
333373
334374# ' Fit a Bernoulli naive Bayes model
335375# '
336- # ' Fit a Bernoulli naive Bayes model, similarly to R package e1071's naiveBayes() while only
337- # ' categorical features are supported. The input should be a SparkDataFrame of observations instead
338- # ' of a contingency table.
376+ # ' Fit a Bernoulli naive Bayes model on a Spark DataFrame (only categorical data is supported).
339377# '
378+ # ' @param data SparkDataFrame for training
340379# ' @param object A symbolic description of the model to be fitted. Currently only a few formula
341380# ' operators are supported, including '~', '.', ':', '+', and '-'.
342- # ' @param data SparkDataFrame for training
343381# ' @param laplace Smoothing parameter
344382# ' @return a fitted naive Bayes model
345- # ' @rdname naiveBayes
383+ # ' @rdname spark. naiveBayes
346384# ' @seealso e1071: \url{https://cran.r-project.org/web/packages/e1071/}
347385# ' @export
348386# ' @examples
349387# ' \dontrun{
350388# ' df <- createDataFrame(sqlContext, infert)
351- # ' model <- naiveBayes(education ~ ., df , laplace = 0)
389+ # ' model <- spark. naiveBayes(df, education ~ ., laplace = 0)
352390# '}
353- setMethod ("naiveBayes ", signature(formula = "formula ", data = "SparkDataFrame "),
354- function (formula , data , laplace = 0 , ... ) {
355- formula <- paste(deparse(formula ), collapse = " " )
356- jobj <- callJStatic(" org.apache.spark.ml.r.NaiveBayesWrapper" , " fit" ,
357- formula , data @ sdf , laplace )
358- return (new(" NaiveBayesModel" , jobj = jobj ))
359- })
391+ setMethod ("spark. naiveBayes ", signature(data = "SparkDataFrame ", formula = "formula "),
392+ function (data , formula , laplace = 0 , ... ) {
393+ formula <- paste(deparse(formula ), collapse = " " )
394+ jobj <- callJStatic(" org.apache.spark.ml.r.NaiveBayesWrapper" , " fit" ,
395+ formula , data @ sdf , laplace )
396+ return (new(" NaiveBayesModel" , jobj = jobj ))
397+ })
360398
361399# ' Save the Bernoulli naive Bayes model to the input path.
362400# '
@@ -371,7 +409,7 @@ setMethod("naiveBayes", signature(formula = "formula", data = "SparkDataFrame"),
371409# ' @examples
372410# ' \dontrun{
373411# ' df <- createDataFrame(sqlContext, infert)
374- # ' model <- naiveBayes(education ~ ., df, laplace = 0)
412+ # ' model <- spark. naiveBayes(education ~ ., df, laplace = 0)
375413# ' path <- "path/to/model"
376414# ' ml.save(model, path)
377415# ' }
@@ -396,7 +434,7 @@ setMethod("ml.save", signature(object = "NaiveBayesModel", path = "character"),
396434# ' @export
397435# ' @examples
398436# ' \dontrun{
399- # ' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData )
437+ # ' model <- spark. survreg(trainingData, Surv(futime, fustat) ~ ecog_ps + rx)
400438# ' path <- "path/to/model"
401439# ' ml.save(model, path)
402440# ' }
@@ -446,7 +484,7 @@ setMethod("ml.save", signature(object = "GeneralizedLinearRegressionModel", path
446484# ' @export
447485# ' @examples
448486# ' \dontrun{
449- # ' model <- kmeans(x, centers = 2, algorithm ="random")
487+ # ' model <- spark. kmeans(x, k = 2, initializationMode ="random")
450488# ' path <- "path/to/model"
451489# ' ml.save(model, path)
452490# ' }
@@ -489,29 +527,30 @@ ml.load <- function(path) {
489527
490528# ' Fit an accelerated failure time (AFT) survival regression model.
491529# '
492- # ' Fit an accelerated failure time (AFT) survival regression model, similarly to R's survreg() .
530+ # ' Fit an accelerated failure time (AFT) survival regression model on a Spark DataFrame .
493531# '
532+ # ' @param data SparkDataFrame for training.
494533# ' @param formula A symbolic description of the model to be fitted. Currently only a few formula
495534# ' operators are supported, including '~', ':', '+', and '-'.
496535# ' Note that operator '.' is not supported currently.
497- # ' @param data SparkDataFrame for training.
498536# ' @return a fitted AFT survival regression model
499- # ' @rdname survreg
537+ # ' @rdname spark. survreg
500538# ' @seealso survival: \url{https://cran.r-project.org/web/packages/survival/}
501539# ' @export
502540# ' @examples
503541# ' \dontrun{
504542# ' df <- createDataFrame(sqlContext, ovarian)
505- # ' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, df )
543+ # ' model <- spark. survreg(Surv(df, futime, fustat) ~ ecog_ps + rx)
506544# ' }
507- setMethod ("survreg ", signature(formula = "formula ", data = "SparkDataFrame "),
508- function (formula , data , ... ) {
545+ setMethod ("spark. survreg ", signature(data = "SparkDataFrame ", formula = "formula "),
546+ function (data , formula , ... ) {
509547 formula <- paste(deparse(formula ), collapse = " " )
510548 jobj <- callJStatic(" org.apache.spark.ml.r.AFTSurvivalRegressionWrapper" ,
511549 " fit" , formula , data @ sdf )
512550 return (new(" AFTSurvivalRegressionModel" , jobj = jobj ))
513551 })
514552
553+
515554# ' Get the summary of an AFT survival regression model
516555# '
517556# ' Returns the summary of an AFT survival regression model produced by survreg(),
@@ -523,7 +562,7 @@ setMethod("survreg", signature(formula = "formula", data = "SparkDataFrame"),
523562# ' @export
524563# ' @examples
525564# ' \dontrun{
526- # ' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData )
565+ # ' model <- spark. survreg(trainingData, Surv(futime, fustat) ~ ecog_ps + rx)
527566# ' summary(model)
528567# ' }
529568setMethod ("summary ", signature(object = "AFTSurvivalRegressionModel"),
@@ -548,7 +587,7 @@ setMethod("summary", signature(object = "AFTSurvivalRegressionModel"),
548587# ' @export
549588# ' @examples
550589# ' \dontrun{
551- # ' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData )
590+ # ' model <- spark. survreg(trainingData, Surv(futime, fustat) ~ ecog_ps + rx)
552591# ' predicted <- predict(model, testData)
553592# ' showDF(predicted)
554593# ' }
0 commit comments