Skip to content

Commit a355dde

Browse files
committed
Review SparkR ML wrappers API for 2.1
1 parent f1fca81 commit a355dde

File tree

4 files changed

+12
-19
lines changed

4 files changed

+12
-19
lines changed

R/pkg/R/mllib.R

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,6 @@ setMethod("predict", signature(object = "KMeansModel"),
733733
#' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p
734734
#' is the original probability of that class and t is the class's threshold.
735735
#' @param weightCol The weight column name.
736-
#' @param probabilityCol column name for predicted class conditional probabilities.
737736
#' @param ... additional arguments passed to the method.
738737
#' @return \code{spark.logit} returns a fitted logistic regression model
739738
#' @rdname spark.logit
@@ -772,7 +771,7 @@ setMethod("predict", signature(object = "KMeansModel"),
772771
setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"),
773772
function(data, formula, regParam = 0.0, elasticNetParam = 0.0, maxIter = 100,
774773
tol = 1E-6, family = "auto", standardization = TRUE,
775-
thresholds = 0.5, weightCol = NULL, probabilityCol = "probability") {
774+
thresholds = 0.5, weightCol = NULL) {
776775
formula <- paste(deparse(formula), collapse = "")
777776

778777
if (is.null(weightCol)) {
@@ -784,7 +783,7 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula")
784783
as.numeric(elasticNetParam), as.integer(maxIter),
785784
as.numeric(tol), as.character(family),
786785
as.logical(standardization), as.array(thresholds),
787-
as.character(weightCol), as.character(probabilityCol))
786+
as.character(weightCol))
788787
new("LogisticRegressionModel", jobj = jobj)
789788
})
790789

@@ -1425,7 +1424,7 @@ setMethod("predict", signature(object = "GaussianMixtureModel"),
14251424
#' @param userCol column name for user ids. Ids must be (or can be coerced into) integers.
14261425
#' @param itemCol column name for item ids. Ids must be (or can be coerced into) integers.
14271426
#' @param rank rank of the matrix factorization (> 0).
1428-
#' @param reg regularization parameter (>= 0).
1427+
#' @param regParam regularization parameter (>= 0).
14291428
#' @param maxIter maximum number of iterations (>= 0).
14301429
#' @param nonnegative logical value indicating whether to apply nonnegativity constraints.
14311430
#' @param implicitPrefs logical value indicating whether to use implicit preference.
@@ -1464,20 +1463,20 @@ setMethod("predict", signature(object = "GaussianMixtureModel"),
14641463
#'
14651464
#' # set other arguments
14661465
#' modelS <- spark.als(df, "rating", "user", "item", rank = 20,
1467-
#' reg = 0.1, nonnegative = TRUE)
1466+
#' regParam = 0.1, nonnegative = TRUE)
14681467
#' statsS <- summary(modelS)
14691468
#' }
14701469
#' @note spark.als since 2.1.0
14711470
setMethod("spark.als", signature(data = "SparkDataFrame"),
14721471
function(data, ratingCol = "rating", userCol = "user", itemCol = "item",
1473-
rank = 10, reg = 0.1, maxIter = 10, nonnegative = FALSE,
1472+
rank = 10, regParam = 0.1, maxIter = 10, nonnegative = FALSE,
14741473
implicitPrefs = FALSE, alpha = 1.0, numUserBlocks = 10, numItemBlocks = 10,
14751474
checkpointInterval = 10, seed = 0) {
14761475

14771476
if (!is.numeric(rank) || rank <= 0) {
14781477
stop("rank should be a positive number.")
14791478
}
1480-
if (!is.numeric(reg) || reg < 0) {
1479+
if (!is.numeric(regParam) || regParam < 0) {
14811480
stop("reg should be a nonnegative number.")
14821481
}
14831482
if (!is.numeric(maxIter) || maxIter <= 0) {
@@ -1486,7 +1485,7 @@ setMethod("spark.als", signature(data = "SparkDataFrame"),
14861485

14871486
jobj <- callJStatic("org.apache.spark.ml.r.ALSWrapper",
14881487
"fit", data@sdf, ratingCol, userCol, itemCol, as.integer(rank),
1489-
reg, as.integer(maxIter), implicitPrefs, alpha, nonnegative,
1488+
regParam, as.integer(maxIter), implicitPrefs, alpha, nonnegative,
14901489
as.integer(numUserBlocks), as.integer(numItemBlocks),
14911490
as.integer(checkpointInterval), as.integer(seed))
14921491
new("ALSModel", jobj = jobj)
@@ -1684,8 +1683,6 @@ print.summary.KSTest <- function(x, ...) {
16841683
#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching
16851684
#' can speed up training of deeper trees. Users can set how often should the
16861685
#' cache be checkpointed or disable it by setting checkpointInterval.
1687-
#' @param probabilityCol column name for predicted class conditional probabilities, only for
1688-
#' classification.
16891686
#' @param ... additional arguments passed to the method.
16901687
#' @aliases spark.randomForest,SparkDataFrame,formula-method
16911688
#' @return \code{spark.randomForest} returns a fitted Random Forest model.
@@ -1720,7 +1717,7 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo
17201717
maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL,
17211718
featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0,
17221719
minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10,
1723-
maxMemoryInMB = 256, cacheNodeIds = FALSE, probabilityCol = "probability") {
1720+
maxMemoryInMB = 256, cacheNodeIds = FALSE) {
17241721
type <- match.arg(type)
17251722
formula <- paste(deparse(formula), collapse = "")
17261723
if (!is.null(seed)) {
@@ -1749,7 +1746,7 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo
17491746
impurity, as.integer(minInstancesPerNode),
17501747
as.numeric(minInfoGain), as.integer(checkpointInterval),
17511748
as.character(featureSubsetStrategy), seed,
1752-
as.numeric(subsamplingRate), as.character(probabilityCol),
1749+
as.numeric(subsamplingRate),
17531750
as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
17541751
new("RandomForestClassificationModel", jobj = jobj)
17551752
}

R/pkg/inst/tests/testthat/test_mllib.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -926,10 +926,10 @@ test_that("spark.posterior and spark.perplexity", {
926926

927927
test_that("spark.als", {
928928
data <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0),
929-
list(2, 1, 1.0), list(2, 2, 5.0))
929+
list(2, 1, 1.0), list(2, 2, 5.0))
930930
df <- createDataFrame(data, c("user", "item", "score"))
931931
model <- spark.als(df, ratingCol = "score", userCol = "user", itemCol = "item",
932-
rank = 10, maxIter = 5, seed = 0, reg = 0.1)
932+
rank = 10, maxIter = 5, seed = 0, regParam = 0.1)
933933
stats <- summary(model)
934934
expect_equal(stats$rank, 10)
935935
test <- createDataFrame(list(list(0, 2), list(1, 0), list(2, 0)), c("user", "item"))

mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,7 @@ private[r] object LogisticRegressionWrapper
9696
family: String,
9797
standardization: Boolean,
9898
thresholds: Array[Double],
99-
weightCol: String,
100-
probabilityCol: String
99+
weightCol: String
101100
): LogisticRegressionWrapper = {
102101

103102
val rFormula = new RFormula()
@@ -123,7 +122,6 @@ private[r] object LogisticRegressionWrapper
123122
.setWeightCol(weightCol)
124123
.setFeaturesCol(rFormula.getFeaturesCol)
125124
.setLabelCol(rFormula.getLabelCol)
126-
.setProbabilityCol(probabilityCol)
127125
.setPredictionCol(PREDICTED_LABEL_INDEX_COL)
128126

129127
if (thresholds.length > 1) {

mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC
7676
featureSubsetStrategy: String,
7777
seed: String,
7878
subsamplingRate: Double,
79-
probabilityCol: String,
8079
maxMemoryInMB: Int,
8180
cacheNodeIds: Boolean): RandomForestClassifierWrapper = {
8281

@@ -102,7 +101,6 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC
102101
.setSubsamplingRate(subsamplingRate)
103102
.setMaxMemoryInMB(maxMemoryInMB)
104103
.setCacheNodeIds(cacheNodeIds)
105-
.setProbabilityCol(probabilityCol)
106104
.setFeaturesCol(rFormula.getFeaturesCol)
107105
.setLabelCol(rFormula.getLabelCol)
108106
.setPredictionCol(PREDICTED_LABEL_INDEX_COL)

0 commit comments

Comments
 (0)