@@ -45,6 +45,20 @@ setClass("RandomForestRegressionModel", representation(jobj = "jobj"))
4545# ' @note RandomForestClassificationModel since 2.1.0
4646setClass ("RandomForestClassificationModel ", representation(jobj = "jobj"))
4747
48+ # ' S4 class that represents a DecisionTreeRegressionModel
49+ # '
50+ # ' @param jobj a Java object reference to the backing Scala DecisionTreeRegressionModel
51+ # ' @export
52+ # ' @note DecisionTreeRegressionModel since 2.3.0
53+ setClass ("DecisionTreeRegressionModel ", representation(jobj = "jobj"))
54+
55+ # ' S4 class that represents a DecisionTreeClassificationModel
56+ # '
57+ # ' @param jobj a Java object reference to the backing Scala DecisionTreeClassificationModel
58+ # ' @export
59+ # ' @note DecisionTreeClassificationModel since 2.3.0
60+ setClass ("DecisionTreeClassificationModel ", representation(jobj = "jobj"))
61+
4862# Create the summary of a tree ensemble model (eg. Random Forest, GBT)
4963summary.treeEnsemble <- function (model ) {
5064 jobj <- model @ jobj
@@ -81,6 +95,36 @@ print.summary.treeEnsemble <- function(x) {
8195 invisible (x )
8296}
8397
98+ # Create the summary of a decision tree model
99+ summary.decisionTree <- function (model ) {
100+ jobj <- model @ jobj
101+ formula <- callJMethod(jobj , " formula" )
102+ numFeatures <- callJMethod(jobj , " numFeatures" )
103+ features <- callJMethod(jobj , " features" )
104+ featureImportances <- callJMethod(callJMethod(jobj , " featureImportances" ), " toString" )
105+ maxDepth <- callJMethod(jobj , " maxDepth" )
106+ list (formula = formula ,
107+ numFeatures = numFeatures ,
108+ features = features ,
109+ featureImportances = featureImportances ,
110+ maxDepth = maxDepth ,
111+ jobj = jobj )
112+ }
113+
114+ # Prints the summary of decision tree models
115+ print.summary.decisionTree <- function (x ) {
116+ jobj <- x $ jobj
117+ cat(" Formula: " , x $ formula )
118+ cat(" \n Number of features: " , x $ numFeatures )
119+ cat(" \n Features: " , unlist(x $ features ))
120+ cat(" \n Feature importances: " , x $ featureImportances )
121+ cat(" \n Max Depth: " , x $ maxDepth )
122+
123+ summaryStr <- callJMethod(jobj , " summary" )
124+ cat(" \n " , summaryStr , " \n " )
125+ invisible (x )
126+ }
127+
84128# ' Gradient Boosted Tree Model for Regression and Classification
85129# '
86130# ' \code{spark.gbt} fits a Gradient Boosted Tree Regression model or Classification model on a
@@ -499,3 +543,199 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path
499543 function (object , path , overwrite = FALSE ) {
500544 write_internal(object , path , overwrite )
501545 })
546+
547+ # ' Decision Tree Model for Regression and Classification
548+ # '
549+ # ' \code{spark.decisionTree} fits a Decision Tree Regression model or Classification model on
550+ # ' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted Decision Tree
551+ # ' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to
552+ # ' save/load fitted models.
553+ # ' For more details, see
554+ # ' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-regression}{
555+ # ' Decision Tree Regression} and
556+ # ' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-classifier}{
557+ # ' Decision Tree Classification}
558+ # '
559+ # ' @param data a SparkDataFrame for training.
560+ # ' @param formula a symbolic description of the model to be fitted. Currently only a few formula
561+ # ' operators are supported, including '~', ':', '+', and '-'.
562+ # ' @param type type of model, one of "regression" or "classification", to fit
563+ # ' @param maxDepth Maximum depth of the tree (>= 0).
564+ # ' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing
565+ # ' how to split on features at each node. More bins give higher granularity. Must be
566+ # ' >= 2 and >= number of categories in any categorical feature.
567+ # ' @param impurity Criterion used for information gain calculation.
568+ # ' For regression, must be "variance". For classification, must be one of
569+ # ' "entropy" and "gini", default is "gini".
570+ # ' @param seed integer seed for random number generation.
571+ # ' @param minInstancesPerNode Minimum number of instances each child must have after split.
572+ # ' @param minInfoGain Minimum information gain for a split to be considered at a tree node.
573+ # ' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1).
574+ # ' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation.
575+ # ' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with
576+ # ' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching
577+ # ' can speed up training of deeper trees. Users can set how often should the
578+ # ' cache be checkpointed or disable it by setting checkpointInterval.
579+ # ' @param ... additional arguments passed to the method.
580+ # ' @aliases spark.decisionTree,SparkDataFrame,formula-method
581+ # ' @return \code{spark.decisionTree} returns a fitted Decision Tree model.
582+ # ' @rdname spark.decisionTree
583+ # ' @name spark.decisionTree
584+ # ' @export
585+ # ' @examples
586+ # ' \dontrun{
587+ # ' # fit a Decision Tree Regression Model
588+ # ' df <- createDataFrame(longley)
589+ # ' model <- spark.decisionTree(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16)
590+ # '
591+ # ' # get the summary of the model
592+ # ' summary(model)
593+ # '
594+ # ' # make predictions
595+ # ' predictions <- predict(model, df)
596+ # '
597+ # ' # save and load the model
598+ # ' path <- "path/to/model"
599+ # ' write.ml(model, path)
600+ # ' savedModel <- read.ml(path)
601+ # ' summary(savedModel)
602+ # '
603+ # ' # fit a Decision Tree Classification Model
604+ # ' t <- as.data.frame(Titanic)
605+ # ' df <- createDataFrame(t)
606+ # ' model <- spark.decisionTree(df, Survived ~ Freq + Age, "classification")
607+ # ' }
608+ # ' @note spark.decisionTree since 2.3.0
609+ setMethod ("spark.decisionTree ", signature(data = "SparkDataFrame", formula = "formula"),
610+ function (data , formula , type = c(" regression" , " classification" ),
611+ maxDepth = 5 , maxBins = 32 , impurity = NULL , seed = NULL ,
612+ minInstancesPerNode = 1 , minInfoGain = 0.0 , checkpointInterval = 10 ,
613+ maxMemoryInMB = 256 , cacheNodeIds = FALSE ) {
614+ type <- match.arg(type )
615+ formula <- paste(deparse(formula ), collapse = " " )
616+ if (! is.null(seed )) {
617+ seed <- as.character(as.integer(seed ))
618+ }
619+ switch (type ,
620+ regression = {
621+ if (is.null(impurity )) impurity <- " variance"
622+ impurity <- match.arg(impurity , " variance" )
623+ jobj <- callJStatic(" org.apache.spark.ml.r.DecisionTreeRegressorWrapper" ,
624+ " fit" , data @ sdf , formula , as.integer(maxDepth ),
625+ as.integer(maxBins ), impurity ,
626+ as.integer(minInstancesPerNode ), as.numeric(minInfoGain ),
627+ as.integer(checkpointInterval ), seed ,
628+ as.integer(maxMemoryInMB ), as.logical(cacheNodeIds ))
629+ new(" DecisionTreeRegressionModel" , jobj = jobj )
630+ },
631+ classification = {
632+ if (is.null(impurity )) impurity <- " gini"
633+ impurity <- match.arg(impurity , c(" gini" , " entropy" ))
634+ jobj <- callJStatic(" org.apache.spark.ml.r.DecisionTreeClassifierWrapper" ,
635+ " fit" , data @ sdf , formula , as.integer(maxDepth ),
636+ as.integer(maxBins ), impurity ,
637+ as.integer(minInstancesPerNode ), as.numeric(minInfoGain ),
638+ as.integer(checkpointInterval ), seed ,
639+ as.integer(maxMemoryInMB ), as.logical(cacheNodeIds ))
640+ new(" DecisionTreeClassificationModel" , jobj = jobj )
641+ }
642+ )
643+ })
644+
645+ # Get the summary of a Decision Tree Regression Model
646+
647+ # ' @return \code{summary} returns summary information of the fitted model, which is a list.
648+ # ' The list of components includes \code{formula} (formula),
649+ # ' \code{numFeatures} (number of features), \code{features} (list of features),
650+ # ' \code{featureImportances} (feature importances), and \code{maxDepth} (max depth of trees).
651+ # ' @rdname spark.decisionTree
652+ # ' @aliases summary,DecisionTreeRegressionModel-method
653+ # ' @export
654+ # ' @note summary(DecisionTreeRegressionModel) since 2.3.0
655+ setMethod ("summary ", signature(object = "DecisionTreeRegressionModel"),
656+ function (object ) {
657+ ans <- summary.decisionTree(object )
658+ class(ans ) <- " summary.DecisionTreeRegressionModel"
659+ ans
660+ })
661+
662+ # Prints the summary of Decision Tree Regression Model
663+
664+ # ' @param x summary object of Decision Tree regression model or classification model
665+ # ' returned by \code{summary}.
666+ # ' @rdname spark.decisionTree
667+ # ' @export
668+ # ' @note print.summary.DecisionTreeRegressionModel since 2.3.0
669+ print.summary.DecisionTreeRegressionModel <- function (x , ... ) {
670+ print.summary.decisionTree(x )
671+ }
672+
673+ # Get the summary of a Decision Tree Classification Model
674+
675+ # ' @rdname spark.decisionTree
676+ # ' @aliases summary,DecisionTreeClassificationModel-method
677+ # ' @export
678+ # ' @note summary(DecisionTreeClassificationModel) since 2.3.0
679+ setMethod ("summary ", signature(object = "DecisionTreeClassificationModel"),
680+ function (object ) {
681+ ans <- summary.decisionTree(object )
682+ class(ans ) <- " summary.DecisionTreeClassificationModel"
683+ ans
684+ })
685+
686+ # Prints the summary of Decision Tree Classification Model
687+
688+ # ' @rdname spark.decisionTree
689+ # ' @export
690+ # ' @note print.summary.DecisionTreeClassificationModel since 2.3.0
691+ print.summary.DecisionTreeClassificationModel <- function (x , ... ) {
692+ print.summary.decisionTree(x )
693+ }
694+
695+ # Makes predictions from a Decision Tree Regression model or Classification model
696+
697+ # ' @param newData a SparkDataFrame for testing.
698+ # ' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named
699+ # ' "prediction".
700+ # ' @rdname spark.decisionTree
701+ # ' @aliases predict,DecisionTreeRegressionModel-method
702+ # ' @export
703+ # ' @note predict(DecisionTreeRegressionModel) since 2.3.0
704+ setMethod ("predict ", signature(object = "DecisionTreeRegressionModel"),
705+ function (object , newData ) {
706+ predict_internal(object , newData )
707+ })
708+
709+ # ' @rdname spark.decisionTree
710+ # ' @aliases predict,DecisionTreeClassificationModel-method
711+ # ' @export
712+ # ' @note predict(DecisionTreeClassificationModel) since 2.3.0
713+ setMethod ("predict ", signature(object = "DecisionTreeClassificationModel"),
714+ function (object , newData ) {
715+ predict_internal(object , newData )
716+ })
717+
718+ # Save the Decision Tree Regression or Classification model to the input path.
719+
720+ # ' @param object A fitted Decision Tree regression model or classification model.
721+ # ' @param path The directory where the model is saved.
722+ # ' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
723+ # ' which means throw exception if the output path exists.
724+ # '
725+ # ' @aliases write.ml,DecisionTreeRegressionModel,character-method
726+ # ' @rdname spark.decisionTree
727+ # ' @export
728+ # ' @note write.ml(DecisionTreeRegressionModel, character) since 2.3.0
729+ setMethod ("write.ml ", signature(object = "DecisionTreeRegressionModel", path = "character"),
730+ function (object , path , overwrite = FALSE ) {
731+ write_internal(object , path , overwrite )
732+ })
733+
734+ # ' @aliases write.ml,DecisionTreeClassificationModel,character-method
735+ # ' @rdname spark.decisionTree
736+ # ' @export
737+ # ' @note write.ml(DecisionTreeClassificationModel, character) since 2.3.0
738+ setMethod ("write.ml ", signature(object = "DecisionTreeClassificationModel", path = "character"),
739+ function (object , path , overwrite = FALSE ) {
740+ write_internal(object , path , overwrite )
741+ })
0 commit comments