Skip to content

Commit d034735

Browse files
committed
classification unit test
1 parent d107ab9 commit d034735

File tree

3 files changed

+50
-17
lines changed

3 files changed

+50
-17
lines changed

R/pkg/R/mllib.R

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,7 +1474,7 @@ print.summary.KSTest <- function(x, ...) {
14741474
#' df <- createDataFrame(longley)
14751475
#'
14761476
#' # fit a Decision Tree Regression Model
1477-
#' model <- spark.decisionTree(data, Employed~., "regression", maxDepth=5, maxBins=16)
1477+
#' model <- spark.decisionTree(data, Employed~., type = "regression", maxDepth = 5, maxBins = 16)
14781478
#'
14791479
#' # get the summary of the model
14801480
#' summary(model)
@@ -1579,7 +1579,7 @@ setMethod("summary", signature(object = "DecisionTreeRegressionModel"),
15791579
features <- callJMethod(jobj, "features")
15801580
depth <- callJMethod(jobj, "depth")
15811581
numNodes <- callJMethod(jobj, "numNodes")
1582-
ans <- list(features = features, depth = depth, numNodes = numNodes)
1582+
ans <- list(features = features, depth = depth, numNodes = numNodes, jobj = jobj)
15831583
class(ans) <- "summary.DecisionTreeRegressionModel"
15841584
ans
15851585
})
@@ -1594,15 +1594,17 @@ setMethod("summary", signature(object = "DecisionTreeRegressionModel"),
15941594
#' @export
15951595
#' @note summary(DecisionTreeRegressionModel) since 2.1.0
15961596
setMethod("summary", signature(object = "DecisionTreeClassificationModel"),
1597-
function(object, ...) {
1598-
jobj <- object@jobj
1599-
features <- callJMethod(jobj, "features")
1600-
depth <- callJMethod(jobj, "depth")
1601-
numNodes <- callJMethod(jobj, "numNodes")
1602-
ans <- list(features = features, depth = depth, numNodes = numNodes)
1603-
class(ans) <- "summary.DecisionTreeClassificationModel"
1604-
ans
1605-
})
1597+
function(object, ...) {
1598+
jobj <- object@jobj
1599+
features <- callJMethod(jobj, "features")
1600+
depth <- callJMethod(jobj, "depth")
1601+
numNodes <- callJMethod(jobj, "numNodes")
1602+
numClasses <- callJMethod(jobj, "numClasses")
1603+
ans <- list(features = features, depth = depth,
1604+
numNodes = numNodes, numClasses = numClasses, jobj = jobj)
1605+
class(ans) <- "summary.DecisionTreeClassificationModel"
1606+
ans
1607+
})
16061608

16071609
# Prints the summary of Decision Tree Regression Model
16081610

@@ -1611,11 +1613,11 @@ function(object, ...) {
16111613
#' @export
16121614
#' @note print.summary.DecisionTreeRegressionModel since 2.1.0
16131615
print.summary.DecisionTreeRegressionModel <- function(x, ...) {
1614-
jobj <- x@jobj
1615-
summaryStr <- callJMethod(jobj, "summary")
1616-
cat(summaryStr, "\n")
1617-
invisible(x)
1618-
}
1616+
jobj <- x$jobj
1617+
summaryStr <- callJMethod(jobj, "summary")
1618+
cat(summaryStr, "\n")
1619+
invisible(x)
1620+
}
16191621

16201622
# Prints the summary of Decision Tree Classification Model
16211623

@@ -1624,7 +1626,7 @@ print.summary.DecisionTreeRegressionModel <- function(x, ...) {
16241626
#' @export
16251627
#' @note print.summary.DecisionTreeClassificationModel since 2.1.0
16261628
print.summary.DecisionTreeClassificationModel <- function(x, ...) {
1627-
jobj <- x@jobj
1629+
jobj <- x$jobj
16281630
summaryStr <- callJMethod(jobj, "summary")
16291631
cat(summaryStr, "\n")
16301632
invisible(x)

R/pkg/inst/tests/testthat/test_mllib.R

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,4 +821,29 @@ test_that("spark.decisionTree Regression", {
821821
unlink(modelPath)
822822
})
823823

824+
test_that("spark.decisionTree Classification", {
825+
data <- suppressWarnings(createDataFrame(iris))
826+
model <- spark.decisionTree(data, Species ~ Petal_Length + Petal_Width, "classification",
827+
maxDepth = 5, maxBins = 16)
828+
829+
#Test summary
830+
stats <- summary(model)
831+
expect_equal(stats$depth, 5)
832+
expect_equal(stats$numNodes, 19)
833+
expect_equal(stats$numClasses, 3)
834+
835+
# Test model save/load
836+
modelPath <- tempfile(pattern = "spark-decisionTreeClassification", fileext = ".tmp")
837+
write.ml(model, modelPath)
838+
expect_error(write.ml(model, modelPath))
839+
write.ml(model, modelPath, overwrite = TRUE)
840+
model2 <- read.ml(modelPath)
841+
stats2 <- summary(model2)
842+
expect_equal(stats$depth, stats2$depth)
843+
expect_equal(stats$numNodes, stats2$numNodes)
844+
expect_equal(stats$numClasses, stats2$numClasses)
845+
846+
unlink(modelPath)
847+
})
848+
824849
sparkR.session.stop()

mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassifierWrapper.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ private[r] class DecisionTreeClassifierWrapper private (
3838
private val DTModel: DecisionTreeClassificationModel =
3939
pipeline.stages(1).asInstanceOf[DecisionTreeClassificationModel]
4040

41+
lazy val depth: Int = DTModel.depth
42+
lazy val numNodes: Int = DTModel.numNodes
43+
lazy val numClasses: Int = DTModel.numClasses
44+
45+
def summary: String = DTModel.toDebugString
46+
4147
def transform(dataset: Dataset[_]): DataFrame = {
4248
pipeline.transform(dataset)
4349
.drop(DTModel.getFeaturesCol)

0 commit comments

Comments
 (0)