Skip to content

Commit b18b718

Browse files
committed
regression pass unit test
1 parent 463f965 commit b18b718

File tree

5 files changed

+207
-97
lines changed

5 files changed

+207
-97
lines changed

R/pkg/NAMESPACE

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,9 @@ export("as.DataFrame",
348348
"uncacheTable",
349349
"print.summary.GeneralizedLinearRegressionModel",
350350
"read.ml",
351-
"print.summary.KSTest")
351+
"print.summary.KSTest",
352+
"print.summary.DecisionTreeRegressionModel",
353+
"print.summary.DecisionTreeClassificationModel")
352354

353355
export("structField",
354356
"structField.jobj",
@@ -373,6 +375,8 @@ S3method(print, structField)
373375
S3method(print, structType)
374376
S3method(print, summary.GeneralizedLinearRegressionModel)
375377
S3method(print, summary.KSTest)
378+
S3method(print, summary.DecisionTreeRegressionModel)
379+
S3method(print, summary.DecisionTreeClassificationModel)
376380
S3method(structField, character)
377381
S3method(structField, jobj)
378382
S3method(structType, jobj)

R/pkg/R/mllib.R

Lines changed: 105 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1447,39 +1447,31 @@ print.summary.KSTest <- function(x, ...) {
14471447
invisible(x)
14481448
}
14491449

1450-
#' Decision tree regression model.
1450+
#' Decision Tree
14511451
#'
1452-
#' Fit Decision Tree regression model on a SparkDataFrame.
1452+
#' @description
1453+
#' \code{spark.decisionTree} tree
14531454
#'
1454-
#' @param data SparkDataFrame for training.
1455-
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
1456-
#' operators are supported, including '~', ':', '+', and '-'.
1457-
#' Note that operator '.' is not supported currently.
1458-
#' @return a fitted decision tree regression model
1459-
#' @rdname spark.decisionTreeRegressor
1460-
#' @seealso rpart: \url{https://cran.r-project.org/web/packages/rpart/}
1461-
#' @export
1462-
#' @examples
1463-
#' \dontrun{
1464-
#' df <- createDataFrame(sqlContext, kyphosis)
1465-
#' model <- spark.decisionTree(df, Kyphosis ~ Age + Number + Start)
1466-
#' }
1455+
#' Decision Tree
1456+
#'
1457+
#' @param data a SparkDataFrame of user data.
14671458
#' @note spark.decisionTree since 2.1.0
14681459
setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "formula"),
1469-
function(data, formula, type = c("regression", "classification")) {
1460+
function(data, formula, type = c("regression", "classification"), maxDepth = 5, maxBins = 32 ) {
14701461
formula <- paste(deparse(formula), collapse = "")
14711462
if (identical(type, "regression")) {
14721463
jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeRegressorWrapper", "fit",
1473-
data@sdf, formula)
1464+
data@sdf, formula, as.integer(maxDepth), as.integer(maxBins))
14741465
new("DecisionTreeRegressionModel", jobj = jobj)
14751466
} else if (identical(type, "classification")) {
1476-
jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeClassificationWrapper", "fit",
1477-
data@sdf, formula)
1467+
jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeClassifierWrapper", "fit",
1468+
data@sdf, formula, as.integer(maxDepth), as.integer(maxBins))
14781469
new("DecisionTreeClassificationModel", jobj = jobj)
14791470
}
14801471
})
14811472

1482-
# Makes predictions from a Decision Tree model or a model produced by spark.decisionTree()
1473+
# Makes predictions from a Decision Tree Regression model or
1474+
# a model produced by spark.decisionTree()
14831475

14841476
#' @param newData a SparkDataFrame for testing.
14851477
#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named
@@ -1492,6 +1484,20 @@ setMethod("predict", signature(object = "DecisionTreeRegressionModel"),
14921484
predict_internal(object, newData)
14931485
})
14941486

1487+
# Makes predictions from a Decision Tree Classification model or
1488+
# a model produced by spark.decisionTree()
1489+
1490+
#' @param newData a SparkDataFrame for testing.
1491+
#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named
1492+
#' "prediction"
1493+
#' @rdname spark.decisionTree
1494+
#' @export
1495+
#' @note predict(decisionTreeClassificationModel) since 2.1.0
1496+
setMethod("predict", signature(object = "DecisionTreeClassificationModel"),
1497+
function(object, newData) {
1498+
predict_internal(object, newData)
1499+
})
1500+
14951501
#' Save the Decision Tree Regression model to the input path.
14961502
#'
14971503
#' @param object A fitted Decision tree regression model
@@ -1504,23 +1510,88 @@ setMethod("predict", signature(object = "DecisionTreeRegressionModel"),
15041510
#' @export
15051511
#' @note write.ml(DecisionTreeRegressionModel, character) since 2.1.0
15061512
setMethod("write.ml", signature(object = "DecisionTreeRegressionModel", path = "character"),
1507-
function(object, path, overwrite = FALSE) {
1508-
write_internal(object, path, overwrite)
1509-
})
1513+
function(object, path, overwrite = FALSE) {
1514+
write_internal(object, path, overwrite)
1515+
})
15101516

1511-
# Get the summary of an IsotonicRegressionModel model
1517+
#' Save the Decision Tree Classification model to the input path.
1518+
#'
1519+
#' @param object A fitted Decision tree classification model
1520+
#' @param path The directory where the model is saved
1521+
#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
1522+
#' which means throw exception if the output path exists.
1523+
#'
1524+
#' @aliases write.ml,DecisionTreeClassificationModel,character-method
1525+
#' @rdname spark.decisionTreeClassification
1526+
#' @export
1527+
#' @note write.ml(DecisionTreeClassificationModel, character) since 2.1.0
1528+
setMethod("write.ml", signature(object = "DecisionTreeClassificationModel", path = "character"),
1529+
function(object, path, overwrite = FALSE) {
1530+
write_internal(object, path, overwrite)
1531+
})
15121532

1513-
#' @param object a fitted IsotonicRegressionModel
1514-
#' @param ... Other optional arguments to summary of an IsotonicRegressionModel
1515-
#' @return \code{summary} returns the model's boundaries and prediction as lists
1516-
#' @rdname spark.isoreg
1517-
#' @aliases summary,IsotonicRegressionModel-method
1533+
# Get the summary of an DecisionTreeRegressionModel model
1534+
1535+
#' @param object a fitted DecisionTreeRegressionModel
1536+
#' @param ... Other optional arguments to summary of a DecisionTreeRegressionModel
1537+
#' @return \code{summary} returns the model's features as lists, depth and number of nodes
1538+
#' @rdname spark.decisionTree
1539+
#' @aliases summary,DecisionTreeRegressionModel-method
15181540
#' @export
1519-
#' @note summary(IsotonicRegressionModel) since 2.1.0
1541+
#' @note summary(DecisionTreeRegressionModel) since 2.1.0
15201542
setMethod("summary", signature(object = "DecisionTreeRegressionModel"),
15211543
function(object, ...) {
15221544
jobj <- object@jobj
1523-
boundaries <- callJMethod(jobj, "boundaries")
1524-
predictions <- callJMethod(jobj, "predictions")
1525-
return(list(boundaries = boundaries, predictions = predictions))
1526-
})
1545+
features <- callJMethod(jobj, "features")
1546+
depth <- callJMethod(jobj, "depth")
1547+
numNodes <- callJMethod(jobj, "numNodes")
1548+
ans <- list(features = features, depth = depth, numNodes = numNodes)
1549+
class(ans) <- "summary.DecisionTreeRegressionModel"
1550+
ans
1551+
})
1552+
1553+
# Get the summary of an DecisionTreeClassificationModel model
1554+
1555+
#' @param object a fitted DecisionTreeClassificationModel
1556+
#' @param ... Other optional arguments to summary of a DecisionTreeClassificationModel
1557+
#' @return \code{summary} returns the model's features as lists, depth and number of nodes
1558+
#' @rdname spark.decisionTree
1559+
#' @aliases summary,DecisionTreeClassificationModel-method
1560+
#' @export
1561+
#' @note summary(DecisionTreeRegressionModel) since 2.1.0
1562+
setMethod("summary", signature(object = "DecisionTreeClassificationModel"),
1563+
function(object, ...) {
1564+
jobj <- object@jobj
1565+
features <- callJMethod(jobj, "features")
1566+
depth <- callJMethod(jobj, "depth")
1567+
numNodes <- callJMethod(jobj, "numNodes")
1568+
ans <- list(features = features, depth = depth, numNodes = numNodes)
1569+
class(ans) <- "summary.DecisionTreeClassificationModel"
1570+
ans
1571+
})
1572+
1573+
# Prints the summary of Decision Tree Regression Model
1574+
1575+
#' @rdname spark.decisionTree
1576+
#' @param x summary object of decisionTreeRegressionModel returned by \code{summary}.
1577+
#' @export
1578+
#' @note print.summary.DecisionTreeRegressionModel since 2.1.0
1579+
print.summary.DecisionTreeRegressionModel <- function(x, ...) {
1580+
jobj <- x@jobj
1581+
summaryStr <- callJMethod(jobj, "summary")
1582+
cat(summaryStr, "\n")
1583+
invisible(x)
1584+
}
1585+
1586+
# Prints the summary of Decision Tree Classification Model
1587+
1588+
#' @rdname spark.decisionTree
1589+
#' @param x summary object of decisionTreeClassificationModel returned by \code{summary}.
1590+
#' @export
1591+
#' @note print.summary.DecisionTreeClassificationModel since 2.1.0
1592+
print.summary.DecisionTreeClassificationModel <- function(x, ...) {
1593+
jobj <- x@jobj
1594+
summaryStr <- callJMethod(jobj, "summary")
1595+
cat(summaryStr, "\n")
1596+
invisible(x)
1597+
}

R/pkg/inst/tests/testthat/test_mllib.R

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,4 +791,34 @@ test_that("spark.kstest", {
791791
expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:")
792792
})
793793

794+
test_that("spark.decisionTree Regression", {
795+
data <- suppressWarnings(createDataFrame(longley))
796+
model <- spark.decisionTree(data, Employed~., "regression", maxDepth=5, maxBins=16)
797+
798+
#Test summary
799+
stats <- summary(model)
800+
expect_equal(stats$depth, 5)
801+
expect_equal(stats$numNodes, 31)
802+
803+
#Test model predict
804+
predictions <- collect(predict(model, data))
805+
expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187,
806+
63.221, 63.639, 64.989, 63.761,
807+
66.019, 67.857, 68.169, 66.513,
808+
68.655, 69.564, 69.331, 70.551),
809+
tolerance = 1e-4)
810+
811+
# Test model save/load
812+
modelPath <- tempfile(pattern = "spark-decisionTreeRegression", fileext = ".tmp")
813+
write.ml(model, modelPath)
814+
expect_error(write.ml(model, modelPath))
815+
write.ml(model, modelPath, overwrite = TRUE)
816+
model2 <- read.ml(modelPath)
817+
stats2 <- summary(model2)
818+
expect_equal(stats$depth, stats2$depth)
819+
expect_equal(stats$numNodes, stats2$numNodes)
820+
821+
unlink(modelPath)
822+
})
823+
794824
sparkR.session.stop()

mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassifierWrapper.scala

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -23,29 +23,23 @@ import org.json4s.JsonDSL._
2323
import org.json4s.jackson.JsonMethods._
2424

2525
import org.apache.spark.ml.{Pipeline, PipelineModel}
26-
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
26+
import org.apache.spark.ml.attribute.AttributeGroup
2727
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
28-
import org.apache.spark.ml.feature.{IndexToString, RFormula}
28+
import org.apache.spark.ml.feature.RFormula
2929
import org.apache.spark.ml.util._
3030
import org.apache.spark.sql.{DataFrame, Dataset}
3131

3232
private[r] class DecisionTreeClassifierWrapper private (
3333
val pipeline: PipelineModel,
3434
val features: Array[String],
35-
val labels: Array[String]) extends MLWritable {
36-
37-
import DecisionTreeClassifierWrapper.PREDICTED_LABEL_INDEX_COL
35+
val maxDepth: Int,
36+
val maxBins: Int) extends MLWritable {
3837

3938
private val DTModel: DecisionTreeClassificationModel =
4039
pipeline.stages(1).asInstanceOf[DecisionTreeClassificationModel]
4140

42-
lazy val maxDepth: Int = DTModel.getMaxDepth
43-
44-
lazy val maxBins: Int = DTModel.getMaxBins
45-
4641
def transform(dataset: Dataset[_]): DataFrame = {
4742
pipeline.transform(dataset)
48-
.drop(PREDICTED_LABEL_INDEX_COL)
4943
.drop(DTModel.getFeaturesCol)
5044
}
5145

@@ -54,33 +48,36 @@ private[r] class DecisionTreeClassifierWrapper private (
5448
}
5549

5650
private[r] object DecisionTreeClassifierWrapper extends MLReadable[DecisionTreeClassifierWrapper] {
51+
def fit(data: DataFrame,
52+
formula: String,
53+
maxDepth: Int,
54+
maxBins: Int): DecisionTreeClassifierWrapper = {
5755

58-
val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
59-
val PREDICTED_LABEL_COL = "prediction"
60-
61-
def fit(data: DataFrame, formula: String): DecisionTreeClassifierWrapper = {
6256
val rFormula = new RFormula()
6357
.setFormula(formula)
64-
.fit(data)
65-
// get labels and feature names from output schema
66-
val schema = rFormula.transform(data).schema
67-
val labelAttr = Attribute.fromStructField(schema(rFormula.getLabelCol))
68-
.asInstanceOf[NominalAttribute]
69-
val labels = labelAttr.values.get
70-
val featureAttrs = AttributeGroup.fromStructField(schema(rFormula.getFeaturesCol))
58+
.setFeaturesCol("features")
59+
.setLabelCol("label")
60+
61+
RWrapperUtils.checkDataColumns(rFormula, data)
62+
val rFormulaModel = rFormula.fit(data)
63+
64+
// get feature names from output schema
65+
val schema = rFormulaModel.transform(data).schema
66+
val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
7167
.attributes.get
7268
val features = featureAttrs.map(_.name.get)
69+
7370
// assemble and fit the pipeline
74-
val decisionTree = new DecisionTreeClassifier()
75-
.setPredictionCol(PREDICTED_LABEL_INDEX_COL)
76-
val idxToStr = new IndexToString()
77-
.setInputCol(PREDICTED_LABEL_INDEX_COL)
78-
.setOutputCol(PREDICTED_LABEL_COL)
79-
.setLabels(labels)
71+
val decisionTreeClassification = new DecisionTreeClassifier()
72+
.setMaxDepth(maxDepth)
73+
.setMaxBins(maxBins)
74+
.setFeaturesCol(rFormula.getFeaturesCol)
75+
8076
val pipeline = new Pipeline()
81-
.setStages(Array(rFormula, decisionTree, idxToStr))
77+
.setStages(Array(rFormulaModel, decisionTreeClassification))
8278
.fit(data)
83-
new DecisionTreeClassifierWrapper(pipeline, features, labels)
79+
80+
new DecisionTreeClassifierWrapper(pipeline, features, maxDepth, maxBins)
8481
}
8582

8683
override def read: MLReader[DecisionTreeClassifierWrapper] =
@@ -97,7 +94,8 @@ private[r] object DecisionTreeClassifierWrapper extends MLReadable[DecisionTreeC
9794

9895
val rMetadata = ("class" -> instance.getClass.getName) ~
9996
("features" -> instance.features.toSeq) ~
100-
("labels" -> instance.labels.toSeq)
97+
("maxDepth" -> instance.maxDepth) ~
98+
("maxBins" -> instance.maxBins)
10199
val rMetadataJson: String = compact(render(rMetadata))
102100

103101
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
@@ -116,8 +114,10 @@ private[r] object DecisionTreeClassifierWrapper extends MLReadable[DecisionTreeC
116114
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
117115
val rMetadata = parse(rMetadataStr)
118116
val features = (rMetadata \ "features").extract[Array[String]]
119-
val labels = (rMetadata \ "labels").extract[Array[String]]
120-
new DecisionTreeClassifierWrapper(pipeline, features, labels)
117+
val maxDepth = (rMetadata \ "maxDepth").extract[Int]
118+
val maxBins = (rMetadata \ "maxBins").extract[Int]
119+
120+
new DecisionTreeClassifierWrapper(pipeline, features, maxDepth, maxBins)
121121
}
122122
}
123123
}

0 commit comments

Comments
 (0)