Skip to content

Commit 9dd1b6b

Browse files
committed
Updated decision tree doc.
1 parent d802369 commit 9dd1b6b

File tree

1 file changed

+43
-25
lines changed

1 file changed

+43
-25
lines changed

docs/mllib-decision-tree.md

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ and their ensembles are popular methods for the machine learning tasks of
1212
classification and regression. Decision trees are widely used since they are easy to interpret,
1313
handle categorical features, extend to the multiclass classification setting, do not require
1414
feature scaling and are able to capture nonlinearities and feature interactions. Tree ensemble
15-
algorithms such as decision forests and boosting are among the top performers for classification and
15+
algorithms such as random forests and boosting are among the top performers for classification and
1616
regression tasks.
1717

1818
MLlib supports decision trees for binary and multiclass classification and for regression,
@@ -94,13 +94,13 @@ Section 9.2.4 in
9494
details). For example, for a binary classification problem with one categorical feature with three
9595
categories A, B and C whose corresponding proportions of label 1 are 0.2, 0.6 and 0.4, the categorical
9696
features are ordered as A, C, B. The two split candidates are A \| C, B
97-
and A , C \| B where \| denotes the split. A similar heuristic is used for multiclass classification
98-
when `$2^{M-1}-1$` is greater than the `maxBins` parameter: the impurity for each categorical feature value
99-
is used for ordering. In multiclass classification, all `$2^{M-1}-1$` possible splits are used
100-
whenever possible.
97+
and A , C \| B where \| denotes the split.
10198

102-
Note that the `maxBins` parameter must be at least `$M_{max}$`, the maximum number of categories for
103-
any categorical feature.
99+
In multiclass classification, all `$2^{M-1}-1$` possible splits are used whenever possible.
100+
When `$2^{M-1}-1$` is greater than the `maxBins` parameter, we use a (heuristic) method
101+
similar to the method used for binary classification and regression.
102+
The `$M$` categorical feature values are ordered by impurity,
103+
and the resulting `$M-1$` split candidates are considered.
104104

105105
### Stopping rule
106106

@@ -109,6 +109,8 @@ The recursive tree construction is stopped at a node when one of the two conditi
109109
1. The node depth is equal to the `maxDepth` training parameter.
110110
2. No split candidate leads to an information gain at the node.
111111

112+
## Implementation details
113+
112114
### Max memory requirements
113115

114116
For faster processing, the decision tree algorithm performs simultaneous histogram computations for
@@ -120,11 +122,24 @@ be 128 MB to allow the decision algorithm to work in most scenarios. Once the me
120122
for a level-wise computation cross the `maxMemoryInMB` threshold, the node training tasks at each
121123
subsequent level are split into smaller tasks.
122124

123-
### Practical limitations
125+
Note that, if you have a large amount of memory, increasing `maxMemoryInMB` can lead to faster
126+
training by requiring fewer passes over the data.
127+
128+
### Binning feature values
129+
130+
Increasing `maxBins` allows the algorithm to consider more split candidates and make fine-grained
131+
split decisions. However, it also increases computation and communication.
132+
133+
Note that the `maxBins` parameter must be at least the maximum number of categories `$M$` for
134+
any categorical feature.
135+
136+
### Scaling
137+
138+
Computation scales approximately linearly in the number of training instances,
139+
in the number of features, and in the `maxBins` parameter.
140+
Communication scales approximately linearly in the number of features and in `maxBins`.
124141

125-
1. The implemented algorithm reads both sparse and dense data. However, it is not optimized for sparse input.
126-
2. Computation scales approximately linearly in the number of training instances,
127-
in the number of features, and in the `maxBins` parameter.
142+
The implemented algorithm reads both sparse and dense data. However, it is not optimized for sparse input.
128143

129144
## Examples
130145

@@ -143,8 +158,9 @@ maximum tree depth of 5. The training error is calculated to measure the algorit
143158
import org.apache.spark.mllib.tree.DecisionTree
144159
import org.apache.spark.mllib.util.MLUtils
145160

146-
// Load and parse the data file
147-
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
161+
// Load and parse the data file.
162+
// Cache the data since we will use it again to compute training error.
163+
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").cache()
148164

149165
// Train a DecisionTree model.
150166
// Empty categoricalFeaturesInfo indicates all features are continuous.
@@ -187,17 +203,14 @@ import org.apache.spark.SparkConf;
187203
SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree");
188204
JavaSparkContext sc = new JavaSparkContext(sparkConf);
189205

206+
// Load and parse the data file.
207+
// Cache the data since we will use it again to compute training error.
190208
String datapath = "data/mllib/sample_libsvm_data.txt";
191209
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
192-
// Compute the number of classes from the data.
193-
Integer numClasses = data.map(new Function<LabeledPoint, Double>() {
194-
@Override public Double call(LabeledPoint p) {
195-
return p.label();
196-
}
197-
}).countByValue().size();
198210

199211
// Set parameters.
200212
// Empty categoricalFeaturesInfo indicates all features are continuous.
213+
Integer numClasses = 2;
201214
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
202215
String impurity = "gini";
203216
Integer maxDepth = 5;
@@ -231,8 +244,9 @@ from pyspark.mllib.regression import LabeledPoint
231244
from pyspark.mllib.tree import DecisionTree
232245
from pyspark.mllib.util import MLUtils
233246

234-
# an RDD of LabeledPoint
235-
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
247+
# Load and parse the data file into an RDD of LabeledPoint.
248+
# Cache the data since we will use it again to compute training error.
249+
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt').cache()
236250

237251
# Train a DecisionTree model.
238252
# Empty categoricalFeaturesInfo indicates all features are continuous.
@@ -271,8 +285,9 @@ depth of 5. The Mean Squared Error (MSE) is computed at the end to evaluate
271285
import org.apache.spark.mllib.tree.DecisionTree
272286
import org.apache.spark.mllib.util.MLUtils
273287

274-
// Load and parse the data file
275-
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
288+
// Load and parse the data file.
289+
// Cache the data since we will use it again to compute training error.
290+
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").cache()
276291

277292
// Train a DecisionTree model.
278293
// Empty categoricalFeaturesInfo indicates all features are continuous.
@@ -311,6 +326,8 @@ import org.apache.spark.mllib.tree.model.DecisionTreeModel;
311326
import org.apache.spark.mllib.util.MLUtils;
312327
import org.apache.spark.SparkConf;
313328

329+
// Load and parse the data file.
330+
// Cache the data since we will use it again to compute training error.
314331
String datapath = "data/mllib/sample_libsvm_data.txt";
315332
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
316333

@@ -357,8 +374,9 @@ from pyspark.mllib.regression import LabeledPoint
357374
from pyspark.mllib.tree import DecisionTree
358375
from pyspark.mllib.util import MLUtils
359376

360-
# an RDD of LabeledPoint
361-
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
377+
# Load and parse the data file into an RDD of LabeledPoint.
378+
# Cache the data since we will use it again to compute training error.
379+
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt').cache()
362380

363381
# Train a DecisionTree model.
364382
# Empty categoricalFeaturesInfo indicates all features are continuous.

0 commit comments

Comments
 (0)