Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions R/pkg/R/mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,6 @@ setMethod("predict", signature(object = "KMeansModel"),
#' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p
#' is the original probability of that class and t is the class's threshold.
#' @param weightCol The weight column name.
#' @param probabilityCol column name for predicted class conditional probabilities.
#' @param ... additional arguments passed to the method.
#' @return \code{spark.logit} returns a fitted logistic regression model
#' @rdname spark.logit
Expand Down Expand Up @@ -772,7 +771,7 @@ setMethod("predict", signature(object = "KMeansModel"),
setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula, regParam = 0.0, elasticNetParam = 0.0, maxIter = 100,
tol = 1E-6, family = "auto", standardization = TRUE,
thresholds = 0.5, weightCol = NULL, probabilityCol = "probability") {
thresholds = 0.5, weightCol = NULL) {
formula <- paste(deparse(formula), collapse = "")

if (is.null(weightCol)) {
Expand All @@ -784,7 +783,7 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula")
as.numeric(elasticNetParam), as.integer(maxIter),
as.numeric(tol), as.character(family),
as.logical(standardization), as.array(thresholds),
as.character(weightCol), as.character(probabilityCol))
as.character(weightCol))
new("LogisticRegressionModel", jobj = jobj)
})

Expand Down Expand Up @@ -1425,7 +1424,7 @@ setMethod("predict", signature(object = "GaussianMixtureModel"),
#' @param userCol column name for user ids. Ids must be (or can be coerced into) integers.
#' @param itemCol column name for item ids. Ids must be (or can be coerced into) integers.
#' @param rank rank of the matrix factorization (> 0).
#' @param reg regularization parameter (>= 0).
#' @param regParam regularization parameter (>= 0).
#' @param maxIter maximum number of iterations (>= 0).
#' @param nonnegative logical value indicating whether to apply nonnegativity constraints.
#' @param implicitPrefs logical value indicating whether to use implicit preference.
Expand Down Expand Up @@ -1464,29 +1463,29 @@ setMethod("predict", signature(object = "GaussianMixtureModel"),
#'
#' # set other arguments
#' modelS <- spark.als(df, "rating", "user", "item", rank = 20,
#' reg = 0.1, nonnegative = TRUE)
#' regParam = 0.1, nonnegative = TRUE)
#' statsS <- summary(modelS)
#' }
#' @note spark.als since 2.1.0
setMethod("spark.als", signature(data = "SparkDataFrame"),
function(data, ratingCol = "rating", userCol = "user", itemCol = "item",
rank = 10, reg = 0.1, maxIter = 10, nonnegative = FALSE,
rank = 10, regParam = 0.1, maxIter = 10, nonnegative = FALSE,
implicitPrefs = FALSE, alpha = 1.0, numUserBlocks = 10, numItemBlocks = 10,
checkpointInterval = 10, seed = 0) {

if (!is.numeric(rank) || rank <= 0) {
stop("rank should be a positive number.")
}
if (!is.numeric(reg) || reg < 0) {
stop("reg should be a nonnegative number.")
if (!is.numeric(regParam) || regParam < 0) {
stop("regParam should be a nonnegative number.")
}
if (!is.numeric(maxIter) || maxIter <= 0) {
stop("maxIter should be a positive number.")
}

jobj <- callJStatic("org.apache.spark.ml.r.ALSWrapper",
"fit", data@sdf, ratingCol, userCol, itemCol, as.integer(rank),
reg, as.integer(maxIter), implicitPrefs, alpha, nonnegative,
regParam, as.integer(maxIter), implicitPrefs, alpha, nonnegative,
as.integer(numUserBlocks), as.integer(numItemBlocks),
as.integer(checkpointInterval), as.integer(seed))
new("ALSModel", jobj = jobj)
Expand Down Expand Up @@ -1684,8 +1683,6 @@ print.summary.KSTest <- function(x, ...) {
#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching
#' can speed up training of deeper trees. Users can set how often should the
#' cache be checkpointed or disable it by setting checkpointInterval.
#' @param probabilityCol column name for predicted class conditional probabilities, only for
#' classification.
#' @param ... additional arguments passed to the method.
#' @aliases spark.randomForest,SparkDataFrame,formula-method
#' @return \code{spark.randomForest} returns a fitted Random Forest model.
Expand Down Expand Up @@ -1720,7 +1717,7 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo
maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL,
featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0,
minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10,
maxMemoryInMB = 256, cacheNodeIds = FALSE, probabilityCol = "probability") {
maxMemoryInMB = 256, cacheNodeIds = FALSE) {
type <- match.arg(type)
formula <- paste(deparse(formula), collapse = "")
if (!is.null(seed)) {
Expand Down Expand Up @@ -1749,7 +1746,7 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo
impurity, as.integer(minInstancesPerNode),
as.numeric(minInfoGain), as.integer(checkpointInterval),
as.character(featureSubsetStrategy), seed,
as.numeric(subsamplingRate), as.character(probabilityCol),
as.numeric(subsamplingRate),
as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
new("RandomForestClassificationModel", jobj = jobj)
}
Expand Down
4 changes: 2 additions & 2 deletions R/pkg/inst/tests/testthat/test_mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -926,10 +926,10 @@ test_that("spark.posterior and spark.perplexity", {

test_that("spark.als", {
data <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0),
list(2, 1, 1.0), list(2, 2, 5.0))
list(2, 1, 1.0), list(2, 2, 5.0))
df <- createDataFrame(data, c("user", "item", "score"))
model <- spark.als(df, ratingCol = "score", userCol = "user", itemCol = "item",
rank = 10, maxIter = 5, seed = 0, reg = 0.1)
rank = 10, maxIter = 5, seed = 0, regParam = 0.1)
stats <- summary(model)
expect_equal(stats$rank, 10)
test <- createDataFrame(list(list(0, 2), list(1, 0), list(2, 0)), c("user", "item"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ private[r] object LogisticRegressionWrapper
family: String,
standardization: Boolean,
thresholds: Array[Double],
weightCol: String,
probabilityCol: String
weightCol: String
): LogisticRegressionWrapper = {

val rFormula = new RFormula()
Expand All @@ -123,7 +122,6 @@ private[r] object LogisticRegressionWrapper
.setWeightCol(weightCol)
.setFeaturesCol(rFormula.getFeaturesCol)
.setLabelCol(rFormula.getLabelCol)
.setProbabilityCol(probabilityCol)
.setPredictionCol(PREDICTED_LABEL_INDEX_COL)

if (thresholds.length > 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC
featureSubsetStrategy: String,
seed: String,
subsamplingRate: Double,
probabilityCol: String,
maxMemoryInMB: Int,
cacheNodeIds: Boolean): RandomForestClassifierWrapper = {

Expand All @@ -102,7 +101,6 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC
.setSubsamplingRate(subsamplingRate)
.setMaxMemoryInMB(maxMemoryInMB)
.setCacheNodeIds(cacheNodeIds)
.setProbabilityCol(probabilityCol)
.setFeaturesCol(rFormula.getFeaturesCol)
.setLabelCol(rFormula.getLabelCol)
.setPredictionCol(PREDICTED_LABEL_INDEX_COL)
Expand Down