Skip to content

Commit f825352

Browse files
committed
Wrote Python API and example for DecisionTree. Also added toString, depth, and numNodes methods to DecisionTreeModel.
1 parent fc47bb6 commit f825352

File tree

5 files changed

+455
-0
lines changed

5 files changed

+455
-0
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""
19+
Decision tree classification and regression using MLlib.
20+
"""
21+
22+
import sys
23+
24+
from operator import add
25+
26+
from pyspark import SparkContext
27+
from pyspark.mllib.regression import LabeledPoint
28+
from pyspark.mllib.tree import DecisionTree
29+
30+
31+
# Parse a line of text into an MLlib LabeledPoint object
32+
def parsePoint(line):
33+
values = [float(s) for s in line.split(',')]
34+
if values[0] == -1: # Convert -1 labels to 0 for MLlib
35+
values[0] = 0
36+
return LabeledPoint(values[0], values[1:])
37+
38+
# Return accuracy of DecisionTreeModel on the given RDD[LabeledPoint].
39+
def getAccuracy(dtModel, data):
40+
seqOp = (lambda acc, x: acc + (x[0] == x[1]))
41+
trainCorrect = \
42+
dtModel.predict(data).zip(data.map((lambda p => p.label))).aggregate(0, seqOp, add)
43+
return trainCorrect / (0.0 + data.count())
44+
45+
46+
if __name__ == "__main__":
47+
if len(sys.argv) != 1:
48+
print >> sys.stderr, "Usage: logistic_regression"
49+
exit(-1)
50+
sc = SparkContext(appName="PythonDT")
51+
52+
# Load data.
53+
dataPath = 'data/mllib/sample_tree_data.csv'
54+
points = sc.textFile(dataPath).map(parsePoint)
55+
56+
# Train a classifier.
57+
model = DecisionTree.trainClassifier(points, numClasses=2)
58+
# Print learned tree.
59+
print "Model numNodes: " + model.numNodes() + "\n"
60+
print "Model depth: " + model.depth() + "\n"
61+
print model
62+
# Check accuracy.
63+
print "Training accuracy: " + getAccuracy(model, points) + "\n"
64+
65+
# Switch labels and first feature to create a regression dataset with categorical features.
66+
"""
67+
datasetInfo = DatasetInfo(numClasses=0, numFeatures=numFeatures)
68+
dtParams = DecisionTreeRegressor.defaultParams()
69+
model = DecisionTreeRegressor.train(points, datasetInfo, dtParams)
70+
# Print learned tree.
71+
print "Model numNodes: " + model.numNodes() + "\n"
72+
print "Model depth: " + model.depth() + "\n"
73+
print model
74+
# Check error.
75+
print "Training accuracy: " + getAccuracy(model, points) + "\n"
76+
"""

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,23 @@ package org.apache.spark.mllib.api.python
1919

2020
import java.nio.{ByteBuffer, ByteOrder}
2121

22+
import scala.collection.JavaConversions._
23+
2224
import org.apache.spark.annotation.DeveloperApi
2325
import org.apache.spark.api.java.{JavaSparkContext, JavaRDD}
2426
import org.apache.spark.mllib.classification._
2527
import org.apache.spark.mllib.clustering._
2628
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
2729
import org.apache.spark.mllib.recommendation._
2830
import org.apache.spark.mllib.regression._
31+
import org.apache.spark.mllib.tree.configuration.Algo._
32+
import org.apache.spark.mllib.tree.configuration.Strategy
33+
import org.apache.spark.mllib.tree.DecisionTree
34+
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
35+
import org.apache.spark.mllib.tree.model.DecisionTreeModel
2936
import org.apache.spark.mllib.util.MLUtils
3037
import org.apache.spark.rdd.RDD
38+
import org.apache.spark.util.Utils
3139

3240
/**
3341
* :: DeveloperApi ::
@@ -453,4 +461,75 @@ class PythonMLLibAPI extends Serializable {
453461
val ratings = ratingsBytesJRDD.rdd.map(unpackRating)
454462
ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha)
455463
}
464+
465+
/**
466+
* Java stub for Python mllib DecisionTree.train().
467+
* This stub returns a handle to the Java object instead of the content of the Java object.
468+
* Extra care needs to be taken in the Python code to ensure it gets freed on exit;
469+
* see the Py4J documentation.
470+
* @param dataBytesJRDD Training data
471+
* @param categoricalFeaturesInfoJMap Categorical features info, as Java map
472+
*/
473+
def trainDecisionTreeModel(
474+
dataBytesJRDD: JavaRDD[Array[Byte]],
475+
algoStr: String,
476+
numClasses: Int,
477+
categoricalFeaturesInfoJMap: java.util.Map[Int,Int],
478+
impurityStr: String,
479+
maxDepth: Int,
480+
maxBins: Int): DecisionTreeModel = {
481+
482+
val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint)
483+
484+
val algo: Algo = algoStr match {
485+
case "classification" => Classification
486+
case "regression" => Regression
487+
case _ => throw new IllegalArgumentException(s"Bad algoStr parameter: $algoStr")
488+
}
489+
val impurity: Impurity = impurityStr match {
490+
case "gini" => Gini
491+
case "entropy" => Entropy
492+
case "variance" => Variance
493+
case _ => throw new IllegalArgumentException(s"Bad impurityStr parameter: $impurityStr")
494+
}
495+
496+
val strategy = new Strategy(
497+
algo = algo,
498+
impurity = impurity,
499+
maxDepth = maxDepth,
500+
numClassesForClassification = numClasses,
501+
maxBins = maxBins,
502+
categoricalFeaturesInfo = categoricalFeaturesInfoJMap.toMap)
503+
504+
DecisionTree.train(data, strategy)
505+
}
506+
507+
/**
508+
* Predict the label of the given data point.
509+
* This is a Java stub for python DecisionTreeModel.predict()
510+
*
511+
* @param featuresBytes Serialized feature vector for data point
512+
* @return predicted label
513+
*/
514+
def predictDecisionTreeModel(
515+
model: DecisionTreeModel,
516+
featuresBytes: Array[Byte]): Double = {
517+
val features: Vector = deserializeDoubleVector(featuresBytes)
518+
model.predict(features)
519+
}
520+
521+
/**
522+
* Predict the labels of the given data points.
523+
* This is a Java stub for python DecisionTreeModel.predict()
524+
*
525+
* @param dataJRDD A JavaRDD with serialized feature vectors
526+
* @return JavaRDD of serialized predictions
527+
*/
528+
def predictDecisionTreeModel(
529+
model: DecisionTreeModel,
530+
dataJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = {
531+
val data = dataJRDD.rdd.map(xBytes => deserializeDoubleVector(xBytes))
532+
model.predict(data).map(Utils.serialize(_))
533+
}
534+
456535
}

mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,32 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
5050
def predict(features: RDD[Vector]): RDD[Double] = {
5151
features.map(x => predict(x))
5252
}
53+
54+
/**
55+
* Get number of nodes in tree, including leaf nodes.
56+
*/
57+
def numNodes: Int = {
58+
topNode.numNodesRecursive
59+
}
60+
61+
/**
62+
* Get depth of tree.
63+
* E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes.
64+
*/
65+
def depth: Int = {
66+
topNode.depthRecursive
67+
}
68+
69+
/**
70+
* Print full model.
71+
*/
72+
override def toString: String = algo match {
73+
case Classification =>
74+
s"DecisionTreeModel classifier\n" + topNode.toStringRecursive(2)
75+
case Regression =>
76+
s"DecisionTreeModel regressor\n" + topNode.toStringRecursive(2)
77+
case _ => throw new IllegalArgumentException(
78+
s"DecisionTreeModel given unknown algo parameter: $algo.")
79+
}
80+
5381
}

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,4 +91,59 @@ class Node (
9191
}
9292
}
9393
}
94+
95+
/**
96+
* Get number of nodes in tree from this node, including leaf nodes.
97+
*/
98+
def numNodesRecursive: Int = {
99+
if (isLeaf) {
100+
1
101+
} else {
102+
1 + leftNode.get.numNodesRecursive + rightNode.get.numNodesRecursive
103+
}
104+
}
105+
106+
/**
107+
* Get depth of tree from this node.
108+
* E.g.: Depth 0 means this is a leaf node.
109+
*/
110+
def depthRecursive: Int = {
111+
if (isLeaf) {
112+
0
113+
} else {
114+
1 + math.max(leftNode.get.depthRecursive, rightNode.get.depthRecursive)
115+
}
116+
}
117+
118+
/**
119+
* Recursive print function.
120+
* @param indentFactor The number of spaces to add to each level of indentation.
121+
*/
122+
def toStringRecursive(indentFactor: Int = 0): String = {
123+
124+
def splitToString(split: Split, left: Boolean) : String = {
125+
split.featureType match {
126+
case Continuous => if (left) {
127+
s"(feature ${split.feature} <= ${split.threshold})"
128+
} else {
129+
s"(feature ${split.feature} > ${split.threshold})"
130+
}
131+
case Categorical => if (left) {
132+
s"(feature ${split.feature} in ${split.categories})"
133+
} else {
134+
s"(feature ${split.feature} not in ${split.categories})"
135+
}
136+
}
137+
}
138+
val prefix: String = " " * indentFactor
139+
if (isLeaf) {
140+
prefix + s"Predict: $predict\n"
141+
} else {
142+
prefix + s"If ${splitToString(split.get, left=true)}\n" +
143+
leftNode.get.toStringRecursive(indentFactor + 1) +
144+
prefix + s"Else ${splitToString(split.get, left=false)}\n" +
145+
rightNode.get.toStringRecursive(indentFactor + 1)
146+
}
147+
}
148+
94149
}

0 commit comments

Comments
 (0)