Skip to content

Commit bee4868

Browse files
committed
DecisionTree Wrapper in SparkR
1 parent 7aeb20b commit bee4868

File tree

6 files changed

+322
-1
lines changed

6 files changed

+322
-1
lines changed

R/pkg/NAMESPACE

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ exportMethods("glm",
4343
"spark.isoreg",
4444
"spark.gaussianMixture",
4545
"spark.als",
46-
"spark.kstest")
46+
"spark.kstest",
47+
"spark.decisionTree")
4748

4849
# Job group lifecycle management methods
4950
export("setJobGroup",

R/pkg/R/generics.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1358,6 +1358,11 @@ setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.p
13581358
#' @export
13591359
setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") })
13601360

1361+
#' @rdname spark.decisionTree
1362+
#' @export
1363+
setGeneric("spark.decisionTree",
1364+
function(data, formula, ...) { standardGeneric("spark.decisionTree") })
1365+
13611366
#' @rdname spark.gaussianMixture
13621367
#' @export
13631368
setGeneric("spark.gaussianMixture",

R/pkg/R/mllib.R

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,20 @@ setClass("ALSModel", representation(jobj = "jobj"))
9595
#' @note KSTest since 2.1.0
9696
setClass("KSTest", representation(jobj = "jobj"))
9797

98+
#' S4 class that represents a DecisionTreeRegressionModel
99+
#'
100+
#' @param jobj a Java object reference to the backing Scala DecisionTreeRegressionModel
101+
#' @export
102+
#' @note DecisionTreeRegressionModel since 2.1.0
103+
setClass("DecisionTreeRegressionModel", representation(jobj = "jobj"))
104+
105+
#' S4 class that represents a DecisionTreeClassificationModel
106+
#'
107+
#' @param jobj a Java object reference to the backing Scala DecisionTreeClassificationModel
108+
#' @export
109+
#' @note DecisionTreeClassificationModel since 2.1.0
110+
setClass("DecisionTreeClassificationModel", representation(jobj = "jobj"))
111+
98112
#' Saves the MLlib model to the input path
99113
#'
100114
#' Saves the MLlib model to the input path. For more information, see the specific
@@ -897,6 +911,22 @@ setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "charact
897911
write_internal(object, path, overwrite)
898912
})
899913

914+
#' Save the Decision Tree Regression model to the input path.
915+
#'
916+
#' @param object A fitted Decision tree regression model
917+
#' @param path The directory where the model is saved
918+
#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
919+
#' which means throw exception if the output path exists.
920+
#'
921+
#' @aliases write.ml,DecisionTreeRegressionModel,character-method
922+
#' @rdname spark.decisionTreeRegression
923+
#' @export
924+
#' @note write.ml(DecisionTreeRegressionModel, character) since 2.1.0
925+
setMethod("write.ml", signature(object = "DecisionTreeRegressionModel", path = "character"),
926+
function(object, path, overwrite = FALSE) {
927+
write_internal(object, path, overwrite)
928+
})
929+
900930
#' Load a fitted MLlib model from the input path.
901931
#'
902932
#' @param path path of the model to read.
@@ -932,6 +962,8 @@ read.ml <- function(path) {
932962
new("GaussianMixtureModel", jobj = jobj)
933963
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.ALSWrapper")) {
934964
new("ALSModel", jobj = jobj)
965+
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.DecisionTreeRegressorWrapper")) {
966+
new("DecisionTreeRegressionModel", jobj = jobj)
935967
} else {
936968
stop("Unsupported model: ", jobj)
937969
}
@@ -1427,3 +1459,39 @@ print.summary.KSTest <- function(x, ...) {
14271459
cat(summaryStr, "\n")
14281460
invisible(x)
14291461
}
1462+
1463+
#' Decision tree regression model.
1464+
#'
1465+
#' Fit Decision Tree regression model on a SparkDataFrame.
1466+
#'
1467+
#' @param data SparkDataFrame for training.
1468+
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
1469+
#' operators are supported, including '~', ':', '+', and '-'.
1470+
#' Note that operator '.' is not supported currently.
1471+
#' @return a fitted decision tree regression model
1472+
#' @rdname spark.decisionTreeRegressor
1473+
#' @seealso rpart: \url{https://cran.r-project.org/web/packages/rpart/}
1474+
#' @export
1475+
#' @examples
1476+
#' \dontrun{
1477+
#' df <- createDataFrame(sqlContext, kyphosis)
1478+
#' model <- spark.decisionTree(df, Kyphosis ~ Age + Number + Start)
1479+
#' }
1480+
setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "formula"),
1481+
function(data, formula, type = c("regression", "classification")) {
1482+
formula <- paste(deparse(formula), collapse = "")
1483+
if (identical(type, "regression")) {
1484+
jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeRegressorWrapper", "fit",
1485+
data@sdf, formula)
1486+
new("DecisionTreeRegressionModel", jobj = jobj)
1487+
} else if (identical(type, "classification")) {
1488+
jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeClassificationWrapper", "fit",
1489+
data@sdf, formula)
1490+
new("DecisionTreeClassificationModel", jobj = jobj)
1491+
}
1492+
})
1493+
1494+
setMethod("predict", signature(object = "DecisionTreeRegressionModel"),
1495+
function(object, newData) {
1496+
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
1497+
})
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.r
19+
20+
import org.apache.hadoop.fs.Path
21+
import org.json4s._
22+
import org.json4s.JsonDSL._
23+
import org.json4s.jackson.JsonMethods._
24+
25+
import org.apache.spark.ml.{Pipeline, PipelineModel}
26+
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
27+
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
28+
import org.apache.spark.ml.feature.{IndexToString, RFormula}
29+
import org.apache.spark.ml.util._
30+
import org.apache.spark.sql.{DataFrame, Dataset}
31+
32+
private[r] class DecisionTreeClassifierWrapper private (
33+
val pipeline: PipelineModel,
34+
val features: Array[String],
35+
val labels: Array[String]) extends MLWritable {
36+
37+
import DecisionTreeClassifierWrapper.PREDICTED_LABEL_INDEX_COL
38+
39+
private val DTModel: DecisionTreeClassificationModel =
40+
pipeline.stages(1).asInstanceOf[DecisionTreeClassificationModel]
41+
42+
lazy val maxDepth: Int = DTModel.getMaxDepth
43+
44+
lazy val maxBins: Int = DTModel.getMaxBins
45+
46+
def transform(dataset: Dataset[_]): DataFrame = {
47+
pipeline.transform(dataset)
48+
.drop(PREDICTED_LABEL_INDEX_COL)
49+
.drop(DTModel.getFeaturesCol)
50+
}
51+
52+
override def write: MLWriter = new
53+
DecisionTreeClassifierWrapper.DecisionTreeClassifierWrapperWriter(this)
54+
}
55+
56+
private[r] object DecisionTreeClassifierWrapper extends MLReadable[DecisionTreeClassifierWrapper] {
57+
58+
val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
59+
val PREDICTED_LABEL_COL = "prediction"
60+
61+
def fit(data: DataFrame, formula: String): DecisionTreeClassifierWrapper = {
62+
val rFormula = new RFormula()
63+
.setFormula(formula)
64+
.fit(data)
65+
// get labels and feature names from output schema
66+
val schema = rFormula.transform(data).schema
67+
val labelAttr = Attribute.fromStructField(schema(rFormula.getLabelCol))
68+
.asInstanceOf[NominalAttribute]
69+
val labels = labelAttr.values.get
70+
val featureAttrs = AttributeGroup.fromStructField(schema(rFormula.getFeaturesCol))
71+
.attributes.get
72+
val features = featureAttrs.map(_.name.get)
73+
// assemble and fit the pipeline
74+
val decisionTree = new DecisionTreeClassifier()
75+
.setPredictionCol(PREDICTED_LABEL_INDEX_COL)
76+
val idxToStr = new IndexToString()
77+
.setInputCol(PREDICTED_LABEL_INDEX_COL)
78+
.setOutputCol(PREDICTED_LABEL_COL)
79+
.setLabels(labels)
80+
val pipeline = new Pipeline()
81+
.setStages(Array(rFormula, decisionTree, idxToStr))
82+
.fit(data)
83+
new DecisionTreeClassifierWrapper(pipeline, features, labels)
84+
}
85+
86+
override def read: MLReader[DecisionTreeClassifierWrapper] =
87+
new DecisionTreeClassifierWrapperReader
88+
89+
override def load(path: String): DecisionTreeClassifierWrapper = super.load(path)
90+
91+
class DecisionTreeClassifierWrapperWriter(instance: DecisionTreeClassifierWrapper)
92+
extends MLWriter {
93+
94+
override protected def saveImpl(path: String): Unit = {
95+
val rMetadataPath = new Path(path, "rMetadata").toString
96+
val pipelinePath = new Path(path, "pipeline").toString
97+
98+
val rMetadata = ("class" -> instance.getClass.getName) ~
99+
("features" -> instance.features.toSeq) ~
100+
("labels" -> instance.labels.toSeq)
101+
val rMetadataJson: String = compact(render(rMetadata))
102+
103+
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
104+
instance.pipeline.save(pipelinePath)
105+
}
106+
}
107+
108+
class DecisionTreeClassifierWrapperReader extends MLReader[DecisionTreeClassifierWrapper] {
109+
110+
override def load(path: String): DecisionTreeClassifierWrapper = {
111+
implicit val format = DefaultFormats
112+
val rMetadataPath = new Path(path, "rMetadata").toString
113+
val pipelinePath = new Path(path, "pipeline").toString
114+
val pipeline = PipelineModel.load(pipelinePath)
115+
116+
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
117+
val rMetadata = parse(rMetadataStr)
118+
val features = (rMetadata \ "features").extract[Array[String]]
119+
val labels = (rMetadata \ "labels").extract[Array[String]]
120+
new DecisionTreeClassifierWrapper(pipeline, features, labels)
121+
}
122+
}
123+
}
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.r
19+
20+
import org.apache.hadoop.fs.Path
21+
import org.json4s._
22+
import org.json4s.JsonDSL._
23+
import org.json4s.jackson.JsonMethods._
24+
25+
import org.apache.spark.ml.{Pipeline, PipelineModel}
26+
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
27+
import org.apache.spark.ml.feature.{IndexToString, RFormula}
28+
import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
29+
import org.apache.spark.ml.util._
30+
import org.apache.spark.sql.{DataFrame, Dataset}
31+
32+
private[r] class DecisionTreeRegressorWrapper private (
33+
val pipeline: PipelineModel,
34+
val features: Array[String],
35+
val labels: Array[String]) extends MLWritable {
36+
37+
import DecisionTreeRegressorWrapper.PREDICTED_LABEL_INDEX_COL
38+
39+
private val DTModel: DecisionTreeRegressionModel =
40+
pipeline.stages(1).asInstanceOf[DecisionTreeRegressionModel]
41+
42+
lazy val maxDepth: Int = DTModel.getMaxDepth
43+
44+
lazy val maxBins: Int = DTModel.getMaxBins
45+
46+
def transform(dataset: Dataset[_]): DataFrame = {
47+
pipeline.transform(dataset)
48+
.drop(PREDICTED_LABEL_INDEX_COL)
49+
.drop(DTModel.getFeaturesCol)
50+
}
51+
52+
override def write: MLWriter = new
53+
DecisionTreeRegressorWrapper.DecisionTreeRegressorWrapperWriter(this)
54+
}
55+
56+
private[r] object DecisionTreeRegressorWrapper extends MLReadable[DecisionTreeRegressorWrapper] {
57+
58+
val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
59+
val PREDICTED_LABEL_COL = "prediction"
60+
61+
def fit(data: DataFrame, formula: String): DecisionTreeRegressorWrapper = {
62+
val rFormula = new RFormula()
63+
.setFormula(formula)
64+
.fit(data)
65+
// get labels and feature names from output schema
66+
val schema = rFormula.transform(data).schema
67+
val labelAttr = Attribute.fromStructField(schema(rFormula.getLabelCol))
68+
.asInstanceOf[NominalAttribute]
69+
val labels = labelAttr.values.get
70+
val featureAttrs = AttributeGroup.fromStructField(schema(rFormula.getFeaturesCol))
71+
.attributes.get
72+
val features = featureAttrs.map(_.name.get)
73+
// assemble and fit the pipeline
74+
val decisionTree = new DecisionTreeRegressor()
75+
.setPredictionCol(PREDICTED_LABEL_INDEX_COL)
76+
val idxToStr = new IndexToString()
77+
.setInputCol(PREDICTED_LABEL_INDEX_COL)
78+
.setOutputCol(PREDICTED_LABEL_COL)
79+
.setLabels(labels)
80+
val pipeline = new Pipeline()
81+
.setStages(Array(rFormula, decisionTree, idxToStr))
82+
.fit(data)
83+
new DecisionTreeRegressorWrapper(pipeline, features, labels)
84+
}
85+
86+
override def read: MLReader[DecisionTreeRegressorWrapper] = new DecisionTreeRegressorWrapperReader
87+
88+
override def load(path: String): DecisionTreeRegressorWrapper = super.load(path)
89+
90+
class DecisionTreeRegressorWrapperWriter(instance: DecisionTreeRegressorWrapper)
91+
extends MLWriter {
92+
93+
override protected def saveImpl(path: String): Unit = {
94+
val rMetadataPath = new Path(path, "rMetadata").toString
95+
val pipelinePath = new Path(path, "pipeline").toString
96+
97+
val rMetadata = ("class" -> instance.getClass.getName) ~
98+
("features" -> instance.features.toSeq) ~
99+
("labels" -> instance.labels.toSeq)
100+
val rMetadataJson: String = compact(render(rMetadata))
101+
102+
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
103+
instance.pipeline.save(pipelinePath)
104+
}
105+
}
106+
107+
class DecisionTreeRegressorWrapperReader extends MLReader[DecisionTreeRegressorWrapper] {
108+
109+
override def load(path: String): DecisionTreeRegressorWrapper = {
110+
implicit val format = DefaultFormats
111+
val rMetadataPath = new Path(path, "rMetadata").toString
112+
val pipelinePath = new Path(path, "pipeline").toString
113+
val pipeline = PipelineModel.load(pipelinePath)
114+
115+
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
116+
val rMetadata = parse(rMetadataStr)
117+
val features = (rMetadata \ "features").extract[Array[String]]
118+
val labels = (rMetadata \ "labels").extract[Array[String]]
119+
new DecisionTreeRegressorWrapper(pipeline, features, labels)
120+
}
121+
}
122+
}

mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ private[r] object RWrappers extends MLReader[Object] {
5454
GaussianMixtureWrapper.load(path)
5555
case "org.apache.spark.ml.r.ALSWrapper" =>
5656
ALSWrapper.load(path)
57+
case "org.apache.spark.ml.r.DecisionTreeRegressorWrapper" =>
58+
DecisionTreeRegressorWrapper.load(path)
5759
case _ =>
5860
throw new SparkException(s"SparkR read.ml does not support load $className")
5961
}

0 commit comments

Comments
 (0)