@@ -12,7 +12,7 @@ and their ensembles are popular methods for the machine learning tasks of
1212classification and regression. Decision trees are widely used since they are easy to interpret,
1313handle categorical features, extend to the multiclass classification setting, do not require
1414feature scaling and are able to capture nonlinearities and feature interactions. Tree ensemble
15- algorithms such as decision forests and boosting are among the top performers for classification and
15+ algorithms such as random forests and boosting are among the top performers for classification and
1616regression tasks.
1717
1818MLlib supports decision trees for binary and multiclass classification and for regression,
@@ -94,13 +94,13 @@ Section 9.2.4 in
9494details). For example, for a binary classification problem with one categorical feature with three
9595categories A, B and C whose corresponding proportions of label 1 are 0.2, 0.6 and 0.4, the categorical
9696features are ordered as A, C, B. The two split candidates are A \| C, B
97- and A , C \| B where \| denotes the split. A similar heuristic is used for multiclass classification
98- when ` $2^{M-1}-1$ ` is greater than the ` maxBins ` parameter: the impurity for each categorical feature value
99- is used for ordering. In multiclass classification, all ` $2^{M-1}-1$ ` possible splits are used
100- whenever possible.
97+ and A , C \| B where \| denotes the split.
10198
102- Note that the ` maxBins ` parameter must be at least ` $M_{max}$ ` , the maximum number of categories for
103- any categorical feature.
99+ In multiclass classification, all ` $2^{M-1}-1$ ` possible splits are used whenever possible.
100+ When ` $2^{M-1}-1$ ` is greater than the ` maxBins ` parameter, we use a (heuristic) method
101+ similar to the method used for binary classification and regression.
102+ The ` $M$ ` categorical feature values are ordered by impurity,
103+ and the resulting ` $M-1$ ` split candidates are considered.
104104
105105### Stopping rule
106106
@@ -109,6 +109,8 @@ The recursive tree construction is stopped at a node when one of the two conditi
1091091 . The node depth is equal to the ` maxDepth ` training parameter.
1101102 . No split candidate leads to an information gain at the node.
111111
112+ ## Implementation details
113+
112114### Max memory requirements
113115
114116For faster processing, the decision tree algorithm performs simultaneous histogram computations for
@@ -120,11 +122,24 @@ be 128 MB to allow the decision algorithm to work in most scenarios. Once the me
120122for a level-wise computation cross the ` maxMemoryInMB ` threshold, the node training tasks at each
121123subsequent level are split into smaller tasks.
122124
123- ### Practical limitations
125+ Note that, if you have a large amount of memory, increasing ` maxMemoryInMB ` can lead to faster
126+ training by requiring fewer passes over the data.
127+
128+ ### Binning feature values
129+
130+ Increasing ` maxBins ` allows the algorithm to consider more split candidates and make fine-grained
131+ split decisions. However, it also increases computation and communication.
132+
133+ Note that the ` maxBins ` parameter must be at least the maximum number of categories ` $M$ ` for
134+ any categorical feature.
135+
136+ ### Scaling
137+
138+ Computation scales approximately linearly in the number of training instances,
139+ in the number of features, and in the ` maxBins ` parameter.
140+ Communication scales approximately linearly in the number of features and in ` maxBins ` .
124141
125- 1 . The implemented algorithm reads both sparse and dense data. However, it is not optimized for sparse input.
126- 2 . Computation scales approximately linearly in the number of training instances,
127- in the number of features, and in the ` maxBins ` parameter.
142+ The implemented algorithm reads both sparse and dense data. However, it is not optimized for sparse input.
128143
129144## Examples
130145
@@ -143,8 +158,9 @@ maximum tree depth of 5. The training error is calculated to measure the algorit
143158import org.apache.spark.mllib.tree.DecisionTree
144159import org.apache.spark.mllib.util.MLUtils
145160
146- // Load and parse the data file
147- val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
161+ // Load and parse the data file.
162+ // Cache the data since we will use it again to compute training error.
163+ val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").cache()
148164
149165// Train a DecisionTree model.
150166// Empty categoricalFeaturesInfo indicates all features are continuous.
@@ -187,17 +203,14 @@ import org.apache.spark.SparkConf;
187203SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree");
188204JavaSparkContext sc = new JavaSparkContext(sparkConf);
189205
206+ // Load and parse the data file.
207+ // Cache the data since we will use it again to compute training error.
190208String datapath = "data/mllib/sample_libsvm_data.txt";
191209JavaRDD<LabeledPoint > data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
192- // Compute the number of classes from the data.
193- Integer numClasses = data.map(new Function<LabeledPoint, Double>() {
194- @Override public Double call(LabeledPoint p) {
195- return p.label();
196- }
197- }).countByValue().size();
198210
199211// Set parameters.
200212// Empty categoricalFeaturesInfo indicates all features are continuous.
213+ Integer numClasses = 2;
201214HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
202215String impurity = "gini";
203216Integer maxDepth = 5;
@@ -231,8 +244,9 @@ from pyspark.mllib.regression import LabeledPoint
231244from pyspark.mllib.tree import DecisionTree
232245from pyspark.mllib.util import MLUtils
233246
234- # an RDD of LabeledPoint
235- data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
247+ # Load and parse the data file into an RDD of LabeledPoint.
248+ # Cache the data since we will use it again to compute training error.
249+ data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt').cache()
236250
237251# Train a DecisionTree model.
238252# Empty categoricalFeaturesInfo indicates all features are continuous.
@@ -271,8 +285,9 @@ depth of 5. The Mean Squared Error (MSE) is computed at the end to evaluate
271285import org.apache.spark.mllib.tree.DecisionTree
272286import org.apache.spark.mllib.util.MLUtils
273287
274- // Load and parse the data file
275- val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
288+ // Load and parse the data file.
289+ // Cache the data since we will use it again to compute training error.
290+ val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").cache()
276291
277292// Train a DecisionTree model.
278293// Empty categoricalFeaturesInfo indicates all features are continuous.
@@ -311,6 +326,8 @@ import org.apache.spark.mllib.tree.model.DecisionTreeModel;
311326import org.apache.spark.mllib.util.MLUtils;
312327import org.apache.spark.SparkConf;
313328
329+ // Load and parse the data file.
330+ // Cache the data since we will use it again to compute training error.
314331String datapath = "data/mllib/sample_libsvm_data.txt";
315332JavaRDD<LabeledPoint > data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
316333
@@ -357,8 +374,9 @@ from pyspark.mllib.regression import LabeledPoint
357374from pyspark.mllib.tree import DecisionTree
358375from pyspark.mllib.util import MLUtils
359376
360- # an RDD of LabeledPoint
361- data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
377+ # Load and parse the data file into an RDD of LabeledPoint.
378+ # Cache the data since we will use it again to compute training error.
379+ data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt').cache()
362380
363381# Train a DecisionTree model.
364382# Empty categoricalFeaturesInfo indicates all features are continuous.
0 commit comments