@@ -118,7 +118,8 @@ setClass("DecisionTreeClassificationModel", representation(jobj = "jobj"))
118118# ' @export
119119# ' @seealso \link{spark.glm}, \link{glm},
120120# ' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans},
121- # ' @seealso \link{spark.lda}, \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg}
121+ # ' @seealso \link{spark.lda}, \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg},
122+ # ' @seealso \link{spark.decisionTree},
122123# ' @seealso \link{read.ml}
123124NULL
124125
131132# ' @export
132133# ' @seealso \link{spark.glm}, \link{glm},
133134# ' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans},
134- # ' @seealso \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg}
135+ # ' @seealso \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg}, \link{spark.decisionTree}
135136NULL
136137
137138write_internal <- function (object , path , overwrite = FALSE ) {
@@ -911,22 +912,6 @@ setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "charact
911912 write_internal(object , path , overwrite )
912913 })
913914
914- # ' Save the Decision Tree Regression model to the input path.
915- # '
916- # ' @param object A fitted Decision tree regression model
917- # ' @param path The directory where the model is saved
918- # ' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
919- # ' which means throw exception if the output path exists.
920- # '
921- # ' @aliases write.ml,DecisionTreeRegressionModel,character-method
922- # ' @rdname spark.decisionTreeRegression
923- # ' @export
924- # ' @note write.ml(DecisionTreeRegressionModel, character) since 2.1.0
925- setMethod ("write.ml ", signature(object = "DecisionTreeRegressionModel", path = "character"),
926- function (object , path , overwrite = FALSE ) {
927- write_internal(object , path , overwrite )
928- })
929-
930915# ' Load a fitted MLlib model from the input path.
931916# '
932917# ' @param path path of the model to read.
@@ -964,6 +949,8 @@ read.ml <- function(path) {
964949 new(" ALSModel" , jobj = jobj )
965950 } else if (isInstanceOf(jobj , " org.apache.spark.ml.r.DecisionTreeRegressorWrapper" )) {
966951 new(" DecisionTreeRegressionModel" , jobj = jobj )
952+ } else if (isInstanceOf(jobj , " org.apache.spark.ml.r.DecisionTreeClassifierWrapper" )) {
953+ new(" DecisionTreeClassificationModel" , jobj = jobj )
967954 } else {
968955 stop(" Unsupported model: " , jobj )
969956 }
@@ -1477,6 +1464,7 @@ print.summary.KSTest <- function(x, ...) {
14771464# ' df <- createDataFrame(sqlContext, kyphosis)
14781465# ' model <- spark.decisionTree(df, Kyphosis ~ Age + Number + Start)
14791466# ' }
1467+ # ' @note spark.decisionTree since 2.1.0
14801468setMethod ("spark.decisionTree ", signature(data = "SparkDataFrame", formula = "formula"),
14811469 function (data , formula , type = c(" regression" , " classification" )) {
14821470 formula <- paste(deparse(formula ), collapse = " " )
@@ -1491,7 +1479,48 @@ setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "fo
14911479 }
14921480 })
14931481
1482+ # Makes predictions from a Decision Tree model or a model produced by spark.decisionTree()
1483+
1484+ # ' @param newData a SparkDataFrame for testing.
1485+ # ' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named
1486+ # ' "prediction"
1487+ # ' @rdname spark.decisionTree
1488+ # ' @export
1489+ # ' @note predict(decisionTreeRegressionModel) since 2.1.0
14941490setMethod ("predict ", signature(object = "DecisionTreeRegressionModel"),
14951491 function (object , newData ) {
1496- return (dataFrame(callJMethod( object @ jobj , " transform " , newData @ sdf )) )
1492+ predict_internal( object , newData )
14971493 })
1494+
1495+ # ' Save the Decision Tree Regression model to the input path.
1496+ # '
1497+ # ' @param object A fitted Decision tree regression model
1498+ # ' @param path The directory where the model is saved
1499+ # ' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
1500+ # ' which means throw exception if the output path exists.
1501+ # '
1502+ # ' @aliases write.ml,DecisionTreeRegressionModel,character-method
1503+ # ' @rdname spark.decisionTreeRegression
1504+ # ' @export
1505+ # ' @note write.ml(DecisionTreeRegressionModel, character) since 2.1.0
1506+ setMethod ("write.ml ", signature(object = "DecisionTreeRegressionModel", path = "character"),
1507+ function (object , path , overwrite = FALSE ) {
1508+ write_internal(object , path , overwrite )
1509+ })
1510+
1511+ # Get the summary of an IsotonicRegressionModel model
1512+
1513+ # ' @param object a fitted IsotonicRegressionModel
1514+ # ' @param ... Other optional arguments to summary of an IsotonicRegressionModel
1515+ # ' @return \code{summary} returns the model's boundaries and prediction as lists
1516+ # ' @rdname spark.isoreg
1517+ # ' @aliases summary,IsotonicRegressionModel-method
1518+ # ' @export
1519+ # ' @note summary(IsotonicRegressionModel) since 2.1.0
1520+ setMethod ("summary ", signature(object = "DecisionTreeRegressionModel"),
1521+ function (object , ... ) {
1522+ jobj <- object @ jobj
1523+ boundaries <- callJMethod(jobj , " boundaries" )
1524+ predictions <- callJMethod(jobj , " predictions" )
1525+ return (list (boundaries = boundaries , predictions = predictions ))
1526+ })
0 commit comments