Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions docs/mllib-decision-tree.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ The ordered splits create "bins" and the maximum number of such
bins can be specified using the `maxBins` parameter.

Note that the number of bins cannot be greater than the number of instances `$N$` (a rare scenario
since the default `maxBins` value is 100). The tree algorithm automatically reduces the number of
since the default `maxBins` value is 32). The tree algorithm automatically reduces the number of
bins if the condition is not satisfied.

**Categorical features**
Expand Down Expand Up @@ -117,7 +117,7 @@ all nodes at each level of the tree. This could lead to high memory requirements
of the tree, potentially leading to memory overflow errors. To alleviate this problem, a `maxMemoryInMB`
training parameter specifies the maximum amount of memory at the workers (twice as much at the
master) to be allocated to the histogram computation. The default value is conservatively chosen to
be 128 MB to allow the decision algorithm to work in most scenarios. Once the memory requirements
be 256 MB to allow the decision algorithm to work in most scenarios. Once the memory requirements
for a level-wise computation cross the `maxMemoryInMB` threshold, the node training tasks at each
subsequent level are split into smaller tasks.

Expand Down Expand Up @@ -167,7 +167,7 @@ val numClasses = 2
val categoricalFeaturesInfo = Map[Int, Int]()
val impurity = "gini"
val maxDepth = 5
val maxBins = 100
val maxBins = 32

val model = DecisionTree.trainClassifier(data, numClasses, categoricalFeaturesInfo, impurity,
maxDepth, maxBins)
Expand Down Expand Up @@ -213,7 +213,7 @@ Integer numClasses = 2;
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
String impurity = "gini";
Integer maxDepth = 5;
Integer maxBins = 100;
Integer maxBins = 32;

// Train a DecisionTree model for classification.
final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses,
Expand Down Expand Up @@ -250,7 +250,7 @@ data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt').cache()
# Train a DecisionTree model.
# Empty categoricalFeaturesInfo indicates all features are continuous.
model = DecisionTree.trainClassifier(data, numClasses=2, categoricalFeaturesInfo={},
impurity='gini', maxDepth=5, maxBins=100)
impurity='gini', maxDepth=5, maxBins=32)

# Evaluate model on training instances and compute training error
predictions = model.predict(data.map(lambda x: x.features))
Expand Down Expand Up @@ -293,7 +293,7 @@ val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").cache
val categoricalFeaturesInfo = Map[Int, Int]()
val impurity = "variance"
val maxDepth = 5
val maxBins = 100
val maxBins = 32

val model = DecisionTree.trainRegressor(data, categoricalFeaturesInfo, impurity,
maxDepth, maxBins)
Expand Down Expand Up @@ -338,7 +338,7 @@ JavaSparkContext sc = new JavaSparkContext(sparkConf);
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
String impurity = "variance";
Integer maxDepth = 5;
Integer maxBins = 100;
Integer maxBins = 32;

// Train a DecisionTree model.
final DecisionTreeModel model = DecisionTree.trainRegressor(data,
Expand Down Expand Up @@ -380,7 +380,7 @@ data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt').cache()
# Train a DecisionTree model.
# Empty categoricalFeaturesInfo indicates all features are continuous.
model = DecisionTree.trainRegressor(data, categoricalFeaturesInfo={},
impurity='variance', maxDepth=5, maxBins=100)
impurity='variance', maxDepth=5, maxBins=32)

# Evaluate model on training instances and compute training error
predictions = model.predict(data.map(lambda x: x.features))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public static void main(String[] args) {
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
String impurity = "gini";
Integer maxDepth = 5;
Integer maxBins = 100;
Integer maxBins = 32;

// Train a DecisionTree model for classification.
final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ object DecisionTreeRunner {
input: String = null,
dataFormat: String = "libsvm",
algo: Algo = Classification,
maxDepth: Int = 4,
maxDepth: Int = 5,
impurity: ImpurityType = Gini,
maxBins: Int = 100,
maxBins: Int = 32,
fracTest: Double = 0.2)

def main(args: Array[String]) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,9 @@ object DecisionTree extends Serializable with Logging {
* Supported values: "gini" (recommended) or "entropy".
* @param maxDepth Maximum depth of the tree.
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
* (suggested value: 4)
* (suggested value: 5)
* @param maxBins maximum number of bins used for splitting features
* (suggested value: 100)
* (suggested value: 32)
* @return DecisionTreeModel that can be used for prediction
*/
def trainClassifier(
Expand Down Expand Up @@ -374,9 +374,9 @@ object DecisionTree extends Serializable with Logging {
* Supported values: "variance".
* @param maxDepth Maximum depth of the tree.
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
* (suggested value: 4)
* (suggested value: 5)
* @param maxBins maximum number of bins used for splitting features
* (suggested value: 100)
* (suggested value: 32)
* @return DecisionTreeModel that can be used for prediction
*/
def trainRegressor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,18 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* 1, 2, ... , k-1. It's important to note that features are
* zero-indexed.
* @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is
* 128 MB.
* 256 MB.
*/
@Experimental
class Strategy (
val algo: Algo,
val impurity: Impurity,
val maxDepth: Int,
val numClassesForClassification: Int = 2,
val maxBins: Int = 100,
val maxBins: Int = 32,
val quantileCalculationStrategy: QuantileStrategy = Sort,
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
val maxMemoryInMB: Int = 128) extends Serializable {
val maxMemoryInMB: Int = 256) extends Serializable {

if (algo == Classification) {
require(numClassesForClassification >= 2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node}
import org.apache.spark.mllib.util.LocalSparkContext


class DecisionTreeSuite extends FunSuite with LocalSparkContext {

def validateClassifier(
Expand Down Expand Up @@ -353,8 +352,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(splits(0).length === 99)
assert(bins.length === 2)
assert(bins(0).length === 100)
assert(splits(0).length === 99)
assert(bins(0).length === 100)

val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
Expand All @@ -381,8 +378,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(splits(0).length === 99)
assert(bins.length === 2)
assert(bins(0).length === 100)
assert(splits(0).length === 99)
assert(bins(0).length === 100)

val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
Expand Down Expand Up @@ -410,8 +405,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(splits(0).length === 99)
assert(bins.length === 2)
assert(bins(0).length === 100)
assert(splits(0).length === 99)
assert(bins(0).length === 100)

val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
Expand Down Expand Up @@ -439,8 +432,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(splits(0).length === 99)
assert(bins.length === 2)
assert(bins(0).length === 100)
assert(splits(0).length === 99)
assert(bins(0).length === 100)

val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
Expand All @@ -464,8 +455,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(splits(0).length === 99)
assert(bins.length === 2)
assert(bins(0).length === 100)
assert(splits(0).length === 99)
assert(bins(0).length === 100)

// Train a 1-node model
val strategyOneNode = new Strategy(Classification, Entropy, 1, 2, 100)
Expand Down Expand Up @@ -600,7 +589,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClassesForClassification = 3)
numClassesForClassification = 3, maxBins = 100)
assert(strategy.isMulticlassClassification)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)

Expand All @@ -626,7 +615,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3))
numClassesForClassification = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3))
assert(strategy.isMulticlassClassification)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
assert(metadata.isUnordered(featureIndex = 0))
Expand All @@ -652,7 +641,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
numClassesForClassification = 3, maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
assert(strategy.isMulticlassClassification)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
assert(!metadata.isUnordered(featureIndex = 0))
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/mllib/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ class DecisionTree(object):

@staticmethod
def trainClassifier(data, numClasses, categoricalFeaturesInfo,
impurity="gini", maxDepth=4, maxBins=100):
impurity="gini", maxDepth=5, maxBins=32):
"""
Train a DecisionTreeModel for classification.

Expand Down Expand Up @@ -170,7 +170,7 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo,

@staticmethod
def trainRegressor(data, categoricalFeaturesInfo,
impurity="variance", maxDepth=4, maxBins=100):
impurity="variance", maxDepth=5, maxBins=32):
"""
Train a DecisionTreeModel for regression.

Expand Down