Skip to content

Commit 39a4c4c

Browse files
committed
make SparkR model params and default values consistent with MLlib
1 parent 58f6e27 commit 39a4c4c

File tree

4 files changed

+43
-45
lines changed

4 files changed

+43
-45
lines changed

R/pkg/R/mllib.R

Lines changed: 36 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ setClass("KMeansModel", representation(jobj = "jobj"))
6464
#' This can be a character string naming a family function, a family function or
6565
#' the result of a call to a family function. Refer R family at
6666
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
67-
#' @param epsilon Positive convergence tolerance of iterations.
68-
#' @param maxit Integer giving the maximal number of IRLS iterations.
67+
#' @param tol Positive convergence tolerance of iterations.
68+
#' @param maxIter Integer giving the maximal number of IRLS iterations.
6969
#' @return a fitted generalized linear model
7070
#' @rdname spark.glm
7171
#' @export
@@ -74,32 +74,30 @@ setClass("KMeansModel", representation(jobj = "jobj"))
7474
#' sparkR.session()
7575
#' data(iris)
7676
#' df <- createDataFrame(iris)
77-
#' model <- spark.glm(df, Sepal_Length ~ Sepal_Width, family="gaussian")
77+
#' model <- spark.glm(df, Sepal_Length ~ Sepal_Width, family = "gaussian")
7878
#' summary(model)
7979
#' }
8080
#' @note spark.glm since 2.0.0
81-
setMethod(
82-
"spark.glm",
83-
signature(data = "SparkDataFrame", formula = "formula"),
84-
function(data, formula, family = gaussian, epsilon = 1e-06, maxit = 25) {
85-
if (is.character(family)) {
86-
family <- get(family, mode = "function", envir = parent.frame())
87-
}
88-
if (is.function(family)) {
89-
family <- family()
90-
}
91-
if (is.null(family$family)) {
92-
print(family)
93-
stop("'family' not recognized")
94-
}
81+
setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
82+
function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25) {
83+
if (is.character(family)) {
84+
family <- get(family, mode = "function", envir = parent.frame())
85+
}
86+
if (is.function(family)) {
87+
family <- family()
88+
}
89+
if (is.null(family$family)) {
90+
print(family)
91+
stop("'family' not recognized")
92+
}
9593

96-
formula <- paste(deparse(formula), collapse = "")
94+
formula <- paste(deparse(formula), collapse = "")
9795

98-
jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
99-
"fit", formula, data@sdf, family$family, family$link,
100-
epsilon, as.integer(maxit))
101-
return(new("GeneralizedLinearRegressionModel", jobj = jobj))
102-
})
96+
jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
97+
"fit", formula, data@sdf, family$family, family$link,
98+
tol, as.integer(maxIter))
99+
return(new("GeneralizedLinearRegressionModel", jobj = jobj))
100+
})
103101

104102
#' Fits a generalized linear model (R-compliant).
105103
#'
@@ -122,13 +120,13 @@ setMethod(
122120
#' sparkR.session()
123121
#' data(iris)
124122
#' df <- createDataFrame(iris)
125-
#' model <- glm(Sepal_Length ~ Sepal_Width, df, family="gaussian")
123+
#' model <- glm(Sepal_Length ~ Sepal_Width, df, family = "gaussian")
126124
#' summary(model)
127125
#' }
128126
#' @note glm since 1.5.0
129127
setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDataFrame"),
130-
function(formula, family = gaussian, data, epsilon = 1e-06, maxit = 25) {
131-
spark.glm(data, formula, family, epsilon, maxit)
128+
function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25) {
129+
spark.glm(data, formula, family, tol = epsilon, maxIter = maxit)
132130
})
133131

134132
#' Get the summary of a generalized linear model
@@ -298,17 +296,17 @@ setMethod("summary", signature(object = "NaiveBayesModel"),
298296
#' @export
299297
#' @examples
300298
#' \dontrun{
301-
#' model <- spark.kmeans(data, ~ ., k=2, initMode="random")
299+
#' model <- spark.kmeans(data, ~ ., k = 4, initMode = "random")
302300
#' }
303301
#' @note spark.kmeans since 2.0.0
304302
setMethod("spark.kmeans", signature(data = "SparkDataFrame", formula = "formula"),
305-
function(data, formula, k, maxIter = 10, initMode = c("random", "k-means||")) {
303+
function(data, formula, k = 2, maxIter = 20, initMode = c("k-means||", "random")) {
306304
formula <- paste(deparse(formula), collapse = "")
307305
initMode <- match.arg(initMode)
308306
jobj <- callJStatic("org.apache.spark.ml.r.KMeansWrapper", "fit", data@sdf, formula,
309307
as.integer(k), as.integer(maxIter), initMode)
310308
return(new("KMeansModel", jobj = jobj))
311-
})
309+
})
312310

313311
#' Get fitted result from a k-means model
314312
#'
@@ -401,24 +399,24 @@ setMethod("predict", signature(object = "KMeansModel"),
401399
#' @param data SparkDataFrame for training
402400
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
403401
#' operators are supported, including '~', '.', ':', '+', and '-'.
404-
#' @param laplace Smoothing parameter
402+
#' @param smoothing Smoothing parameter
405403
#' @return a fitted naive Bayes model
406404
#' @rdname spark.naiveBayes
407405
#' @seealso e1071: \url{https://cran.r-project.org/web/packages/e1071/}
408406
#' @export
409407
#' @examples
410408
#' \dontrun{
411409
#' df <- createDataFrame(infert)
412-
#' model <- spark.naiveBayes(df, education ~ ., laplace = 0)
410+
#' model <- spark.naiveBayes(df, education ~ ., smoothing = 0)
413411
#'}
414412
#' @note spark.naiveBayes since 2.0.0
415413
setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "formula"),
416-
function(data, formula, laplace = 0, ...) {
417-
formula <- paste(deparse(formula), collapse = "")
418-
jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit",
419-
formula, data@sdf, laplace)
420-
return(new("NaiveBayesModel", jobj = jobj))
421-
})
414+
function(data, formula, smoothing = 1.0, ...) {
415+
formula <- paste(deparse(formula), collapse = "")
416+
jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit",
417+
formula, data@sdf, smoothing)
418+
return(new("NaiveBayesModel", jobj = jobj))
419+
})
422420

423421
#' Save fitted MLlib model to the input path
424422
#'
@@ -435,7 +433,7 @@ setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "form
435433
#' @examples
436434
#' \dontrun{
437435
#' df <- createDataFrame(infert)
438-
#' model <- spark.naiveBayes(df, education ~ ., laplace = 0)
436+
#' model <- spark.naiveBayes(df, education ~ ., smoothing = 0)
439437
#' path <- "path/to/model"
440438
#' write.ml(model, path)
441439
#' }

R/pkg/inst/tests/testthat/test_mllib.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ test_that("spark.naiveBayes", {
363363
t <- as.data.frame(Titanic)
364364
t1 <- t[t$Freq > 0, -5]
365365
df <- suppressWarnings(createDataFrame(t1))
366-
m <- spark.naiveBayes(df, Survived ~ .)
366+
m <- spark.naiveBayes(df, Survived ~ ., smoothing = 0.0)
367367
s <- summary(m)
368368
expect_equal(as.double(s$apriori[1, "Yes"]), 0.5833333, tolerance = 1e-6)
369369
expect_equal(sum(s$apriori), 1)

mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ private[r] object GeneralizedLinearRegressionWrapper
6767
data: DataFrame,
6868
family: String,
6969
link: String,
70-
epsilon: Double,
71-
maxit: Int): GeneralizedLinearRegressionWrapper = {
70+
tol: Double,
71+
maxIter: Int): GeneralizedLinearRegressionWrapper = {
7272
val rFormula = new RFormula()
7373
.setFormula(formula)
7474
val rFormulaModel = rFormula.fit(data)
@@ -82,8 +82,8 @@ private[r] object GeneralizedLinearRegressionWrapper
8282
.setFamily(family)
8383
.setLink(link)
8484
.setFitIntercept(rFormula.hasIntercept)
85-
.setTol(epsilon)
86-
.setMaxIter(maxit)
85+
.setTol(tol)
86+
.setMaxIter(maxIter)
8787
val pipeline = new Pipeline()
8888
.setStages(Array(rFormulaModel, glr))
8989
.fit(data)

mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] {
5656
val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
5757
val PREDICTED_LABEL_COL = "prediction"
5858

59-
def fit(formula: String, data: DataFrame, laplace: Double): NaiveBayesWrapper = {
59+
def fit(formula: String, data: DataFrame, smoothing: Double): NaiveBayesWrapper = {
6060
val rFormula = new RFormula()
6161
.setFormula(formula)
6262
.fit(data)
@@ -70,7 +70,7 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] {
7070
val features = featureAttrs.map(_.name.get)
7171
// assemble and fit the pipeline
7272
val naiveBayes = new NaiveBayes()
73-
.setSmoothing(laplace)
73+
.setSmoothing(smoothing)
7474
.setModelType("bernoulli")
7575
.setPredictionCol(PREDICTED_LABEL_INDEX_COL)
7676
val idxToStr = new IndexToString()

0 commit comments

Comments
 (0)