Skip to content

Commit bc1fc9b

Browse files
Kazuki Taniguchimengxr
authored andcommitted
[SPARK-5094][MLlib] Add Python API for Gradient Boosted Trees
This PR is implementing the Gradient Boosted Trees for Python API. Author: Kazuki Taniguchi <[email protected]> Closes apache#3951 from kazk1018/gbt_for_py and squashes the following commits: 620d247 [Kazuki Taniguchi] [SPARK-5094][MLlib] Add Python API for Gradient Boosted Trees
1 parent dd4d84c commit bc1fc9b

File tree

4 files changed

+318
-56
lines changed

4 files changed

+318
-56
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""
19+
Gradient boosted Trees classification and regression using MLlib.
20+
"""
21+
22+
import sys
23+
24+
from pyspark.context import SparkContext
25+
from pyspark.mllib.tree import GradientBoostedTrees
26+
from pyspark.mllib.util import MLUtils
27+
28+
29+
def testClassification(trainingData, testData):
30+
# Train a GradientBoostedTrees model.
31+
# Empty categoricalFeaturesInfo indicates all features are continuous.
32+
model = GradientBoostedTrees.trainClassifier(trainingData, categoricalFeaturesInfo={},
33+
numIterations=30, maxDepth=4)
34+
# Evaluate model on test instances and compute test error
35+
predictions = model.predict(testData.map(lambda x: x.features))
36+
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
37+
testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() \
38+
/ float(testData.count())
39+
print('Test Error = ' + str(testErr))
40+
print('Learned classification ensemble model:')
41+
print(model.toDebugString())
42+
43+
44+
def testRegression(trainingData, testData):
45+
# Train a GradientBoostedTrees model.
46+
# Empty categoricalFeaturesInfo indicates all features are continuous.
47+
model = GradientBoostedTrees.trainRegressor(trainingData, categoricalFeaturesInfo={},
48+
numIterations=30, maxDepth=4)
49+
# Evaluate model on test instances and compute test error
50+
predictions = model.predict(testData.map(lambda x: x.features))
51+
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
52+
testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() \
53+
/ float(testData.count())
54+
print('Test Mean Squared Error = ' + str(testMSE))
55+
print('Learned regression ensemble model:')
56+
print(model.toDebugString())
57+
58+
59+
if __name__ == "__main__":
60+
if len(sys.argv) > 1:
61+
print >> sys.stderr, "Usage: gradient_boosted_trees"
62+
exit(1)
63+
sc = SparkContext(appName="PythonGradientBoostedTrees")
64+
65+
# Load and parse the data file into an RDD of LabeledPoint.
66+
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
67+
# Split the data into training and test sets (30% held out for testing)
68+
(trainingData, testData) = data.randomSplit([0.7, 0.3])
69+
70+
print('\nRunning example of classification using GradientBoostedTrees\n')
71+
testClassification(trainingData, testData)
72+
73+
print('\nRunning example of regression using GradientBoostedTrees\n')
74+
testRegression(trainingData, testData)
75+
76+
sc.stop()

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,11 @@ import org.apache.spark.mllib.regression._
4141
import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
4242
import org.apache.spark.mllib.stat.correlation.CorrelationNames
4343
import org.apache.spark.mllib.stat.test.ChiSqTestResult
44-
import org.apache.spark.mllib.tree.{RandomForest, DecisionTree}
45-
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
44+
import org.apache.spark.mllib.tree.{GradientBoostedTrees, RandomForest, DecisionTree}
45+
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo, Strategy}
4646
import org.apache.spark.mllib.tree.impurity._
47-
import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel}
47+
import org.apache.spark.mllib.tree.loss.Losses
48+
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel, RandomForestModel, DecisionTreeModel}
4849
import org.apache.spark.mllib.util.MLUtils
4950
import org.apache.spark.rdd.RDD
5051
import org.apache.spark.storage.StorageLevel
@@ -532,6 +533,35 @@ class PythonMLLibAPI extends Serializable {
532533
}
533534
}
534535

536+
/**
537+
* Java stub for Python mllib GradientBoostedTrees.train().
538+
* This stub returns a handle to the Java object instead of the content of the Java object.
539+
* Extra care needs to be taken in the Python code to ensure it gets freed on exit;
540+
* see the Py4J documentation.
541+
*/
542+
def trainGradientBoostedTreesModel(
543+
data: JavaRDD[LabeledPoint],
544+
algoStr: String,
545+
categoricalFeaturesInfo: JMap[Int, Int],
546+
lossStr: String,
547+
numIterations: Int,
548+
learningRate: Double,
549+
maxDepth: Int): GradientBoostedTreesModel = {
550+
val boostingStrategy = BoostingStrategy.defaultParams(algoStr)
551+
boostingStrategy.setLoss(Losses.fromString(lossStr))
552+
boostingStrategy.setNumIterations(numIterations)
553+
boostingStrategy.setLearningRate(learningRate)
554+
boostingStrategy.treeStrategy.setMaxDepth(maxDepth)
555+
boostingStrategy.treeStrategy.categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap
556+
557+
val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK)
558+
try {
559+
GradientBoostedTrees.train(cached, boostingStrategy)
560+
} finally {
561+
cached.unpersist(blocking = false)
562+
}
563+
}
564+
535565
/**
536566
* Java stub for mllib Statistics.colStats(X: RDD[Vector]).
537567
* TODO figure out return type.

python/pyspark/mllib/tests.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def test_kmeans_deterministic(self):
169169

170170
def test_classification(self):
171171
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
172-
from pyspark.mllib.tree import DecisionTree
172+
from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees
173173
data = [
174174
LabeledPoint(0.0, [1, 0, 0]),
175175
LabeledPoint(1.0, [0, 1, 1]),
@@ -198,18 +198,31 @@ def test_classification(self):
198198
self.assertTrue(nb_model.predict(features[3]) > 0)
199199

200200
categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories
201-
dt_model = \
202-
DecisionTree.trainClassifier(rdd, numClasses=2,
203-
categoricalFeaturesInfo=categoricalFeaturesInfo)
201+
dt_model = DecisionTree.trainClassifier(
202+
rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo)
204203
self.assertTrue(dt_model.predict(features[0]) <= 0)
205204
self.assertTrue(dt_model.predict(features[1]) > 0)
206205
self.assertTrue(dt_model.predict(features[2]) <= 0)
207206
self.assertTrue(dt_model.predict(features[3]) > 0)
208207

208+
rf_model = RandomForest.trainClassifier(
209+
rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=100)
210+
self.assertTrue(rf_model.predict(features[0]) <= 0)
211+
self.assertTrue(rf_model.predict(features[1]) > 0)
212+
self.assertTrue(rf_model.predict(features[2]) <= 0)
213+
self.assertTrue(rf_model.predict(features[3]) > 0)
214+
215+
gbt_model = GradientBoostedTrees.trainClassifier(
216+
rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
217+
self.assertTrue(gbt_model.predict(features[0]) <= 0)
218+
self.assertTrue(gbt_model.predict(features[1]) > 0)
219+
self.assertTrue(gbt_model.predict(features[2]) <= 0)
220+
self.assertTrue(gbt_model.predict(features[3]) > 0)
221+
209222
def test_regression(self):
210223
from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \
211224
RidgeRegressionWithSGD
212-
from pyspark.mllib.tree import DecisionTree
225+
from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees
213226
data = [
214227
LabeledPoint(-1.0, [0, -1]),
215228
LabeledPoint(1.0, [0, 1]),
@@ -238,13 +251,27 @@ def test_regression(self):
238251
self.assertTrue(rr_model.predict(features[3]) > 0)
239252

240253
categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories
241-
dt_model = \
242-
DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
254+
dt_model = DecisionTree.trainRegressor(
255+
rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
243256
self.assertTrue(dt_model.predict(features[0]) <= 0)
244257
self.assertTrue(dt_model.predict(features[1]) > 0)
245258
self.assertTrue(dt_model.predict(features[2]) <= 0)
246259
self.assertTrue(dt_model.predict(features[3]) > 0)
247260

261+
rf_model = RandomForest.trainRegressor(
262+
rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=100)
263+
self.assertTrue(rf_model.predict(features[0]) <= 0)
264+
self.assertTrue(rf_model.predict(features[1]) > 0)
265+
self.assertTrue(rf_model.predict(features[2]) <= 0)
266+
self.assertTrue(rf_model.predict(features[3]) > 0)
267+
268+
gbt_model = GradientBoostedTrees.trainRegressor(
269+
rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
270+
self.assertTrue(gbt_model.predict(features[0]) <= 0)
271+
self.assertTrue(gbt_model.predict(features[1]) > 0)
272+
self.assertTrue(gbt_model.predict(features[2]) <= 0)
273+
self.assertTrue(gbt_model.predict(features[3]) > 0)
274+
248275

249276
class StatTests(PySparkTestCase):
250277
# SPARK-4023

0 commit comments

Comments
 (0)