Skip to content

Commit 8d6ac2b

Browse files
jkbradleymengxr
authored andcommitted
[SPARK-2478] [mllib] DecisionTree Python API
Added experimental Python API for Decision Trees. API: * class DecisionTreeModel ** predict() for single examples and RDDs, taking both feature vectors and LabeledPoints ** numNodes() ** depth() ** __str__() * class DecisionTree ** trainClassifier() ** trainRegressor() ** train() Examples and testing: * Added example testing classification and regression with batch prediction: examples/src/main/python/mllib/tree.py * Have also tested example usage in doc of python/pyspark/mllib/tree.py which tests single-example prediction with dense and sparse vectors Also: Small bug fix in python/pyspark/mllib/_common.py: In _linear_predictor_typecheck, changed check for RDD to use isinstance() instead of type() in order to catch RDD subclasses. CC mengxr manishamde Author: Joseph K. Bradley <[email protected]> Closes #1727 from jkbradley/decisiontree-python-new and squashes the following commits: 3744488 [Joseph K. Bradley] Renamed test tree.py to decision_tree_runner.py Small updates based on github review. 6b86a9d [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new affceb9 [Joseph K. Bradley] * Fixed bug in doc tests in pyspark/mllib/util.py caused by change in loadLibSVMFile behavior. (It used to threshold labels at 0 to make them 0/1, but it now leaves them as they are.) * Fixed small bug in loadLibSVMFile: If a data file had no features, then loadLibSVMFile would create a single all-zero feature. 67a29bc [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new cf46ad7 [Joseph K. Bradley] Python DecisionTreeModel * predict(empty RDD) returns an empty RDD instead of an error. * Removed support for calling predict() on LabeledPoint and RDD[LabeledPoint] * predict() does not cache serialized RDD any more. aa29873 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new bf21be4 [Joseph K. Bradley] removed old run() func from DecisionTree fa10ea7 [Joseph K. Bradley] Small style update 7968692 [Joseph K. Bradley] small braces typo fix e34c263 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new 4801b40 [Joseph K. Bradley] Small style update to DecisionTreeSuite db0eab2 [Joseph K. Bradley] Merge branch 'decisiontree-bugfix2' into decisiontree-python-new 6873fa9 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new 225822f [Joseph K. Bradley] Bug: In DecisionTree, the method sequentialBinSearchForOrderedCategoricalFeatureInClassification() indexed bins from 0 to (math.pow(2, featureCategories.toInt - 1) - 1). This upper bound is the bound for unordered categorical features, not ordered ones. The upper bound should be the arity (i.e., max value) of the feature. 93953f1 [Joseph K. Bradley] Likely done with Python API. 6df89a9 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new 4562c08 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new 665ba78 [Joseph K. Bradley] Small updates towards Python DecisionTree API 188cb0d [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new 6622247 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new b8fac57 [Joseph K. Bradley] Finished Python DecisionTree API and example but need to test a bit more. 2b20c61 [Joseph K. Bradley] Small doc and style updates 1b29c13 [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new 584449a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new dab0b67 [Joseph K. Bradley] Added documentation for DecisionTree internals 8bb8aa0 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix 978cfcf [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix 6eed482 [Joseph K. Bradley] In DecisionTree: Changed from using procedural syntax for functions returning Unit to explicitly writing Unit return type. 376dca2 [Joseph K. Bradley] Updated meaning of maxDepth by 1 to fit scikit-learn and rpart. * In code, replaced usages of maxDepth <-- maxDepth + 1 * In params, replace settings of maxDepth <-- maxDepth - 1 e06e423 [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new bab3f19 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new 59750f8 [Joseph K. Bradley] * Updated Strategy to check numClassesForClassification only if algo=Classification. * Updates based on comments: ** DecisionTreeRunner *** Made dataFormat arg default to libsvm ** Small cleanups ** tree.Node: Made recursive helper methods private, and renamed them. 52e17c5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix f5a036c [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new da50db7 [Joseph K. Bradley] Added one more test to DecisionTreeSuite: stump with 2 continuous variables for binary classification. Caused problems in past, but fixed now. 8e227ea [Joseph K. Bradley] Changed Strategy so it only requires numClassesForClassification >= 2 for classification cd1d933 [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new 8ea8750 [Joseph K. Bradley] Bug fix: Off-by-1 when finding thresholds for splits for continuous features. 8a758db [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new 5fe44ed [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new 2283df8 [Joseph K. Bradley] 2 bug fixes. 73fbea2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix 5f920a1 [Joseph K. Bradley] Demonstration of bug before submitting fix: Updated DecisionTreeSuite so that 3 tests fail. Will describe bug in next commit. f825352 [Joseph K. Bradley] Wrote Python API and example for DecisionTree. Also added toString, depth, and numNodes methods to DecisionTreeModel. (cherry picked from commit 3f67382) Signed-off-by: Xiangrui Meng <[email protected]>
1 parent e221108 commit 8d6ac2b

File tree

10 files changed

+509
-21
lines changed

10 files changed

+509
-21
lines changed
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""
19+
Decision tree classification and regression using MLlib.
20+
"""
21+
22+
import numpy, os, sys
23+
24+
from operator import add
25+
26+
from pyspark import SparkContext
27+
from pyspark.mllib.regression import LabeledPoint
28+
from pyspark.mllib.tree import DecisionTree
29+
from pyspark.mllib.util import MLUtils
30+
31+
32+
def getAccuracy(dtModel, data):
33+
"""
34+
Return accuracy of DecisionTreeModel on the given RDD[LabeledPoint].
35+
"""
36+
seqOp = (lambda acc, x: acc + (x[0] == x[1]))
37+
predictions = dtModel.predict(data.map(lambda x: x.features))
38+
truth = data.map(lambda p: p.label)
39+
trainCorrect = predictions.zip(truth).aggregate(0, seqOp, add)
40+
if data.count() == 0:
41+
return 0
42+
return trainCorrect / (0.0 + data.count())
43+
44+
45+
def getMSE(dtModel, data):
46+
"""
47+
Return mean squared error (MSE) of DecisionTreeModel on the given
48+
RDD[LabeledPoint].
49+
"""
50+
seqOp = (lambda acc, x: acc + numpy.square(x[0] - x[1]))
51+
predictions = dtModel.predict(data.map(lambda x: x.features))
52+
truth = data.map(lambda p: p.label)
53+
trainMSE = predictions.zip(truth).aggregate(0, seqOp, add)
54+
if data.count() == 0:
55+
return 0
56+
return trainMSE / (0.0 + data.count())
57+
58+
59+
def reindexClassLabels(data):
60+
"""
61+
Re-index class labels in a dataset to the range {0,...,numClasses-1}.
62+
If all labels in that range already appear at least once,
63+
then the returned RDD is the same one (without a mapping).
64+
Note: If a label simply does not appear in the data,
65+
the index will not include it.
66+
Be aware of this when reindexing subsampled data.
67+
:param data: RDD of LabeledPoint where labels are integer values
68+
denoting labels for a classification problem.
69+
:return: Pair (reindexedData, origToNewLabels) where
70+
reindexedData is an RDD of LabeledPoint with labels in
71+
the range {0,...,numClasses-1}, and
72+
origToNewLabels is a dictionary mapping original labels
73+
to new labels.
74+
"""
75+
# classCounts: class --> # examples in class
76+
classCounts = data.map(lambda x: x.label).countByValue()
77+
numExamples = sum(classCounts.values())
78+
sortedClasses = sorted(classCounts.keys())
79+
numClasses = len(classCounts)
80+
# origToNewLabels: class --> index in 0,...,numClasses-1
81+
if (numClasses < 2):
82+
print >> sys.stderr, \
83+
"Dataset for classification should have at least 2 classes." + \
84+
" The given dataset had only %d classes." % numClasses
85+
exit(1)
86+
origToNewLabels = dict([(sortedClasses[i], i) for i in range(0, numClasses)])
87+
88+
print "numClasses = %d" % numClasses
89+
print "Per-class example fractions, counts:"
90+
print "Class\tFrac\tCount"
91+
for c in sortedClasses:
92+
frac = classCounts[c] / (numExamples + 0.0)
93+
print "%g\t%g\t%d" % (c, frac, classCounts[c])
94+
95+
if (sortedClasses[0] == 0 and sortedClasses[-1] == numClasses - 1):
96+
return (data, origToNewLabels)
97+
else:
98+
reindexedData = \
99+
data.map(lambda x: LabeledPoint(origToNewLabels[x.label], x.features))
100+
return (reindexedData, origToNewLabels)
101+
102+
103+
def usage():
104+
print >> sys.stderr, \
105+
"Usage: decision_tree_runner [libsvm format data filepath]\n" + \
106+
" Note: This only supports binary classification."
107+
exit(1)
108+
109+
110+
if __name__ == "__main__":
111+
if len(sys.argv) > 2:
112+
usage()
113+
sc = SparkContext(appName="PythonDT")
114+
115+
# Load data.
116+
dataPath = 'data/mllib/sample_libsvm_data.txt'
117+
if len(sys.argv) == 2:
118+
dataPath = sys.argv[1]
119+
if not os.path.isfile(dataPath):
120+
usage()
121+
points = MLUtils.loadLibSVMFile(sc, dataPath)
122+
123+
# Re-index class labels if needed.
124+
(reindexedData, origToNewLabels) = reindexClassLabels(points)
125+
126+
# Train a classifier.
127+
model = DecisionTree.trainClassifier(reindexedData, numClasses=2)
128+
# Print learned tree and stats.
129+
print "Trained DecisionTree for classification:"
130+
print " Model numNodes: %d\n" % model.numNodes()
131+
print " Model depth: %d\n" % model.depth()
132+
print " Training accuracy: %g\n" % getAccuracy(model, reindexedData)
133+
print model

examples/src/main/python/mllib/logistic_regression.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@
3030
from pyspark.mllib.classification import LogisticRegressionWithSGD
3131

3232

33-
# Parse a line of text into an MLlib LabeledPoint object
3433
def parsePoint(line):
34+
"""
35+
Parse a line of text into an MLlib LabeledPoint object.
36+
"""
3537
values = [float(s) for s in line.split(' ')]
3638
if values[0] == -1: # Convert -1 labels to 0 for MLlib
3739
values[0] = 0

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package org.apache.spark.mllib.api.python
1919

2020
import java.nio.{ByteBuffer, ByteOrder}
2121

22+
import scala.collection.JavaConverters._
23+
2224
import org.apache.spark.annotation.DeveloperApi
2325
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
2426
import org.apache.spark.mllib.classification._
@@ -29,6 +31,11 @@ import org.apache.spark.mllib.linalg.{Matrix, SparseVector, Vector, Vectors}
2931
import org.apache.spark.mllib.random.{RandomRDDGenerators => RG}
3032
import org.apache.spark.mllib.recommendation._
3133
import org.apache.spark.mllib.regression._
34+
import org.apache.spark.mllib.tree.configuration.Algo._
35+
import org.apache.spark.mllib.tree.configuration.Strategy
36+
import org.apache.spark.mllib.tree.DecisionTree
37+
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
38+
import org.apache.spark.mllib.tree.model.DecisionTreeModel
3239
import org.apache.spark.mllib.stat.Statistics
3340
import org.apache.spark.mllib.stat.correlation.CorrelationNames
3441
import org.apache.spark.mllib.util.MLUtils
@@ -472,6 +479,76 @@ class PythonMLLibAPI extends Serializable {
472479
ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha)
473480
}
474481

482+
/**
483+
* Java stub for Python mllib DecisionTree.train().
484+
* This stub returns a handle to the Java object instead of the content of the Java object.
485+
* Extra care needs to be taken in the Python code to ensure it gets freed on exit;
486+
* see the Py4J documentation.
487+
* @param dataBytesJRDD Training data
488+
* @param categoricalFeaturesInfoJMap Categorical features info, as Java map
489+
*/
490+
def trainDecisionTreeModel(
491+
dataBytesJRDD: JavaRDD[Array[Byte]],
492+
algoStr: String,
493+
numClasses: Int,
494+
categoricalFeaturesInfoJMap: java.util.Map[Int, Int],
495+
impurityStr: String,
496+
maxDepth: Int,
497+
maxBins: Int): DecisionTreeModel = {
498+
499+
val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint)
500+
501+
val algo: Algo = algoStr match {
502+
case "classification" => Classification
503+
case "regression" => Regression
504+
case _ => throw new IllegalArgumentException(s"Bad algoStr parameter: $algoStr")
505+
}
506+
val impurity: Impurity = impurityStr match {
507+
case "gini" => Gini
508+
case "entropy" => Entropy
509+
case "variance" => Variance
510+
case _ => throw new IllegalArgumentException(s"Bad impurityStr parameter: $impurityStr")
511+
}
512+
513+
val strategy = new Strategy(
514+
algo = algo,
515+
impurity = impurity,
516+
maxDepth = maxDepth,
517+
numClassesForClassification = numClasses,
518+
maxBins = maxBins,
519+
categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap)
520+
521+
DecisionTree.train(data, strategy)
522+
}
523+
524+
/**
525+
* Predict the label of the given data point.
526+
* This is a Java stub for python DecisionTreeModel.predict()
527+
*
528+
* @param featuresBytes Serialized feature vector for data point
529+
* @return predicted label
530+
*/
531+
def predictDecisionTreeModel(
532+
model: DecisionTreeModel,
533+
featuresBytes: Array[Byte]): Double = {
534+
val features: Vector = deserializeDoubleVector(featuresBytes)
535+
model.predict(features)
536+
}
537+
538+
/**
539+
* Predict the labels of the given data points.
540+
* This is a Java stub for python DecisionTreeModel.predict()
541+
*
542+
* @param dataJRDD A JavaRDD with serialized feature vectors
543+
* @return JavaRDD of serialized predictions
544+
*/
545+
def predictDecisionTreeModel(
546+
model: DecisionTreeModel,
547+
dataJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = {
548+
val data = dataJRDD.rdd.map(xBytes => deserializeDoubleVector(xBytes))
549+
model.predict(data).map(serializeDouble)
550+
}
551+
475552
/**
476553
* Java stub for mllib Statistics.corr(X: RDD[Vector], method: String).
477554
* Returns the correlation matrix serialized into a byte array understood by deserializers in
@@ -597,4 +674,5 @@ class PythonMLLibAPI extends Serializable {
597674
val s = getSeedOrDefault(seed)
598675
RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts, s).map(serializeDoubleVector)
599676
}
677+
600678
}

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ class Strategy (
5656
if (algo == Classification) {
5757
require(numClassesForClassification >= 2)
5858
}
59-
val isMulticlassClassification = numClassesForClassification > 2
59+
val isMulticlassClassification =
60+
algo == Classification && numClassesForClassification > 2
6061
val isMulticlassWithCategoricalFeatures
6162
= isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
6263

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
4848
requiredMSE: Double) {
4949
val predictions = input.map(x => model.predict(x.features))
5050
val squaredError = predictions.zip(input).map { case (prediction, expected) =>
51-
(prediction - expected.label) * (prediction - expected.label)
51+
val err = prediction - expected.label
52+
err * err
5253
}.sum
5354
val mse = squaredError / input.length
5455
assert(mse <= requiredMSE)

python/pyspark/mllib/_common.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -343,22 +343,35 @@ def _copyto(array, buffer, offset, shape, dtype):
343343
temp_array[...] = array
344344

345345

346-
def _get_unmangled_rdd(data, serializer):
346+
def _get_unmangled_rdd(data, serializer, cache=True):
347+
"""
348+
:param cache: If True, the serialized RDD is cached. (default = True)
349+
WARNING: Users should unpersist() this later!
350+
"""
347351
dataBytes = data.map(serializer)
348352
dataBytes._bypass_serializer = True
349-
dataBytes.cache() # TODO: users should unpersist() this later!
353+
if cache:
354+
dataBytes.cache()
350355
return dataBytes
351356

352357

353-
# Map a pickled Python RDD of Python dense or sparse vectors to a Java RDD of
354-
# _serialized_double_vectors
355-
def _get_unmangled_double_vector_rdd(data):
356-
return _get_unmangled_rdd(data, _serialize_double_vector)
358+
def _get_unmangled_double_vector_rdd(data, cache=True):
359+
"""
360+
Map a pickled Python RDD of Python dense or sparse vectors to a Java RDD of
361+
_serialized_double_vectors.
362+
:param cache: If True, the serialized RDD is cached. (default = True)
363+
WARNING: Users should unpersist() this later!
364+
"""
365+
return _get_unmangled_rdd(data, _serialize_double_vector, cache)
357366

358367

359-
# Map a pickled Python RDD of LabeledPoint to a Java RDD of _serialized_labeled_points
360-
def _get_unmangled_labeled_point_rdd(data):
361-
return _get_unmangled_rdd(data, _serialize_labeled_point)
368+
def _get_unmangled_labeled_point_rdd(data, cache=True):
369+
"""
370+
Map a pickled Python RDD of LabeledPoint to a Java RDD of _serialized_labeled_points.
371+
:param cache: If True, the serialized RDD is cached. (default = True)
372+
WARNING: Users should unpersist() this later!
373+
"""
374+
return _get_unmangled_rdd(data, _serialize_labeled_point, cache)
362375

363376

364377
# Common functions for dealing with and training linear models
@@ -380,7 +393,7 @@ def _linear_predictor_typecheck(x, coeffs):
380393
if x.size != coeffs.shape[0]:
381394
raise RuntimeError("Got sparse vector of size %d; wanted %d" % (
382395
x.size, coeffs.shape[0]))
383-
elif (type(x) == RDD):
396+
elif isinstance(x, RDD):
384397
raise RuntimeError("Bulk predict not yet supported.")
385398
else:
386399
raise TypeError("Argument of type " + type(x).__name__ + " unsupported")

0 commit comments

Comments
 (0)