@@ -1447,39 +1447,31 @@ print.summary.KSTest <- function(x, ...) {
14471447 invisible (x )
14481448}
14491449
1450- # ' Decision tree regression model.
1450+ # ' Decision Tree
14511451# '
1452- # ' Fit Decision Tree regression model on a SparkDataFrame.
1452+ # ' @description
1453+ # ' \code{spark.decisionTree} tree
14531454# '
1454- # ' @param data SparkDataFrame for training.
1455- # ' @param formula A symbolic description of the model to be fitted. Currently only a few formula
1456- # ' operators are supported, including '~', ':', '+', and '-'.
1457- # ' Note that operator '.' is not supported currently.
1458- # ' @return a fitted decision tree regression model
1459- # ' @rdname spark.decisionTreeRegressor
1460- # ' @seealso rpart: \url{https://cran.r-project.org/web/packages/rpart/}
1461- # ' @export
1462- # ' @examples
1463- # ' \dontrun{
1464- # ' df <- createDataFrame(sqlContext, kyphosis)
1465- # ' model <- spark.decisionTree(df, Kyphosis ~ Age + Number + Start)
1466- # ' }
1455+ # ' Decision Tree
1456+ # '
1457+ # ' @param data a SparkDataFrame of user data.
14671458# ' @note spark.decisionTree since 2.1.0
14681459setMethod ("spark.decisionTree ", signature(data = "SparkDataFrame", formula = "formula"),
1469- function (data , formula , type = c(" regression" , " classification" )) {
1460+ function (data , formula , type = c(" regression" , " classification" ), maxDepth = 5 , maxBins = 32 ) {
14701461 formula <- paste(deparse(formula ), collapse = " " )
14711462 if (identical(type , " regression" )) {
14721463 jobj <- callJStatic(" org.apache.spark.ml.r.DecisionTreeRegressorWrapper" , " fit" ,
1473- data @ sdf , formula )
1464+ data @ sdf , formula , as.integer( maxDepth ), as.integer( maxBins ) )
14741465 new(" DecisionTreeRegressionModel" , jobj = jobj )
14751466 } else if (identical(type , " classification" )) {
1476- jobj <- callJStatic(" org.apache.spark.ml.r.DecisionTreeClassificationWrapper " , " fit" ,
1477- data @ sdf , formula )
1467+ jobj <- callJStatic(" org.apache.spark.ml.r.DecisionTreeClassifierWrapper " , " fit" ,
1468+ data @ sdf , formula , as.integer( maxDepth ), as.integer( maxBins ) )
14781469 new(" DecisionTreeClassificationModel" , jobj = jobj )
14791470 }
14801471 })
14811472
1482- # Makes predictions from a Decision Tree model or a model produced by spark.decisionTree()
1473+ # Makes predictions from a Decision Tree Regression model or
1474+ # a model produced by spark.decisionTree()
14831475
14841476# ' @param newData a SparkDataFrame for testing.
14851477# ' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named
@@ -1492,6 +1484,20 @@ setMethod("predict", signature(object = "DecisionTreeRegressionModel"),
14921484 predict_internal(object , newData )
14931485 })
14941486
1487+ # Makes predictions from a Decision Tree Classification model or
1488+ # a model produced by spark.decisionTree()
1489+
1490+ # ' @param newData a SparkDataFrame for testing.
1491+ # ' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named
1492+ # ' "prediction"
1493+ # ' @rdname spark.decisionTree
1494+ # ' @export
1495+ # ' @note predict(decisionTreeClassificationModel) since 2.1.0
1496+ setMethod ("predict ", signature(object = "DecisionTreeClassificationModel"),
1497+ function (object , newData ) {
1498+ predict_internal(object , newData )
1499+ })
1500+
14951501# ' Save the Decision Tree Regression model to the input path.
14961502# '
14971503# ' @param object A fitted Decision tree regression model
@@ -1504,23 +1510,88 @@ setMethod("predict", signature(object = "DecisionTreeRegressionModel"),
15041510# ' @export
15051511# ' @note write.ml(DecisionTreeRegressionModel, character) since 2.1.0
15061512setMethod ("write.ml ", signature(object = "DecisionTreeRegressionModel", path = "character"),
1507- function (object , path , overwrite = FALSE ) {
1508- write_internal(object , path , overwrite )
1509- })
1513+ function (object , path , overwrite = FALSE ) {
1514+ write_internal(object , path , overwrite )
1515+ })
15101516
1511- # Get the summary of an IsotonicRegressionModel model
1517+ # ' Save the Decision Tree Classification model to the input path.
1518+ # '
1519+ # ' @param object A fitted Decision tree classification model
1520+ # ' @param path The directory where the model is saved
1521+ # ' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
1522+ # ' which means throw exception if the output path exists.
1523+ # '
1524+ # ' @aliases write.ml,DecisionTreeClassificationModel,character-method
1525+ # ' @rdname spark.decisionTreeClassification
1526+ # ' @export
1527+ # ' @note write.ml(DecisionTreeClassificationModel, character) since 2.1.0
1528+ setMethod ("write.ml ", signature(object = "DecisionTreeClassificationModel", path = "character"),
1529+ function (object , path , overwrite = FALSE ) {
1530+ write_internal(object , path , overwrite )
1531+ })
15121532
1513- # ' @param object a fitted IsotonicRegressionModel
1514- # ' @param ... Other optional arguments to summary of an IsotonicRegressionModel
1515- # ' @return \code{summary} returns the model's boundaries and prediction as lists
1516- # ' @rdname spark.isoreg
1517- # ' @aliases summary,IsotonicRegressionModel-method
1533+ # Get the summary of an DecisionTreeRegressionModel model
1534+
1535+ # ' @param object a fitted DecisionTreeRegressionModel
1536+ # ' @param ... Other optional arguments to summary of a DecisionTreeRegressionModel
1537+ # ' @return \code{summary} returns the model's features as lists, depth and number of nodes
1538+ # ' @rdname spark.decisionTree
1539+ # ' @aliases summary,DecisionTreeRegressionModel-method
15181540# ' @export
1519- # ' @note summary(IsotonicRegressionModel ) since 2.1.0
1541+ # ' @note summary(DecisionTreeRegressionModel ) since 2.1.0
15201542setMethod ("summary ", signature(object = "DecisionTreeRegressionModel"),
15211543 function (object , ... ) {
15221544 jobj <- object @ jobj
1523- boundaries <- callJMethod(jobj , " boundaries" )
1524- predictions <- callJMethod(jobj , " predictions" )
1525- return (list (boundaries = boundaries , predictions = predictions ))
1526- })
1545+ features <- callJMethod(jobj , " features" )
1546+ depth <- callJMethod(jobj , " depth" )
1547+ numNodes <- callJMethod(jobj , " numNodes" )
1548+ ans <- list (features = features , depth = depth , numNodes = numNodes )
1549+ class(ans ) <- " summary.DecisionTreeRegressionModel"
1550+ ans
1551+ })
1552+
1553+ # Get the summary of an DecisionTreeClassificationModel model
1554+
1555+ # ' @param object a fitted DecisionTreeClassificationModel
1556+ # ' @param ... Other optional arguments to summary of a DecisionTreeClassificationModel
1557+ # ' @return \code{summary} returns the model's features as lists, depth and number of nodes
1558+ # ' @rdname spark.decisionTree
1559+ # ' @aliases summary,DecisionTreeClassificationModel-method
1560+ # ' @export
1561+ # ' @note summary(DecisionTreeRegressionModel) since 2.1.0
1562+ setMethod ("summary ", signature(object = "DecisionTreeClassificationModel"),
1563+ function (object , ... ) {
1564+ jobj <- object @ jobj
1565+ features <- callJMethod(jobj , " features" )
1566+ depth <- callJMethod(jobj , " depth" )
1567+ numNodes <- callJMethod(jobj , " numNodes" )
1568+ ans <- list (features = features , depth = depth , numNodes = numNodes )
1569+ class(ans ) <- " summary.DecisionTreeClassificationModel"
1570+ ans
1571+ })
1572+
1573+ # Prints the summary of Decision Tree Regression Model
1574+
1575+ # ' @rdname spark.decisionTree
1576+ # ' @param x summary object of decisionTreeRegressionModel returned by \code{summary}.
1577+ # ' @export
1578+ # ' @note print.summary.DecisionTreeRegressionModel since 2.1.0
1579+ print.summary.DecisionTreeRegressionModel <- function (x , ... ) {
1580+ jobj <- x @ jobj
1581+ summaryStr <- callJMethod(jobj , " summary" )
1582+ cat(summaryStr , " \n " )
1583+ invisible (x )
1584+ }
1585+
1586+ # Prints the summary of Decision Tree Classification Model
1587+
1588+ # ' @rdname spark.decisionTree
1589+ # ' @param x summary object of decisionTreeClassificationModel returned by \code{summary}.
1590+ # ' @export
1591+ # ' @note print.summary.DecisionTreeClassificationModel since 2.1.0
1592+ print.summary.DecisionTreeClassificationModel <- function (x , ... ) {
1593+ jobj <- x @ jobj
1594+ summaryStr <- callJMethod(jobj , " summary" )
1595+ cat(summaryStr , " \n " )
1596+ invisible (x )
1597+ }
0 commit comments