Skip to content

Commit 463f965

Browse files
committed
add
1 parent bee4868 commit 463f965

File tree

1 file changed

+48
-19
lines changed

1 file changed

+48
-19
lines changed

R/pkg/R/mllib.R

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ setClass("DecisionTreeClassificationModel", representation(jobj = "jobj"))
118118
#' @export
119119
#' @seealso \link{spark.glm}, \link{glm},
120120
#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans},
121-
#' @seealso \link{spark.lda}, \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg}
121+
#' @seealso \link{spark.lda}, \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg},
122+
#' @seealso \link{spark.decisionTree},
122123
#' @seealso \link{read.ml}
123124
NULL
124125

@@ -131,7 +132,7 @@ NULL
131132
#' @export
132133
#' @seealso \link{spark.glm}, \link{glm},
133134
#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans},
134-
#' @seealso \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg}
135+
#' @seealso \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg}, \link{spark.decisionTree}
135136
NULL
136137

137138
write_internal <- function(object, path, overwrite = FALSE) {
@@ -911,22 +912,6 @@ setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "charact
911912
write_internal(object, path, overwrite)
912913
})
913914

914-
#' Save the Decision Tree Regression model to the input path.
915-
#'
916-
#' @param object A fitted Decision tree regression model
917-
#' @param path The directory where the model is saved
918-
#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
919-
#' which means throw exception if the output path exists.
920-
#'
921-
#' @aliases write.ml,DecisionTreeRegressionModel,character-method
922-
#' @rdname spark.decisionTreeRegression
923-
#' @export
924-
#' @note write.ml(DecisionTreeRegressionModel, character) since 2.1.0
925-
setMethod("write.ml", signature(object = "DecisionTreeRegressionModel", path = "character"),
926-
function(object, path, overwrite = FALSE) {
927-
write_internal(object, path, overwrite)
928-
})
929-
930915
#' Load a fitted MLlib model from the input path.
931916
#'
932917
#' @param path path of the model to read.
@@ -964,6 +949,8 @@ read.ml <- function(path) {
964949
new("ALSModel", jobj = jobj)
965950
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.DecisionTreeRegressorWrapper")) {
966951
new("DecisionTreeRegressionModel", jobj = jobj)
952+
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.DecisionTreeClassifierWrapper")) {
953+
new("DecisionTreeClassificationModel", jobj = jobj)
967954
} else {
968955
stop("Unsupported model: ", jobj)
969956
}
@@ -1477,6 +1464,7 @@ print.summary.KSTest <- function(x, ...) {
14771464
#' df <- createDataFrame(sqlContext, kyphosis)
14781465
#' model <- spark.decisionTree(df, Kyphosis ~ Age + Number + Start)
14791466
#' }
1467+
#' @note spark.decisionTree since 2.1.0
14801468
setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "formula"),
14811469
function(data, formula, type = c("regression", "classification")) {
14821470
formula <- paste(deparse(formula), collapse = "")
@@ -1491,7 +1479,48 @@ setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "fo
14911479
}
14921480
})
14931481

1482+
# Makes predictions from a Decision Tree model or a model produced by spark.decisionTree()
1483+
1484+
#' @param newData a SparkDataFrame for testing.
1485+
#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named
1486+
#' "prediction"
1487+
#' @rdname spark.decisionTree
1488+
#' @export
1489+
#' @note predict(decisionTreeRegressionModel) since 2.1.0
14941490
setMethod("predict", signature(object = "DecisionTreeRegressionModel"),
14951491
function(object, newData) {
1496-
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
1492+
predict_internal(object, newData)
14971493
})
1494+
1495+
#' Save the Decision Tree Regression model to the input path.
1496+
#'
1497+
#' @param object A fitted Decision tree regression model
1498+
#' @param path The directory where the model is saved
1499+
#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
1500+
#' which means throw exception if the output path exists.
1501+
#'
1502+
#' @aliases write.ml,DecisionTreeRegressionModel,character-method
1503+
#' @rdname spark.decisionTreeRegression
1504+
#' @export
1505+
#' @note write.ml(DecisionTreeRegressionModel, character) since 2.1.0
1506+
setMethod("write.ml", signature(object = "DecisionTreeRegressionModel", path = "character"),
1507+
function(object, path, overwrite = FALSE) {
1508+
write_internal(object, path, overwrite)
1509+
})
1510+
1511+
# Get the summary of an IsotonicRegressionModel model
1512+
1513+
#' @param object a fitted IsotonicRegressionModel
1514+
#' @param ... Other optional arguments to summary of an IsotonicRegressionModel
1515+
#' @return \code{summary} returns the model's boundaries and prediction as lists
1516+
#' @rdname spark.isoreg
1517+
#' @aliases summary,IsotonicRegressionModel-method
1518+
#' @export
1519+
#' @note summary(IsotonicRegressionModel) since 2.1.0
1520+
setMethod("summary", signature(object = "DecisionTreeRegressionModel"),
1521+
function(object, ...) {
1522+
jobj <- object@jobj
1523+
boundaries <- callJMethod(jobj, "boundaries")
1524+
predictions <- callJMethod(jobj, "predictions")
1525+
return(list(boundaries = boundaries, predictions = predictions))
1526+
})

0 commit comments

Comments
 (0)