Skip to content

Commit 7bbfe3a

Browse files
author
Peng
committed
Merge remote-tracking branch 'origin/master' into moreTest
2 parents a8b407f + d06610f commit 7bbfe3a

File tree

64 files changed

+1647
-596
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+1647
-596
lines changed

R/pkg/NAMESPACE

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ exportMethods("glm",
6363
"spark.als",
6464
"spark.kstest",
6565
"spark.logit",
66+
"spark.decisionTree",
6667
"spark.randomForest",
6768
"spark.gbt",
6869
"spark.bisectingKmeans",
@@ -414,6 +415,8 @@ export("as.DataFrame",
414415
"print.summary.GeneralizedLinearRegressionModel",
415416
"read.ml",
416417
"print.summary.KSTest",
418+
"print.summary.DecisionTreeRegressionModel",
419+
"print.summary.DecisionTreeClassificationModel",
417420
"print.summary.RandomForestRegressionModel",
418421
"print.summary.RandomForestClassificationModel",
419422
"print.summary.GBTRegressionModel",
@@ -452,6 +455,8 @@ S3method(print, structField)
452455
S3method(print, structType)
453456
S3method(print, summary.GeneralizedLinearRegressionModel)
454457
S3method(print, summary.KSTest)
458+
S3method(print, summary.DecisionTreeRegressionModel)
459+
S3method(print, summary.DecisionTreeClassificationModel)
455460
S3method(print, summary.RandomForestRegressionModel)
456461
S3method(print, summary.RandomForestClassificationModel)
457462
S3method(print, summary.GBTRegressionModel)

R/pkg/R/generics.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,6 +1506,11 @@ setGeneric("spark.mlp", function(data, formula, ...) { standardGeneric("spark.ml
15061506
#' @export
15071507
setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("spark.naiveBayes") })
15081508

1509+
#' @rdname spark.decisionTree
1510+
#' @export
1511+
setGeneric("spark.decisionTree",
1512+
function(data, formula, ...) { standardGeneric("spark.decisionTree") })
1513+
15091514
#' @rdname spark.randomForest
15101515
#' @export
15111516
setGeneric("spark.randomForest",

R/pkg/R/mllib_tree.R

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,20 @@ setClass("RandomForestRegressionModel", representation(jobj = "jobj"))
4545
#' @note RandomForestClassificationModel since 2.1.0
4646
setClass("RandomForestClassificationModel", representation(jobj = "jobj"))
4747

48+
#' S4 class that represents a DecisionTreeRegressionModel
49+
#'
50+
#' @param jobj a Java object reference to the backing Scala DecisionTreeRegressionModel
51+
#' @export
52+
#' @note DecisionTreeRegressionModel since 2.3.0
53+
setClass("DecisionTreeRegressionModel", representation(jobj = "jobj"))
54+
55+
#' S4 class that represents a DecisionTreeClassificationModel
56+
#'
57+
#' @param jobj a Java object reference to the backing Scala DecisionTreeClassificationModel
58+
#' @export
59+
#' @note DecisionTreeClassificationModel since 2.3.0
60+
setClass("DecisionTreeClassificationModel", representation(jobj = "jobj"))
61+
4862
# Create the summary of a tree ensemble model (eg. Random Forest, GBT)
4963
summary.treeEnsemble <- function(model) {
5064
jobj <- model@jobj
@@ -81,6 +95,36 @@ print.summary.treeEnsemble <- function(x) {
8195
invisible(x)
8296
}
8397

98+
# Create the summary of a decision tree model
99+
summary.decisionTree <- function(model) {
100+
jobj <- model@jobj
101+
formula <- callJMethod(jobj, "formula")
102+
numFeatures <- callJMethod(jobj, "numFeatures")
103+
features <- callJMethod(jobj, "features")
104+
featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString")
105+
maxDepth <- callJMethod(jobj, "maxDepth")
106+
list(formula = formula,
107+
numFeatures = numFeatures,
108+
features = features,
109+
featureImportances = featureImportances,
110+
maxDepth = maxDepth,
111+
jobj = jobj)
112+
}
113+
114+
# Prints the summary of decision tree models
115+
print.summary.decisionTree <- function(x) {
116+
jobj <- x$jobj
117+
cat("Formula: ", x$formula)
118+
cat("\nNumber of features: ", x$numFeatures)
119+
cat("\nFeatures: ", unlist(x$features))
120+
cat("\nFeature importances: ", x$featureImportances)
121+
cat("\nMax Depth: ", x$maxDepth)
122+
123+
summaryStr <- callJMethod(jobj, "summary")
124+
cat("\n", summaryStr, "\n")
125+
invisible(x)
126+
}
127+
84128
#' Gradient Boosted Tree Model for Regression and Classification
85129
#'
86130
#' \code{spark.gbt} fits a Gradient Boosted Tree Regression model or Classification model on a
@@ -499,3 +543,199 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path
499543
function(object, path, overwrite = FALSE) {
500544
write_internal(object, path, overwrite)
501545
})
546+
547+
#' Decision Tree Model for Regression and Classification
548+
#'
549+
#' \code{spark.decisionTree} fits a Decision Tree Regression model or Classification model on
550+
#' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted Decision Tree
551+
#' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to
552+
#' save/load fitted models.
553+
#' For more details, see
554+
#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-regression}{
555+
#' Decision Tree Regression} and
556+
#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-classifier}{
557+
#' Decision Tree Classification}
558+
#'
559+
#' @param data a SparkDataFrame for training.
560+
#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
561+
#' operators are supported, including '~', ':', '+', and '-'.
562+
#' @param type type of model, one of "regression" or "classification", to fit
563+
#' @param maxDepth Maximum depth of the tree (>= 0).
564+
#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing
565+
#' how to split on features at each node. More bins give higher granularity. Must be
566+
#' >= 2 and >= number of categories in any categorical feature.
567+
#' @param impurity Criterion used for information gain calculation.
568+
#' For regression, must be "variance". For classification, must be one of
569+
#' "entropy" and "gini", default is "gini".
570+
#' @param seed integer seed for random number generation.
571+
#' @param minInstancesPerNode Minimum number of instances each child must have after split.
572+
#' @param minInfoGain Minimum information gain for a split to be considered at a tree node.
573+
#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1).
574+
#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation.
575+
#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with
576+
#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching
577+
#' can speed up training of deeper trees. Users can set how often should the
578+
#' cache be checkpointed or disable it by setting checkpointInterval.
579+
#' @param ... additional arguments passed to the method.
580+
#' @aliases spark.decisionTree,SparkDataFrame,formula-method
581+
#' @return \code{spark.decisionTree} returns a fitted Decision Tree model.
582+
#' @rdname spark.decisionTree
583+
#' @name spark.decisionTree
584+
#' @export
585+
#' @examples
586+
#' \dontrun{
587+
#' # fit a Decision Tree Regression Model
588+
#' df <- createDataFrame(longley)
589+
#' model <- spark.decisionTree(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16)
590+
#'
591+
#' # get the summary of the model
592+
#' summary(model)
593+
#'
594+
#' # make predictions
595+
#' predictions <- predict(model, df)
596+
#'
597+
#' # save and load the model
598+
#' path <- "path/to/model"
599+
#' write.ml(model, path)
600+
#' savedModel <- read.ml(path)
601+
#' summary(savedModel)
602+
#'
603+
#' # fit a Decision Tree Classification Model
604+
#' t <- as.data.frame(Titanic)
605+
#' df <- createDataFrame(t)
606+
#' model <- spark.decisionTree(df, Survived ~ Freq + Age, "classification")
607+
#' }
608+
#' @note spark.decisionTree since 2.3.0
609+
setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "formula"),
610+
function(data, formula, type = c("regression", "classification"),
611+
maxDepth = 5, maxBins = 32, impurity = NULL, seed = NULL,
612+
minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10,
613+
maxMemoryInMB = 256, cacheNodeIds = FALSE) {
614+
type <- match.arg(type)
615+
formula <- paste(deparse(formula), collapse = "")
616+
if (!is.null(seed)) {
617+
seed <- as.character(as.integer(seed))
618+
}
619+
switch(type,
620+
regression = {
621+
if (is.null(impurity)) impurity <- "variance"
622+
impurity <- match.arg(impurity, "variance")
623+
jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeRegressorWrapper",
624+
"fit", data@sdf, formula, as.integer(maxDepth),
625+
as.integer(maxBins), impurity,
626+
as.integer(minInstancesPerNode), as.numeric(minInfoGain),
627+
as.integer(checkpointInterval), seed,
628+
as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
629+
new("DecisionTreeRegressionModel", jobj = jobj)
630+
},
631+
classification = {
632+
if (is.null(impurity)) impurity <- "gini"
633+
impurity <- match.arg(impurity, c("gini", "entropy"))
634+
jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeClassifierWrapper",
635+
"fit", data@sdf, formula, as.integer(maxDepth),
636+
as.integer(maxBins), impurity,
637+
as.integer(minInstancesPerNode), as.numeric(minInfoGain),
638+
as.integer(checkpointInterval), seed,
639+
as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
640+
new("DecisionTreeClassificationModel", jobj = jobj)
641+
}
642+
)
643+
})
644+
645+
# Get the summary of a Decision Tree Regression Model
646+
647+
#' @return \code{summary} returns summary information of the fitted model, which is a list.
648+
#' The list of components includes \code{formula} (formula),
649+
#' \code{numFeatures} (number of features), \code{features} (list of features),
650+
#' \code{featureImportances} (feature importances), and \code{maxDepth} (max depth of trees).
651+
#' @rdname spark.decisionTree
652+
#' @aliases summary,DecisionTreeRegressionModel-method
653+
#' @export
654+
#' @note summary(DecisionTreeRegressionModel) since 2.3.0
655+
setMethod("summary", signature(object = "DecisionTreeRegressionModel"),
656+
function(object) {
657+
ans <- summary.decisionTree(object)
658+
class(ans) <- "summary.DecisionTreeRegressionModel"
659+
ans
660+
})
661+
662+
# Prints the summary of Decision Tree Regression Model
663+
664+
#' @param x summary object of Decision Tree regression model or classification model
665+
#' returned by \code{summary}.
666+
#' @rdname spark.decisionTree
667+
#' @export
668+
#' @note print.summary.DecisionTreeRegressionModel since 2.3.0
669+
print.summary.DecisionTreeRegressionModel <- function(x, ...) {
670+
print.summary.decisionTree(x)
671+
}
672+
673+
# Get the summary of a Decision Tree Classification Model
674+
675+
#' @rdname spark.decisionTree
676+
#' @aliases summary,DecisionTreeClassificationModel-method
677+
#' @export
678+
#' @note summary(DecisionTreeClassificationModel) since 2.3.0
679+
setMethod("summary", signature(object = "DecisionTreeClassificationModel"),
680+
function(object) {
681+
ans <- summary.decisionTree(object)
682+
class(ans) <- "summary.DecisionTreeClassificationModel"
683+
ans
684+
})
685+
686+
# Prints the summary of Decision Tree Classification Model
687+
688+
#' @rdname spark.decisionTree
689+
#' @export
690+
#' @note print.summary.DecisionTreeClassificationModel since 2.3.0
691+
print.summary.DecisionTreeClassificationModel <- function(x, ...) {
692+
print.summary.decisionTree(x)
693+
}
694+
695+
# Makes predictions from a Decision Tree Regression model or Classification model
696+
697+
#' @param newData a SparkDataFrame for testing.
698+
#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named
699+
#' "prediction".
700+
#' @rdname spark.decisionTree
701+
#' @aliases predict,DecisionTreeRegressionModel-method
702+
#' @export
703+
#' @note predict(DecisionTreeRegressionModel) since 2.3.0
704+
setMethod("predict", signature(object = "DecisionTreeRegressionModel"),
705+
function(object, newData) {
706+
predict_internal(object, newData)
707+
})
708+
709+
#' @rdname spark.decisionTree
710+
#' @aliases predict,DecisionTreeClassificationModel-method
711+
#' @export
712+
#' @note predict(DecisionTreeClassificationModel) since 2.3.0
713+
setMethod("predict", signature(object = "DecisionTreeClassificationModel"),
714+
function(object, newData) {
715+
predict_internal(object, newData)
716+
})
717+
718+
# Save the Decision Tree Regression or Classification model to the input path.
719+
720+
#' @param object A fitted Decision Tree regression model or classification model.
721+
#' @param path The directory where the model is saved.
722+
#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
723+
#' which means throw exception if the output path exists.
724+
#'
725+
#' @aliases write.ml,DecisionTreeRegressionModel,character-method
726+
#' @rdname spark.decisionTree
727+
#' @export
728+
#' @note write.ml(DecisionTreeRegressionModel, character) since 2.3.0
729+
setMethod("write.ml", signature(object = "DecisionTreeRegressionModel", path = "character"),
730+
function(object, path, overwrite = FALSE) {
731+
write_internal(object, path, overwrite)
732+
})
733+
734+
#' @aliases write.ml,DecisionTreeClassificationModel,character-method
735+
#' @rdname spark.decisionTree
736+
#' @export
737+
#' @note write.ml(DecisionTreeClassificationModel, character) since 2.3.0
738+
setMethod("write.ml", signature(object = "DecisionTreeClassificationModel", path = "character"),
739+
function(object, path, overwrite = FALSE) {
740+
write_internal(object, path, overwrite)
741+
})

R/pkg/R/mllib_utils.R

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@
3232
#' @rdname write.ml
3333
#' @name write.ml
3434
#' @export
35-
#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.gaussianMixture},
36-
#' @seealso \link{spark.gbt}, \link{spark.glm}, \link{glm}, \link{spark.isoreg},
35+
#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.decisionTree},
36+
#' @seealso \link{spark.gaussianMixture}, \link{spark.gbt},
37+
#' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg},
3738
#' @seealso \link{spark.kmeans},
3839
#' @seealso \link{spark.lda}, \link{spark.logit},
3940
#' @seealso \link{spark.mlp}, \link{spark.naiveBayes},
@@ -48,8 +49,9 @@ NULL
4849
#' @rdname predict
4950
#' @name predict
5051
#' @export
51-
#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.gaussianMixture},
52-
#' @seealso \link{spark.gbt}, \link{spark.glm}, \link{glm}, \link{spark.isoreg},
52+
#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.decisionTree},
53+
#' @seealso \link{spark.gaussianMixture}, \link{spark.gbt},
54+
#' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg},
5355
#' @seealso \link{spark.kmeans},
5456
#' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes},
5557
#' @seealso \link{spark.randomForest}, \link{spark.survreg}, \link{spark.svmLinear}
@@ -110,6 +112,10 @@ read.ml <- function(path) {
110112
new("RandomForestRegressionModel", jobj = jobj)
111113
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestClassifierWrapper")) {
112114
new("RandomForestClassificationModel", jobj = jobj)
115+
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.DecisionTreeRegressorWrapper")) {
116+
new("DecisionTreeRegressionModel", jobj = jobj)
117+
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.DecisionTreeClassifierWrapper")) {
118+
new("DecisionTreeClassificationModel", jobj = jobj)
113119
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTRegressorWrapper")) {
114120
new("GBTRegressionModel", jobj = jobj)
115121
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTClassifierWrapper")) {

R/pkg/R/utils.R

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,3 +907,19 @@ basenameSansExtFromUrl <- function(url) {
907907
isAtomicLengthOne <- function(x) {
908908
is.atomic(x) && length(x) == 1
909909
}
910+
911+
is_cran <- function() {
912+
!identical(Sys.getenv("NOT_CRAN"), "true")
913+
}
914+
915+
is_windows <- function() {
916+
.Platform$OS.type == "windows"
917+
}
918+
919+
hadoop_home_set <- function() {
920+
!identical(Sys.getenv("HADOOP_HOME"), "")
921+
}
922+
923+
not_cran_or_windows_with_hadoop <- function() {
924+
!is_cran() && (!is_windows() || hadoop_home_set())
925+
}

0 commit comments

Comments
 (0)