Skip to content

Commit 50a4fa7

Browse files
committed
[SPARK-3443][MLLIB] update default values of tree:
Adjust the default values of decision tree, based on the memory requirement discussed in #2125 : 1. maxMemoryInMB: 128 -> 256 2. maxBins: 100 -> 32 3. maxDepth: 4 -> 5 (in some example code) jkbradley Author: Xiangrui Meng <[email protected]> Closes #2322 from mengxr/tree-defaults and squashes the following commits: cda453a [Xiangrui Meng] fix tests 5900445 [Xiangrui Meng] update comments 8c81831 [Xiangrui Meng] update default values of tree:
1 parent 7db5339 commit 50a4fa7

File tree

7 files changed

+24
-34
lines changed

7 files changed

+24
-34
lines changed

docs/mllib-decision-tree.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ The ordered splits create "bins" and the maximum number of such
8080
bins can be specified using the `maxBins` parameter.
8181

8282
Note that the number of bins cannot be greater than the number of instances `$N$` (a rare scenario
83-
since the default `maxBins` value is 100). The tree algorithm automatically reduces the number of
83+
since the default `maxBins` value is 32). The tree algorithm automatically reduces the number of
8484
bins if the condition is not satisfied.
8585

8686
**Categorical features**
@@ -117,7 +117,7 @@ all nodes at each level of the tree. This could lead to high memory requirements
117117
of the tree, potentially leading to memory overflow errors. To alleviate this problem, a `maxMemoryInMB`
118118
training parameter specifies the maximum amount of memory at the workers (twice as much at the
119119
master) to be allocated to the histogram computation. The default value is conservatively chosen to
120-
be 128 MB to allow the decision algorithm to work in most scenarios. Once the memory requirements
120+
be 256 MB to allow the decision algorithm to work in most scenarios. Once the memory requirements
121121
for a level-wise computation cross the `maxMemoryInMB` threshold, the node training tasks at each
122122
subsequent level are split into smaller tasks.
123123

@@ -167,7 +167,7 @@ val numClasses = 2
167167
val categoricalFeaturesInfo = Map[Int, Int]()
168168
val impurity = "gini"
169169
val maxDepth = 5
170-
val maxBins = 100
170+
val maxBins = 32
171171

172172
val model = DecisionTree.trainClassifier(data, numClasses, categoricalFeaturesInfo, impurity,
173173
maxDepth, maxBins)
@@ -213,7 +213,7 @@ Integer numClasses = 2;
213213
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
214214
String impurity = "gini";
215215
Integer maxDepth = 5;
216-
Integer maxBins = 100;
216+
Integer maxBins = 32;
217217

218218
// Train a DecisionTree model for classification.
219219
final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses,
@@ -250,7 +250,7 @@ data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt').cache()
250250
# Train a DecisionTree model.
251251
# Empty categoricalFeaturesInfo indicates all features are continuous.
252252
model = DecisionTree.trainClassifier(data, numClasses=2, categoricalFeaturesInfo={},
253-
impurity='gini', maxDepth=5, maxBins=100)
253+
impurity='gini', maxDepth=5, maxBins=32)
254254

255255
# Evaluate model on training instances and compute training error
256256
predictions = model.predict(data.map(lambda x: x.features))
@@ -293,7 +293,7 @@ val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").cache
293293
val categoricalFeaturesInfo = Map[Int, Int]()
294294
val impurity = "variance"
295295
val maxDepth = 5
296-
val maxBins = 100
296+
val maxBins = 32
297297

298298
val model = DecisionTree.trainRegressor(data, categoricalFeaturesInfo, impurity,
299299
maxDepth, maxBins)
@@ -338,7 +338,7 @@ JavaSparkContext sc = new JavaSparkContext(sparkConf);
338338
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
339339
String impurity = "variance";
340340
Integer maxDepth = 5;
341-
Integer maxBins = 100;
341+
Integer maxBins = 32;
342342

343343
// Train a DecisionTree model.
344344
final DecisionTreeModel model = DecisionTree.trainRegressor(data,
@@ -380,7 +380,7 @@ data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt').cache()
380380
# Train a DecisionTree model.
381381
# Empty categoricalFeaturesInfo indicates all features are continuous.
382382
model = DecisionTree.trainRegressor(data, categoricalFeaturesInfo={},
383-
impurity='variance', maxDepth=5, maxBins=100)
383+
impurity='variance', maxDepth=5, maxBins=32)
384384

385385
# Evaluate model on training instances and compute training error
386386
predictions = model.predict(data.map(lambda x: x.features))

examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ public static void main(String[] args) {
6363
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
6464
String impurity = "gini";
6565
Integer maxDepth = 5;
66-
Integer maxBins = 100;
66+
Integer maxBins = 32;
6767

6868
// Train a DecisionTree model for classification.
6969
final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses,

examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ object DecisionTreeRunner {
5252
input: String = null,
5353
dataFormat: String = "libsvm",
5454
algo: Algo = Classification,
55-
maxDepth: Int = 4,
55+
maxDepth: Int = 5,
5656
impurity: ImpurityType = Gini,
57-
maxBins: Int = 100,
57+
maxBins: Int = 32,
5858
fracTest: Double = 0.2)
5959

6060
def main(args: Array[String]) {

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -330,9 +330,9 @@ object DecisionTree extends Serializable with Logging {
330330
* Supported values: "gini" (recommended) or "entropy".
331331
* @param maxDepth Maximum depth of the tree.
332332
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
333-
* (suggested value: 4)
333+
* (suggested value: 5)
334334
* @param maxBins maximum number of bins used for splitting features
335-
* (suggested value: 100)
335+
* (suggested value: 32)
336336
* @return DecisionTreeModel that can be used for prediction
337337
*/
338338
def trainClassifier(
@@ -374,9 +374,9 @@ object DecisionTree extends Serializable with Logging {
374374
* Supported values: "variance".
375375
* @param maxDepth Maximum depth of the tree.
376376
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
377-
* (suggested value: 4)
377+
* (suggested value: 5)
378378
* @param maxBins maximum number of bins used for splitting features
379-
* (suggested value: 100)
379+
* (suggested value: 32)
380380
* @return DecisionTreeModel that can be used for prediction
381381
*/
382382
def trainRegressor(

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,18 +50,18 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
5050
* 1, 2, ... , k-1. It's important to note that features are
5151
* zero-indexed.
5252
* @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is
53-
* 128 MB.
53+
* 256 MB.
5454
*/
5555
@Experimental
5656
class Strategy (
5757
val algo: Algo,
5858
val impurity: Impurity,
5959
val maxDepth: Int,
6060
val numClassesForClassification: Int = 2,
61-
val maxBins: Int = 100,
61+
val maxBins: Int = 32,
6262
val quantileCalculationStrategy: QuantileStrategy = Sort,
6363
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
64-
val maxMemoryInMB: Int = 128) extends Serializable {
64+
val maxMemoryInMB: Int = 256) extends Serializable {
6565

6666
if (algo == Classification) {
6767
require(numClassesForClassification >= 2)

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
3131
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node}
3232
import org.apache.spark.mllib.util.LocalSparkContext
3333

34-
3534
class DecisionTreeSuite extends FunSuite with LocalSparkContext {
3635

3736
def validateClassifier(
@@ -353,8 +352,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
353352
assert(splits(0).length === 99)
354353
assert(bins.length === 2)
355354
assert(bins(0).length === 100)
356-
assert(splits(0).length === 99)
357-
assert(bins(0).length === 100)
358355

359356
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
360357
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
@@ -381,8 +378,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
381378
assert(splits(0).length === 99)
382379
assert(bins.length === 2)
383380
assert(bins(0).length === 100)
384-
assert(splits(0).length === 99)
385-
assert(bins(0).length === 100)
386381

387382
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
388383
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
@@ -410,8 +405,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
410405
assert(splits(0).length === 99)
411406
assert(bins.length === 2)
412407
assert(bins(0).length === 100)
413-
assert(splits(0).length === 99)
414-
assert(bins(0).length === 100)
415408

416409
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
417410
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
@@ -439,8 +432,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
439432
assert(splits(0).length === 99)
440433
assert(bins.length === 2)
441434
assert(bins(0).length === 100)
442-
assert(splits(0).length === 99)
443-
assert(bins(0).length === 100)
444435

445436
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
446437
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
@@ -464,8 +455,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
464455
assert(splits(0).length === 99)
465456
assert(bins.length === 2)
466457
assert(bins(0).length === 100)
467-
assert(splits(0).length === 99)
468-
assert(bins(0).length === 100)
469458

470459
// Train a 1-node model
471460
val strategyOneNode = new Strategy(Classification, Entropy, 1, 2, 100)
@@ -600,7 +589,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
600589
val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
601590
val rdd = sc.parallelize(arr)
602591
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
603-
numClassesForClassification = 3)
592+
numClassesForClassification = 3, maxBins = 100)
604593
assert(strategy.isMulticlassClassification)
605594
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
606595

@@ -626,7 +615,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
626615
val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
627616
val rdd = sc.parallelize(arr)
628617
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
629-
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3))
618+
numClassesForClassification = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3))
630619
assert(strategy.isMulticlassClassification)
631620
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
632621
assert(metadata.isUnordered(featureIndex = 0))
@@ -652,7 +641,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
652641
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
653642
val rdd = sc.parallelize(arr)
654643
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
655-
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
644+
numClassesForClassification = 3, maxBins = 100,
645+
categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
656646
assert(strategy.isMulticlassClassification)
657647
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
658648
assert(!metadata.isUnordered(featureIndex = 0))

python/pyspark/mllib/tree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ class DecisionTree(object):
138138

139139
@staticmethod
140140
def trainClassifier(data, numClasses, categoricalFeaturesInfo,
141-
impurity="gini", maxDepth=4, maxBins=100):
141+
impurity="gini", maxDepth=5, maxBins=32):
142142
"""
143143
Train a DecisionTreeModel for classification.
144144
@@ -170,7 +170,7 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo,
170170

171171
@staticmethod
172172
def trainRegressor(data, categoricalFeaturesInfo,
173-
impurity="variance", maxDepth=4, maxBins=100):
173+
impurity="variance", maxDepth=5, maxBins=32):
174174
"""
175175
Train a DecisionTreeModel for regression.
176176

0 commit comments

Comments
 (0)