Skip to content

Commit 6c7d259

Browse files
committed
Merge remote-tracking branch 'origin/master' into execution
2 parents f49a0b3 + 35b644b commit 6c7d259

File tree

203 files changed

+4236
-1254
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

203 files changed

+4236
-1254
lines changed

R/pkg/NAMESPACE

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ exportMethods("glm",
6363
"spark.als",
6464
"spark.kstest",
6565
"spark.logit",
66+
"spark.decisionTree",
6667
"spark.randomForest",
6768
"spark.gbt",
6869
"spark.bisectingKmeans",
@@ -414,6 +415,8 @@ export("as.DataFrame",
414415
"print.summary.GeneralizedLinearRegressionModel",
415416
"read.ml",
416417
"print.summary.KSTest",
418+
"print.summary.DecisionTreeRegressionModel",
419+
"print.summary.DecisionTreeClassificationModel",
417420
"print.summary.RandomForestRegressionModel",
418421
"print.summary.RandomForestClassificationModel",
419422
"print.summary.GBTRegressionModel",
@@ -452,6 +455,8 @@ S3method(print, structField)
452455
S3method(print, structType)
453456
S3method(print, summary.GeneralizedLinearRegressionModel)
454457
S3method(print, summary.KSTest)
458+
S3method(print, summary.DecisionTreeRegressionModel)
459+
S3method(print, summary.DecisionTreeClassificationModel)
455460
S3method(print, summary.RandomForestRegressionModel)
456461
S3method(print, summary.RandomForestClassificationModel)
457462
S3method(print, summary.GBTRegressionModel)

R/pkg/R/generics.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,6 +1506,11 @@ setGeneric("spark.mlp", function(data, formula, ...) { standardGeneric("spark.ml
15061506
#' @export
15071507
setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("spark.naiveBayes") })
15081508

1509+
#' @rdname spark.decisionTree
1510+
#' @export
1511+
setGeneric("spark.decisionTree",
1512+
function(data, formula, ...) { standardGeneric("spark.decisionTree") })
1513+
15091514
#' @rdname spark.randomForest
15101515
#' @export
15111516
setGeneric("spark.randomForest",

R/pkg/R/mllib_classification.R

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,16 @@ setClass("MultilayerPerceptronClassificationModel", representation(jobj = "jobj"
4646
#' @note NaiveBayesModel since 2.0.0
4747
setClass("NaiveBayesModel", representation(jobj = "jobj"))
4848

49-
#' linear SVM Model
49+
#' Linear SVM Model
5050
#'
51-
#' Fits an linear SVM model against a SparkDataFrame. It is a binary classifier, similar to svm in glmnet package
51+
#' Fits a linear SVM model against a SparkDataFrame, similar to svm in e1071 package.
52+
#' Currently only supports binary classification model with linear kernel.
5253
#' Users can print, make predictions on the produced model and save the model to the input path.
5354
#'
5455
#' @param data SparkDataFrame for training.
5556
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
5657
#' operators are supported, including '~', '.', ':', '+', and '-'.
57-
#' @param regParam The regularization parameter.
58+
#' @param regParam The regularization parameter. Only supports L2 regularization currently.
5859
#' @param maxIter Maximum iteration number.
5960
#' @param tol Convergence tolerance of iterations.
6061
#' @param standardization Whether to standardize the training features before fitting the model. The coefficients
@@ -111,10 +112,10 @@ setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formu
111112
new("LinearSVCModel", jobj = jobj)
112113
})
113114

114-
# Predicted values based on an LinearSVCModel model
115+
# Predicted values based on a LinearSVCModel model
115116

116117
#' @param newData a SparkDataFrame for testing.
117-
#' @return \code{predict} returns the predicted values based on an LinearSVCModel.
118+
#' @return \code{predict} returns the predicted values based on a LinearSVCModel.
118119
#' @rdname spark.svmLinear
119120
#' @aliases predict,LinearSVCModel,SparkDataFrame-method
120121
#' @export
@@ -124,36 +125,27 @@ setMethod("predict", signature(object = "LinearSVCModel"),
124125
predict_internal(object, newData)
125126
})
126127

127-
# Get the summary of an LinearSVCModel
128+
# Get the summary of a LinearSVCModel
128129

129-
#' @param object an LinearSVCModel fitted by \code{spark.svmLinear}.
130+
#' @param object a LinearSVCModel fitted by \code{spark.svmLinear}.
130131
#' @return \code{summary} returns summary information of the fitted model, which is a list.
131132
#' The list includes \code{coefficients} (coefficients of the fitted model),
132-
#' \code{intercept} (intercept of the fitted model), \code{numClasses} (number of classes),
133-
#' \code{numFeatures} (number of features).
133+
#' \code{numClasses} (number of classes), \code{numFeatures} (number of features).
134134
#' @rdname spark.svmLinear
135135
#' @aliases summary,LinearSVCModel-method
136136
#' @export
137137
#' @note summary(LinearSVCModel) since 2.2.0
138138
setMethod("summary", signature(object = "LinearSVCModel"),
139139
function(object) {
140140
jobj <- object@jobj
141-
features <- callJMethod(jobj, "features")
142-
labels <- callJMethod(jobj, "labels")
143-
coefficients <- callJMethod(jobj, "coefficients")
144-
nCol <- length(coefficients) / length(features)
145-
coefficients <- matrix(unlist(coefficients), ncol = nCol)
146-
intercept <- callJMethod(jobj, "intercept")
141+
features <- callJMethod(jobj, "rFeatures")
142+
coefficients <- callJMethod(jobj, "rCoefficients")
143+
coefficients <- as.matrix(unlist(coefficients))
144+
colnames(coefficients) <- c("Estimate")
145+
rownames(coefficients) <- unlist(features)
147146
numClasses <- callJMethod(jobj, "numClasses")
148147
numFeatures <- callJMethod(jobj, "numFeatures")
149-
if (nCol == 1) {
150-
colnames(coefficients) <- c("Estimate")
151-
} else {
152-
colnames(coefficients) <- unlist(labels)
153-
}
154-
rownames(coefficients) <- unlist(features)
155-
list(coefficients = coefficients, intercept = intercept,
156-
numClasses = numClasses, numFeatures = numFeatures)
148+
list(coefficients = coefficients, numClasses = numClasses, numFeatures = numFeatures)
157149
})
158150

159151
# Save fitted LinearSVCModel to the input path

R/pkg/R/mllib_tree.R

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,20 @@ setClass("RandomForestRegressionModel", representation(jobj = "jobj"))
4545
#' @note RandomForestClassificationModel since 2.1.0
4646
setClass("RandomForestClassificationModel", representation(jobj = "jobj"))
4747

48+
#' S4 class that represents a DecisionTreeRegressionModel
49+
#'
50+
#' @param jobj a Java object reference to the backing Scala DecisionTreeRegressionModel
51+
#' @export
52+
#' @note DecisionTreeRegressionModel since 2.3.0
53+
setClass("DecisionTreeRegressionModel", representation(jobj = "jobj"))
54+
55+
#' S4 class that represents a DecisionTreeClassificationModel
56+
#'
57+
#' @param jobj a Java object reference to the backing Scala DecisionTreeClassificationModel
58+
#' @export
59+
#' @note DecisionTreeClassificationModel since 2.3.0
60+
setClass("DecisionTreeClassificationModel", representation(jobj = "jobj"))
61+
4862
# Create the summary of a tree ensemble model (eg. Random Forest, GBT)
4963
summary.treeEnsemble <- function(model) {
5064
jobj <- model@jobj
@@ -81,6 +95,36 @@ print.summary.treeEnsemble <- function(x) {
8195
invisible(x)
8296
}
8397

98+
# Create the summary of a decision tree model
99+
summary.decisionTree <- function(model) {
100+
jobj <- model@jobj
101+
formula <- callJMethod(jobj, "formula")
102+
numFeatures <- callJMethod(jobj, "numFeatures")
103+
features <- callJMethod(jobj, "features")
104+
featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString")
105+
maxDepth <- callJMethod(jobj, "maxDepth")
106+
list(formula = formula,
107+
numFeatures = numFeatures,
108+
features = features,
109+
featureImportances = featureImportances,
110+
maxDepth = maxDepth,
111+
jobj = jobj)
112+
}
113+
114+
# Prints the summary of decision tree models
115+
print.summary.decisionTree <- function(x) {
116+
jobj <- x$jobj
117+
cat("Formula: ", x$formula)
118+
cat("\nNumber of features: ", x$numFeatures)
119+
cat("\nFeatures: ", unlist(x$features))
120+
cat("\nFeature importances: ", x$featureImportances)
121+
cat("\nMax Depth: ", x$maxDepth)
122+
123+
summaryStr <- callJMethod(jobj, "summary")
124+
cat("\n", summaryStr, "\n")
125+
invisible(x)
126+
}
127+
84128
#' Gradient Boosted Tree Model for Regression and Classification
85129
#'
86130
#' \code{spark.gbt} fits a Gradient Boosted Tree Regression model or Classification model on a
@@ -499,3 +543,199 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path
499543
function(object, path, overwrite = FALSE) {
500544
write_internal(object, path, overwrite)
501545
})
546+
547+
#' Decision Tree Model for Regression and Classification
548+
#'
549+
#' \code{spark.decisionTree} fits a Decision Tree Regression model or Classification model on
550+
#' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted Decision Tree
551+
#' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to
552+
#' save/load fitted models.
553+
#' For more details, see
554+
#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-regression}{
555+
#' Decision Tree Regression} and
556+
#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-classifier}{
557+
#' Decision Tree Classification}
558+
#'
559+
#' @param data a SparkDataFrame for training.
560+
#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
561+
#' operators are supported, including '~', ':', '+', and '-'.
562+
#' @param type type of model, one of "regression" or "classification", to fit
563+
#' @param maxDepth Maximum depth of the tree (>= 0).
564+
#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing
565+
#' how to split on features at each node. More bins give higher granularity. Must be
566+
#' >= 2 and >= number of categories in any categorical feature.
567+
#' @param impurity Criterion used for information gain calculation.
568+
#' For regression, must be "variance". For classification, must be one of
569+
#' "entropy" and "gini", default is "gini".
570+
#' @param seed integer seed for random number generation.
571+
#' @param minInstancesPerNode Minimum number of instances each child must have after split.
572+
#' @param minInfoGain Minimum information gain for a split to be considered at a tree node.
573+
#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1).
574+
#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation.
575+
#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with
576+
#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching
577+
#' can speed up training of deeper trees. Users can set how often should the
578+
#' cache be checkpointed or disable it by setting checkpointInterval.
579+
#' @param ... additional arguments passed to the method.
580+
#' @aliases spark.decisionTree,SparkDataFrame,formula-method
581+
#' @return \code{spark.decisionTree} returns a fitted Decision Tree model.
582+
#' @rdname spark.decisionTree
583+
#' @name spark.decisionTree
584+
#' @export
585+
#' @examples
586+
#' \dontrun{
587+
#' # fit a Decision Tree Regression Model
588+
#' df <- createDataFrame(longley)
589+
#' model <- spark.decisionTree(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16)
590+
#'
591+
#' # get the summary of the model
592+
#' summary(model)
593+
#'
594+
#' # make predictions
595+
#' predictions <- predict(model, df)
596+
#'
597+
#' # save and load the model
598+
#' path <- "path/to/model"
599+
#' write.ml(model, path)
600+
#' savedModel <- read.ml(path)
601+
#' summary(savedModel)
602+
#'
603+
#' # fit a Decision Tree Classification Model
604+
#' t <- as.data.frame(Titanic)
605+
#' df <- createDataFrame(t)
606+
#' model <- spark.decisionTree(df, Survived ~ Freq + Age, "classification")
607+
#' }
608+
#' @note spark.decisionTree since 2.3.0
609+
setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "formula"),
610+
function(data, formula, type = c("regression", "classification"),
611+
maxDepth = 5, maxBins = 32, impurity = NULL, seed = NULL,
612+
minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10,
613+
maxMemoryInMB = 256, cacheNodeIds = FALSE) {
614+
type <- match.arg(type)
615+
formula <- paste(deparse(formula), collapse = "")
616+
if (!is.null(seed)) {
617+
seed <- as.character(as.integer(seed))
618+
}
619+
switch(type,
620+
regression = {
621+
if (is.null(impurity)) impurity <- "variance"
622+
impurity <- match.arg(impurity, "variance")
623+
jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeRegressorWrapper",
624+
"fit", data@sdf, formula, as.integer(maxDepth),
625+
as.integer(maxBins), impurity,
626+
as.integer(minInstancesPerNode), as.numeric(minInfoGain),
627+
as.integer(checkpointInterval), seed,
628+
as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
629+
new("DecisionTreeRegressionModel", jobj = jobj)
630+
},
631+
classification = {
632+
if (is.null(impurity)) impurity <- "gini"
633+
impurity <- match.arg(impurity, c("gini", "entropy"))
634+
jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeClassifierWrapper",
635+
"fit", data@sdf, formula, as.integer(maxDepth),
636+
as.integer(maxBins), impurity,
637+
as.integer(minInstancesPerNode), as.numeric(minInfoGain),
638+
as.integer(checkpointInterval), seed,
639+
as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
640+
new("DecisionTreeClassificationModel", jobj = jobj)
641+
}
642+
)
643+
})
644+
645+
# Get the summary of a Decision Tree Regression Model
646+
647+
#' @return \code{summary} returns summary information of the fitted model, which is a list.
648+
#' The list of components includes \code{formula} (formula),
649+
#' \code{numFeatures} (number of features), \code{features} (list of features),
650+
#' \code{featureImportances} (feature importances), and \code{maxDepth} (max depth of trees).
651+
#' @rdname spark.decisionTree
652+
#' @aliases summary,DecisionTreeRegressionModel-method
653+
#' @export
654+
#' @note summary(DecisionTreeRegressionModel) since 2.3.0
655+
setMethod("summary", signature(object = "DecisionTreeRegressionModel"),
656+
function(object) {
657+
ans <- summary.decisionTree(object)
658+
class(ans) <- "summary.DecisionTreeRegressionModel"
659+
ans
660+
})
661+
662+
# Prints the summary of Decision Tree Regression Model
663+
664+
#' @param x summary object of Decision Tree regression model or classification model
665+
#' returned by \code{summary}.
666+
#' @rdname spark.decisionTree
667+
#' @export
668+
#' @note print.summary.DecisionTreeRegressionModel since 2.3.0
669+
print.summary.DecisionTreeRegressionModel <- function(x, ...) {
670+
print.summary.decisionTree(x)
671+
}
672+
673+
# Get the summary of a Decision Tree Classification Model
674+
675+
#' @rdname spark.decisionTree
676+
#' @aliases summary,DecisionTreeClassificationModel-method
677+
#' @export
678+
#' @note summary(DecisionTreeClassificationModel) since 2.3.0
679+
setMethod("summary", signature(object = "DecisionTreeClassificationModel"),
680+
function(object) {
681+
ans <- summary.decisionTree(object)
682+
class(ans) <- "summary.DecisionTreeClassificationModel"
683+
ans
684+
})
685+
686+
# Prints the summary of Decision Tree Classification Model
687+
688+
#' @rdname spark.decisionTree
689+
#' @export
690+
#' @note print.summary.DecisionTreeClassificationModel since 2.3.0
691+
print.summary.DecisionTreeClassificationModel <- function(x, ...) {
692+
print.summary.decisionTree(x)
693+
}
694+
695+
# Makes predictions from a Decision Tree Regression model or Classification model
696+
697+
#' @param newData a SparkDataFrame for testing.
698+
#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named
699+
#' "prediction".
700+
#' @rdname spark.decisionTree
701+
#' @aliases predict,DecisionTreeRegressionModel-method
702+
#' @export
703+
#' @note predict(DecisionTreeRegressionModel) since 2.3.0
704+
setMethod("predict", signature(object = "DecisionTreeRegressionModel"),
705+
function(object, newData) {
706+
predict_internal(object, newData)
707+
})
708+
709+
#' @rdname spark.decisionTree
710+
#' @aliases predict,DecisionTreeClassificationModel-method
711+
#' @export
712+
#' @note predict(DecisionTreeClassificationModel) since 2.3.0
713+
setMethod("predict", signature(object = "DecisionTreeClassificationModel"),
714+
function(object, newData) {
715+
predict_internal(object, newData)
716+
})
717+
718+
# Save the Decision Tree Regression or Classification model to the input path.
719+
720+
#' @param object A fitted Decision Tree regression model or classification model.
721+
#' @param path The directory where the model is saved.
722+
#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
723+
#' which means throw exception if the output path exists.
724+
#'
725+
#' @aliases write.ml,DecisionTreeRegressionModel,character-method
726+
#' @rdname spark.decisionTree
727+
#' @export
728+
#' @note write.ml(DecisionTreeRegressionModel, character) since 2.3.0
729+
setMethod("write.ml", signature(object = "DecisionTreeRegressionModel", path = "character"),
730+
function(object, path, overwrite = FALSE) {
731+
write_internal(object, path, overwrite)
732+
})
733+
734+
#' @aliases write.ml,DecisionTreeClassificationModel,character-method
735+
#' @rdname spark.decisionTree
736+
#' @export
737+
#' @note write.ml(DecisionTreeClassificationModel, character) since 2.3.0
738+
setMethod("write.ml", signature(object = "DecisionTreeClassificationModel", path = "character"),
739+
function(object, path, overwrite = FALSE) {
740+
write_internal(object, path, overwrite)
741+
})

0 commit comments

Comments
 (0)