Skip to content

Commit 620d247

Browse files
author
Kazuki Taniguchi
committed
[SPARK-5094][MLlib] Add Python API for Gradient Boosted Trees
Check lint-python and lint-scala [SPARK-5094][MLlib] Add some key params for Gradient Boosted Trees in Python API Fix issues Fix some issues Fix the issues (for changing BoostingStrategy.defaultParams() in master) Fix the issues Added comments about loss functions
1 parent c66a976 commit 620d247

File tree

4 files changed

+318
-56
lines changed

4 files changed

+318
-56
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""
19+
Gradient boosted Trees classification and regression using MLlib.
20+
"""
21+
22+
import sys
23+
24+
from pyspark.context import SparkContext
25+
from pyspark.mllib.tree import GradientBoostedTrees
26+
from pyspark.mllib.util import MLUtils
27+
28+
29+
def testClassification(trainingData, testData):
30+
# Train a GradientBoostedTrees model.
31+
# Empty categoricalFeaturesInfo indicates all features are continuous.
32+
model = GradientBoostedTrees.trainClassifier(trainingData, categoricalFeaturesInfo={},
33+
numIterations=30, maxDepth=4)
34+
# Evaluate model on test instances and compute test error
35+
predictions = model.predict(testData.map(lambda x: x.features))
36+
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
37+
testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() \
38+
/ float(testData.count())
39+
print('Test Error = ' + str(testErr))
40+
print('Learned classification ensemble model:')
41+
print(model.toDebugString())
42+
43+
44+
def testRegression(trainingData, testData):
45+
# Train a GradientBoostedTrees model.
46+
# Empty categoricalFeaturesInfo indicates all features are continuous.
47+
model = GradientBoostedTrees.trainRegressor(trainingData, categoricalFeaturesInfo={},
48+
numIterations=30, maxDepth=4)
49+
# Evaluate model on test instances and compute test error
50+
predictions = model.predict(testData.map(lambda x: x.features))
51+
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
52+
testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() \
53+
/ float(testData.count())
54+
print('Test Mean Squared Error = ' + str(testMSE))
55+
print('Learned regression ensemble model:')
56+
print(model.toDebugString())
57+
58+
59+
if __name__ == "__main__":
60+
if len(sys.argv) > 1:
61+
print >> sys.stderr, "Usage: gradient_boosted_trees"
62+
exit(1)
63+
sc = SparkContext(appName="PythonGradientBoostedTrees")
64+
65+
# Load and parse the data file into an RDD of LabeledPoint.
66+
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
67+
# Split the data into training and test sets (30% held out for testing)
68+
(trainingData, testData) = data.randomSplit([0.7, 0.3])
69+
70+
print('\nRunning example of classification using GradientBoostedTrees\n')
71+
testClassification(trainingData, testData)
72+
73+
print('\nRunning example of regression using GradientBoostedTrees\n')
74+
testRegression(trainingData, testData)
75+
76+
sc.stop()

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,11 @@ import org.apache.spark.mllib.regression._
4141
import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
4242
import org.apache.spark.mllib.stat.correlation.CorrelationNames
4343
import org.apache.spark.mllib.stat.test.ChiSqTestResult
44-
import org.apache.spark.mllib.tree.{RandomForest, DecisionTree}
45-
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
44+
import org.apache.spark.mllib.tree.{GradientBoostedTrees, RandomForest, DecisionTree}
45+
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo, Strategy}
4646
import org.apache.spark.mllib.tree.impurity._
47-
import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel}
47+
import org.apache.spark.mllib.tree.loss.Losses
48+
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel, RandomForestModel, DecisionTreeModel}
4849
import org.apache.spark.mllib.util.MLUtils
4950
import org.apache.spark.rdd.RDD
5051
import org.apache.spark.storage.StorageLevel
@@ -528,6 +529,35 @@ class PythonMLLibAPI extends Serializable {
528529
}
529530
}
530531

532+
/**
533+
* Java stub for Python mllib GradientBoostedTrees.train().
534+
* This stub returns a handle to the Java object instead of the content of the Java object.
535+
* Extra care needs to be taken in the Python code to ensure it gets freed on exit;
536+
* see the Py4J documentation.
537+
*/
538+
def trainGradientBoostedTreesModel(
539+
data: JavaRDD[LabeledPoint],
540+
algoStr: String,
541+
categoricalFeaturesInfo: JMap[Int, Int],
542+
lossStr: String,
543+
numIterations: Int,
544+
learningRate: Double,
545+
maxDepth: Int): GradientBoostedTreesModel = {
546+
val boostingStrategy = BoostingStrategy.defaultParams(algoStr)
547+
boostingStrategy.setLoss(Losses.fromString(lossStr))
548+
boostingStrategy.setNumIterations(numIterations)
549+
boostingStrategy.setLearningRate(learningRate)
550+
boostingStrategy.treeStrategy.setMaxDepth(maxDepth)
551+
boostingStrategy.treeStrategy.categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap
552+
553+
val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK)
554+
try {
555+
GradientBoostedTrees.train(cached, boostingStrategy)
556+
} finally {
557+
cached.unpersist(blocking = false)
558+
}
559+
}
560+
531561
/**
532562
* Java stub for mllib Statistics.colStats(X: RDD[Vector]).
533563
* TODO figure out return type.

python/pyspark/mllib/tests.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def test_clustering(self):
142142

143143
def test_classification(self):
144144
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
145-
from pyspark.mllib.tree import DecisionTree
145+
from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees
146146
data = [
147147
LabeledPoint(0.0, [1, 0, 0]),
148148
LabeledPoint(1.0, [0, 1, 1]),
@@ -171,18 +171,31 @@ def test_classification(self):
171171
self.assertTrue(nb_model.predict(features[3]) > 0)
172172

173173
categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories
174-
dt_model = \
175-
DecisionTree.trainClassifier(rdd, numClasses=2,
176-
categoricalFeaturesInfo=categoricalFeaturesInfo)
174+
dt_model = DecisionTree.trainClassifier(
175+
rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo)
177176
self.assertTrue(dt_model.predict(features[0]) <= 0)
178177
self.assertTrue(dt_model.predict(features[1]) > 0)
179178
self.assertTrue(dt_model.predict(features[2]) <= 0)
180179
self.assertTrue(dt_model.predict(features[3]) > 0)
181180

181+
rf_model = RandomForest.trainClassifier(
182+
rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=100)
183+
self.assertTrue(rf_model.predict(features[0]) <= 0)
184+
self.assertTrue(rf_model.predict(features[1]) > 0)
185+
self.assertTrue(rf_model.predict(features[2]) <= 0)
186+
self.assertTrue(rf_model.predict(features[3]) > 0)
187+
188+
gbt_model = GradientBoostedTrees.trainClassifier(
189+
rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
190+
self.assertTrue(gbt_model.predict(features[0]) <= 0)
191+
self.assertTrue(gbt_model.predict(features[1]) > 0)
192+
self.assertTrue(gbt_model.predict(features[2]) <= 0)
193+
self.assertTrue(gbt_model.predict(features[3]) > 0)
194+
182195
def test_regression(self):
183196
from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \
184197
RidgeRegressionWithSGD
185-
from pyspark.mllib.tree import DecisionTree
198+
from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees
186199
data = [
187200
LabeledPoint(-1.0, [0, -1]),
188201
LabeledPoint(1.0, [0, 1]),
@@ -211,13 +224,27 @@ def test_regression(self):
211224
self.assertTrue(rr_model.predict(features[3]) > 0)
212225

213226
categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories
214-
dt_model = \
215-
DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
227+
dt_model = DecisionTree.trainRegressor(
228+
rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
216229
self.assertTrue(dt_model.predict(features[0]) <= 0)
217230
self.assertTrue(dt_model.predict(features[1]) > 0)
218231
self.assertTrue(dt_model.predict(features[2]) <= 0)
219232
self.assertTrue(dt_model.predict(features[3]) > 0)
220233

234+
rf_model = RandomForest.trainRegressor(
235+
rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=100)
236+
self.assertTrue(rf_model.predict(features[0]) <= 0)
237+
self.assertTrue(rf_model.predict(features[1]) > 0)
238+
self.assertTrue(rf_model.predict(features[2]) <= 0)
239+
self.assertTrue(rf_model.predict(features[3]) > 0)
240+
241+
gbt_model = GradientBoostedTrees.trainRegressor(
242+
rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
243+
self.assertTrue(gbt_model.predict(features[0]) <= 0)
244+
self.assertTrue(gbt_model.predict(features[1]) > 0)
245+
self.assertTrue(gbt_model.predict(features[2]) <= 0)
246+
self.assertTrue(gbt_model.predict(features[3]) > 0)
247+
221248

222249
class StatTests(PySparkTestCase):
223250
# SPARK-4023

0 commit comments

Comments
 (0)