From 50b143a4385f209fbc1793f3e03134cab3ab9583 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 20 Apr 2014 13:33:03 -0700 Subject: [PATCH 01/18] adding support for very deep trees --- .../spark/mllib/tree/DecisionTree.scala | 85 +++++++++++++++++-- .../spark/mllib/tree/DecisionTreeSuite.scala | 12 +-- 2 files changed, 85 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 3019447ce4cd9..ad901d4f67398 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -58,7 +58,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - logDebug("numSplits = " + bins(0).length) + val numBins = bins(0).length + logDebug("numBins = " + numBins) // depth of the decision tree val maxDepth = strategy.maxDepth @@ -72,7 +73,28 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val parentImpurities = new Array[Double](maxNumNodes) // dummy value for top node (updated during first split calculation) val nodes = new Array[Node](maxNumNodes) + // num features + val numFeatures = input.take(1)(0).features.size + + // Calculate level for single group construction + // Max memory usage for aggregates + val maxMemoryUsage = scala.math.pow(2, 27).toInt //128MB + logDebug("max memory usage for aggregates = " + maxMemoryUsage) + val numElementsPerNode = { + strategy.algo match { + case Classification => 2 * numBins * numFeatures + case Regression => 3 * numBins * numFeatures + } + } + logDebug("numElementsPerNode = " + numElementsPerNode) + val arraySizePerNode = 8 * numElementsPerNode //approx. memory usage for bin aggregate array + val maxNumberOfNodesPerGroup = scala.math.max(maxMemoryUsage / arraySizePerNode, 1) + logDebug("maxNumberOfNodesPerGroup = " + maxNumberOfNodesPerGroup) + // nodes at a level is 2^(level-1). level is zero indexed. + val maxLevelForSingleGroup = scala.math.max( + (scala.math.log(maxNumberOfNodesPerGroup) / scala.math.log(2)).floor.toInt - 1, 0) + logDebug("max level for single group = " + maxLevelForSingleGroup) /* * The main idea here is to perform level-wise training of the decision tree nodes thus @@ -92,7 +114,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Find best split for all nodes at a level. val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, - level, filters, splits, bins) + level, filters, splits, bins, maxLevelForSingleGroup) for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { // Extract info for nodes at the current level. @@ -110,6 +132,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo } } + logDebug("#####################################") + logDebug("Extracting tree model") + logDebug("#####################################") + // Initialize the top or root node of the tree. val topNode = nodes(0) // Build the full tree using the node info calculated in the level-wise best split calculations. @@ -260,6 +286,7 @@ object DecisionTree extends Serializable with Logging { * @param filters Filters for all nodes at a given level * @param splits possible splits for all features * @param bins possible bins for all features + * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation. * @return array of splits with best splits for all nodes at a given level. */ protected[tree] def findBestSplits( @@ -269,7 +296,50 @@ object DecisionTree extends Serializable with Logging { level: Int, filters: Array[List[Filter]], splits: Array[Array[Split]], - bins: Array[Array[Bin]]): Array[(Split, InformationGainStats)] = { + bins: Array[Array[Bin]], + maxLevelForSingleGroup: Int): Array[(Split, InformationGainStats)] = { + // split into groups to avoid memory overflow during aggregation + if (level > maxLevelForSingleGroup) { + val numGroups = scala.math.pow(2, (level - maxLevelForSingleGroup)).toInt + logDebug("numGroups = " + numGroups) + var groupIndex = 0 + var bestSplits = new Array[(Split, InformationGainStats)](0) + while (groupIndex < numGroups) { + val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level, + filters, splits, bins, numGroups, groupIndex) + bestSplits = Array.concat(bestSplits, bestSplitsForGroup) + groupIndex += 1 + } + bestSplits + } else { + findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins) + } + } + + /** + * Returns an array of optimal splits for a group of nodes at a given level + * + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * for DecisionTree + * @param parentImpurities Impurities for all parent nodes for the current level + * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing + * parameters for construction the DecisionTree + * @param level Level of the tree + * @param filters Filters for all nodes at a given level + * @param splits possible splits for all features + * @param bins possible bins for all features + * @return array of splits with best splits for all nodes at a given level. + */ + private def findBestSplitsPerGroup( + input: RDD[LabeledPoint], + parentImpurities: Array[Double], + strategy: Strategy, + level: Int, + filters: Array[List[Filter]], + splits: Array[Array[Split]], + bins: Array[Array[Bin]], + numGroups: Int = 1, + groupIndex: Int = 0): Array[(Split, InformationGainStats)] = { /* * The high-level description for the best split optimizations are noted here. @@ -296,7 +366,7 @@ object DecisionTree extends Serializable with Logging { */ // common calculations for multiple nested methods - val numNodes = scala.math.pow(2, level).toInt + val numNodes = scala.math.pow(2, level).toInt / numGroups logDebug("numNodes = " + numNodes) // Find the number of features by looking at the first sample. val numFeatures = input.first().features.size @@ -304,12 +374,15 @@ object DecisionTree extends Serializable with Logging { val numBins = bins(0).length logDebug("numBins = " + numBins) + // shift when more than one group is used at deep tree level + val groupShift = numNodes * groupIndex + /** Find the filters used before reaching the current code. */ def findParentFilters(nodeIndex: Int): List[Filter] = { if (level == 0) { List[Filter]() } else { - val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex + val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex + groupShift filters(nodeFilterIndex) } } @@ -878,7 +951,7 @@ object DecisionTree extends Serializable with Logging { // Iterating over all nodes at this level var node = 0 while (node < numNodes) { - val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node + val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node + groupShift val binsForNode: Array[Double] = getBinDataForNode(node) logDebug("nodeImpurityIndex = " + nodeImpurityIndex) val parentNodeImpurity = parentImpurities(nodeImpurityIndex) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 350130c914f26..e21db8a3bb8cc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -254,7 +254,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins) + Array[List[Filter]](), splits, bins, 10) val split = bestSplits(0)._1 assert(split.categories.length === 1) @@ -281,7 +281,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins) + Array[List[Filter]](), splits, bins, 10) val split = bestSplits(0)._1 assert(split.categories.length === 1) @@ -310,7 +310,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0).length === 100) val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins) + Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._1.threshold === 10) @@ -333,7 +333,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0).length === 100) val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins) + Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._1.threshold === 10) @@ -357,7 +357,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0).length === 100) val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins) + Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._1.threshold === 10) @@ -381,7 +381,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0).length === 100) val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins) + Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._1.threshold === 10) From abc5a23bf80d792a345d723b44bff3ee217cd5ac Mon Sep 17 00:00:00 2001 From: Evan Sparks Date: Mon, 21 Apr 2014 18:41:36 -0700 Subject: [PATCH 02/18] Parameterizing max memory. --- .../scala/org/apache/spark/mllib/tree/DecisionTree.scala | 8 ++++++-- .../apache/spark/mllib/tree/configuration/Strategy.scala | 3 ++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index ad901d4f67398..ffee3fd848955 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -31,6 +31,7 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom +import org.apache.spark.util.Utils.memoryStringToMb import org.apache.spark.mllib.linalg.{Vector, Vectors} /** @@ -79,7 +80,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Calculate level for single group construction // Max memory usage for aggregates - val maxMemoryUsage = scala.math.pow(2, 27).toInt //128MB + val maxMemoryUsage = strategy.maxMemory * 1024 * 1024 logDebug("max memory usage for aggregates = " + maxMemoryUsage) val numElementsPerNode = { strategy.algo match { @@ -1158,10 +1159,13 @@ object DecisionTree extends Serializable with Logging { val maxDepth = options.getOrElse('maxDepth, "1").toString.toInt val maxBins = options.getOrElse('maxBins, "100").toString.toInt + val maxMemUsage = memoryStringToMb(options.getOrElse('maxMemory, "128m").toString) - val strategy = new Strategy(algo, impurity, maxDepth, maxBins) + val strategy = new Strategy(algo, impurity, maxDepth, maxBins, maxMemory=maxMemUsage) val model = DecisionTree.train(trainData, strategy) + + // Load test data. val testData = loadLabeledData(sc, options.get('testDataDir).get.toString) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 8767aca47cd5a..fd7a9ed1514c9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -43,4 +43,5 @@ class Strategy ( val maxDepth: Int, val maxBins: Int = 100, val quantileCalculationStrategy: QuantileStrategy = Sort, - val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int]()) extends Serializable + val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), + val maxMemory: Int = 128) extends Serializable From 2f1e093c5187a1ed532f9c19b25f8a2a6a46e27a Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 21 Apr 2014 20:49:46 -0700 Subject: [PATCH 03/18] minor: added doc for maxMemory parameter --- .../org/apache/spark/mllib/tree/configuration/Strategy.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index fd7a9ed1514c9..18918ad5c746e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -35,6 +35,9 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * k) implies the feature n is categorical with k categories 0, * 1, 2, ... , k-1. It's important to note that features are * zero-indexed. + * @param maxMemory maximum memory in MB allocated to histogram aggregation. Default value is + * 128 MB. + * */ @Experimental class Strategy ( From 02877721328a560f210a7906061108ce5dd4bbbe Mon Sep 17 00:00:00 2001 From: Evan Sparks Date: Tue, 22 Apr 2014 11:13:27 -0700 Subject: [PATCH 04/18] Fixing scalastyle issue. --- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index ffee3fd848955..3dd410e933fa7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -89,7 +89,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo } } logDebug("numElementsPerNode = " + numElementsPerNode) - val arraySizePerNode = 8 * numElementsPerNode //approx. memory usage for bin aggregate array + val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array val maxNumberOfNodesPerGroup = scala.math.max(maxMemoryUsage / arraySizePerNode, 1) logDebug("maxNumberOfNodesPerGroup = " + maxNumberOfNodesPerGroup) // nodes at a level is 2^(level-1). level is zero indexed. From 719d0098bb08b50e523cec3e388115d5a206512b Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Wed, 23 Apr 2014 17:04:05 -0700 Subject: [PATCH 05/18] updating user documentation --- docs/mllib-classification-regression.md | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/docs/mllib-classification-regression.md b/docs/mllib-classification-regression.md index 2c42f60c2ecce..b06e799577fb4 100644 --- a/docs/mllib-classification-regression.md +++ b/docs/mllib-classification-regression.md @@ -294,12 +294,9 @@ The recursive tree construction is stopped at a node when one of the two conditi 1. The node depth is equal to the `maxDepth` training paramemter 2. No split candidate leads to an information gain at the node. -### Practical Limitations - -The tree implementation stores an Array[Double] of size *O(#features \* #splits \* 2^maxDepth)* in memory for aggregating histograms over partitions. The current implementation might not scale to very deep trees since the memory requirement grows exponentially with tree depth. - -Please drop us a line if you encounter any issues. We are planning to solve this problem in the near future and real-world examples will be great. +### Implementation Details +The tree implementation stores an Array[Double] of size *O(#features \* #splits \* 2^maxDepth)* in memory for aggregating histograms over partitions. Based upon the 'maxMemory' parameter set during training (default is 128 MB), the task is broken down into smaller groups to avoid out-of-memory errors during computation. ## Implementation in MLlib From 15171550fe83e42fcb707744c9035ed540fb78d1 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Tue, 29 Apr 2014 14:45:34 -0700 Subject: [PATCH 06/18] updated documentation --- docs/mllib-decision-tree.md | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index 0693766990732..6667911a6abaf 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -95,15 +95,9 @@ The recursive tree construction is stopped at a node when one of the two conditi ### Practical limitations -1. The tree implementation stores an Array[Double] of size *O(#features \* #splits \* 2^maxDepth)* - in memory for aggregating histograms over partitions. The current implementation might not scale - to very deep trees since the memory requirement grows exponentially with tree depth. -2. The implemented algorithm reads both sparse and dense data. However, it is not optimized for +1. The implemented algorithm reads both sparse and dense data. However, it is not optimized for sparse input. -3. Python is not supported in this release. - -We are planning to solve these problems in the near future. Please drop us a line if you encounter -any issues. +2. Python is not supported in this release. ## Examples From 718506b2a0146a5794261a553847d363b7dfb932 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Wed, 30 Apr 2014 16:29:24 -0700 Subject: [PATCH 07/18] added unit test --- .../examples/mllib/DecisionTreeRunner.scala | 2 +- .../spark/mllib/tree/DecisionTreeSuite.scala | 64 ++++++++++++++++++- 2 files changed, 64 insertions(+), 2 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 0bd847d7bab30..9832bec90d7ee 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -51,7 +51,7 @@ object DecisionTreeRunner { algo: Algo = Classification, maxDepth: Int = 5, impurity: ImpurityType = Gini, - maxBins: Int = 20) + maxBins: Int = 100) def main(args: Array[String]) { val defaultParams = Params() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index e21db8a3bb8cc..4a0b399ca3dde 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -24,7 +24,8 @@ import org.apache.spark.SparkContext import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model.Filter -import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.model.Split +import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.linalg.Vectors @@ -390,6 +391,53 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bestSplits(0)._2.rightImpurity === 0) assert(bestSplits(0)._2.predict === 1) } + + test("test second level node building with/without groups") { + val arr = DecisionTreeSuite.generateOrderedLabeledPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification, Entropy, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + + val leftFilter = Filter(new Split(0,400,FeatureType.Continuous,List()),-1) + val rightFilter = Filter(new Split(0,400,FeatureType.Continuous,List()),1) + val filters = Array[List[Filter]](List(),List(leftFilter),List(rightFilter)) + val parentImpurities = Array(0.5, 0.5, 0.5) + + // Single group second level tree construction. + val bestSplits = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, filters, + splits, bins, 10) + assert(bestSplits.length === 2) + assert(bestSplits(0)._2.gain > 0) + assert(bestSplits(1)._2.gain > 0) + + // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second + // level tree construction. + val bestSplitsWithGroups = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, + filters, splits, bins, 0) + assert(bestSplitsWithGroups.length === 2) + assert(bestSplitsWithGroups(0)._2.gain > 0) + assert(bestSplitsWithGroups(1)._2.gain > 0) + + // Verify whether the splits obtained using single group and multiple group level + // construction strategies are the same. + for (i <- 0 until bestSplits.length) { + assert(bestSplits(i)._1 === bestSplitsWithGroups(i)._1) + assert(bestSplits(i)._2.gain === bestSplitsWithGroups(i)._2.gain) + assert(bestSplits(i)._2.impurity === bestSplitsWithGroups(i)._2.impurity) + assert(bestSplits(i)._2.leftImpurity === bestSplitsWithGroups(i)._2.leftImpurity) + assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity) + assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict) + } + + } + } object DecisionTreeSuite { @@ -412,6 +460,20 @@ object DecisionTreeSuite { arr } + def generateOrderedLabeledPoints(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](1000) + for (i <- 0 until 1000){ + if (i < 600){ + val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) + arr(i) = lp + } else { + val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 1000.0 - i)) + arr(i) = lp + } + } + arr + } + def generateCategoricalDataPoints(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000){ From e0426ee74d5e233c1e7b14e29135015d09a0370c Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Wed, 30 Apr 2014 17:36:47 -0700 Subject: [PATCH 08/18] renamed parameter --- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 2 +- .../org/apache/spark/mllib/tree/configuration/Strategy.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 6f1f3883a7e81..4af6a827946bd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -77,7 +77,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Calculate level for single group construction // Max memory usage for aggregates - val maxMemoryUsage = strategy.maxMemory * 1024 * 1024 + val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024 logDebug("max memory usage for aggregates = " + maxMemoryUsage) val numElementsPerNode = { strategy.algo match { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 18918ad5c746e..eeec2f1621cdd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -35,7 +35,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * k) implies the feature n is categorical with k categories 0, * 1, 2, ... , k-1. It's important to note that features are * zero-indexed. - * @param maxMemory maximum memory in MB allocated to histogram aggregation. Default value is + * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is * 128 MB. * */ @@ -47,4 +47,4 @@ class Strategy ( val maxBins: Int = 100, val quantileCalculationStrategy: QuantileStrategy = Sort, val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), - val maxMemory: Int = 128) extends Serializable + val maxMemoryInMB: Int = 128) extends Serializable From dad96523d740c2b7ced0f0d73ade66e528b64064 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Wed, 30 Apr 2014 21:59:55 -0700 Subject: [PATCH 09/18] removed unused imports --- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 4af6a827946bd..952f03f10e538 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -28,8 +28,6 @@ import org.apache.spark.mllib.tree.impurity.Impurity import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom -import org.apache.spark.util.Utils.memoryStringToMb -import org.apache.spark.mllib.linalg.{Vector, Vectors} /** * :: Experimental :: From cbd9f140fd8d43941c61acd6055636bad88b358d Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sat, 3 May 2014 09:16:42 -0700 Subject: [PATCH 10/18] modified scala.math to math --- .../spark/mllib/tree/DecisionTree.scala | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 952f03f10e538..a5a4e61049ccf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -60,7 +60,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // depth of the decision tree val maxDepth = strategy.maxDepth // the max number of nodes possible given the depth of the tree - val maxNumNodes = scala.math.pow(2, maxDepth).toInt - 1 + val maxNumNodes = math.pow(2, maxDepth).toInt - 1 // Initialize an array to hold filters applied to points for each node. val filters = new Array[List[Filter]](maxNumNodes) // The filter at the top node is an empty list. @@ -85,11 +85,11 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo } logDebug("numElementsPerNode = " + numElementsPerNode) val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array - val maxNumberOfNodesPerGroup = scala.math.max(maxMemoryUsage / arraySizePerNode, 1) + val maxNumberOfNodesPerGroup = math.max(maxMemoryUsage / arraySizePerNode, 1) logDebug("maxNumberOfNodesPerGroup = " + maxNumberOfNodesPerGroup) // nodes at a level is 2^(level-1). level is zero indexed. - val maxLevelForSingleGroup = scala.math.max( - (scala.math.log(maxNumberOfNodesPerGroup) / scala.math.log(2)).floor.toInt - 1, 0) + val maxLevelForSingleGroup = math.max( + (math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt - 1, 0) logDebug("max level for single group = " + maxLevelForSingleGroup) /* @@ -120,7 +120,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo filters) logDebug("final best split = " + nodeSplitStats._1) } - require(scala.math.pow(2, level) == splitsStatsForLevel.length) + require(math.pow(2, level) == splitsStatsForLevel.length) // Check whether all the nodes at the current level at leaves. val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) logDebug("all leaf = " + allLeaf) @@ -153,7 +153,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo nodes: Array[Node]): Unit = { val split = nodeSplitStats._1 val stats = nodeSplitStats._2 - val nodeIndex = scala.math.pow(2, level).toInt - 1 + index + val nodeIndex = math.pow(2, level).toInt - 1 + index val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth - 1) val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) logDebug("Node = " + node) @@ -174,7 +174,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo var i = 0 while (i <= 1) { // Calculate the index of the node from the node level and the index at the current level. - val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i + val nodeIndex = math.pow(2, level + 1).toInt - 1 + 2 * index + i if (level < maxDepth - 1) { val impurity = if (i == 0) { nodeSplitStats._2.leftImpurity @@ -300,7 +300,7 @@ object DecisionTree extends Serializable with Logging { maxLevelForSingleGroup: Int): Array[(Split, InformationGainStats)] = { // split into groups to avoid memory overflow during aggregation if (level > maxLevelForSingleGroup) { - val numGroups = scala.math.pow(2, (level - maxLevelForSingleGroup)).toInt + val numGroups = math.pow(2, (level - maxLevelForSingleGroup)).toInt logDebug("numGroups = " + numGroups) var groupIndex = 0 var bestSplits = new Array[(Split, InformationGainStats)](0) @@ -366,7 +366,7 @@ object DecisionTree extends Serializable with Logging { */ // common calculations for multiple nested methods - val numNodes = scala.math.pow(2, level).toInt / numGroups + val numNodes = math.pow(2, level).toInt / numGroups logDebug("numNodes = " + numNodes) // Find the number of features by looking at the first sample. val numFeatures = input.first().features.size @@ -382,7 +382,7 @@ object DecisionTree extends Serializable with Logging { if (level == 0) { List[Filter]() } else { - val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex + groupShift + val nodeFilterIndex = math.pow(2, level).toInt - 1 + nodeIndex + groupShift filters(nodeFilterIndex) } } @@ -951,7 +951,7 @@ object DecisionTree extends Serializable with Logging { // Iterating over all nodes at this level var node = 0 while (node < numNodes) { - val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node + groupShift + val nodeImpurityIndex = math.pow(2, level).toInt - 1 + node + groupShift val binsForNode: Array[Double] = getBinDataForNode(node) logDebug("nodeImpurityIndex = " + nodeImpurityIndex) val parentNodeImpurity = parentImpurities(nodeImpurityIndex) From 5e822020ce50c6e1bdbdbb3d94d5cbc4c715731e Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 5 May 2014 23:34:58 -0700 Subject: [PATCH 11/18] added documentation, fixed off by 1 error in max level calculation --- .../apache/spark/mllib/tree/DecisionTree.scala | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index a5a4e61049ccf..6c99f82f687e8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -76,10 +76,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Max memory usage for aggregates val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024 - logDebug("max memory usage for aggregates = " + maxMemoryUsage) + logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") val numElementsPerNode = { strategy.algo match { - case Classification => 2 * numBins * numFeatures + case Classification => 2 * numBins * numFeatures case Regression => 3 * numBins * numFeatures } } @@ -87,9 +87,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array val maxNumberOfNodesPerGroup = math.max(maxMemoryUsage / arraySizePerNode, 1) logDebug("maxNumberOfNodesPerGroup = " + maxNumberOfNodesPerGroup) - // nodes at a level is 2^(level-1). level is zero indexed. + // nodes at a level is 2^level. level is zero indexed. val maxLevelForSingleGroup = math.max( - (math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt - 1, 0) + (math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt, 0) logDebug("max level for single group = " + maxLevelForSingleGroup) /* @@ -299,11 +299,16 @@ object DecisionTree extends Serializable with Logging { bins: Array[Array[Bin]], maxLevelForSingleGroup: Int): Array[(Split, InformationGainStats)] = { // split into groups to avoid memory overflow during aggregation - if (level > maxLevelForSingleGroup) { + if (level > maxLevelForSingleGroup) { + // When information for all nodes at a given level cannot be stored in memory, + // the nodes are divided into multiple groups at each level with the number of groups + // increasing exponentially per level. For example, if maxLevelForSingleGroup is 10, + // numGroups is equal to 2 at level 11 and 4 at level 12, respectively. val numGroups = math.pow(2, (level - maxLevelForSingleGroup)).toInt logDebug("numGroups = " + numGroups) - var groupIndex = 0 var bestSplits = new Array[(Split, InformationGainStats)](0) + // Iterate over each group of nodes at a level. + var groupIndex = 0 while (groupIndex < numGroups) { val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins, numGroups, groupIndex) From 4731cda7b08fdcd365dd1b690ac04a26f6e85657 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 5 May 2014 23:44:39 -0700 Subject: [PATCH 12/18] formatting --- .../org/apache/spark/mllib/tree/DecisionTree.scala | 3 ++- .../apache/spark/mllib/tree/DecisionTreeSuite.scala | 12 ++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 6c99f82f687e8..4d7ac51e2f01e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -275,7 +275,8 @@ object DecisionTree extends Serializable with Logging { private val InvalidBinIndex = -1 /** - * Returns an array of optimal splits for all nodes at a given level + * Returns an array of optimal splits for all nodes at a given level. Splits the tasks into + * multiple groups if the level-wise training tasks could lead to memory overflow. * * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data * for DecisionTree diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 4a0b399ca3dde..2155ed7b4a154 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -405,8 +405,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(splits(0).length === 99) assert(bins(0).length === 100) - val leftFilter = Filter(new Split(0,400,FeatureType.Continuous,List()),-1) - val rightFilter = Filter(new Split(0,400,FeatureType.Continuous,List()),1) + val leftFilter = Filter(new Split(0, 400, FeatureType.Continuous,List()), -1) + val rightFilter = Filter(new Split(0, 400, FeatureType.Continuous,List()) ,1) val filters = Array[List[Filter]](List(),List(leftFilter),List(rightFilter)) val parentImpurities = Array(0.5, 0.5, 0.5) @@ -444,7 +444,7 @@ object DecisionTreeSuite { def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) - for (i <- 0 until 1000){ + for (i <- 0 until 1000) { val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) arr(i) = lp } @@ -453,7 +453,7 @@ object DecisionTreeSuite { def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) - for (i <- 0 until 1000){ + for (i <- 0 until 1000) { val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 999.0 - i)) arr(i) = lp } @@ -462,7 +462,7 @@ object DecisionTreeSuite { def generateOrderedLabeledPoints(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) - for (i <- 0 until 1000){ + for (i <- 0 until 1000) { if (i < 600){ val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) arr(i) = lp @@ -476,7 +476,7 @@ object DecisionTreeSuite { def generateCategoricalDataPoints(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) - for (i <- 0 until 1000){ + for (i <- 0 until 1000) { if (i < 600){ arr(i) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)) } else { From 5eca9e4fbd0e27e335d5cea0bf26b1a436be0457 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 5 May 2014 23:47:14 -0700 Subject: [PATCH 13/18] grammar --- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 4d7ac51e2f01e..1f99f28e991f7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -275,8 +275,8 @@ object DecisionTree extends Serializable with Logging { private val InvalidBinIndex = -1 /** - * Returns an array of optimal splits for all nodes at a given level. Splits the tasks into - * multiple groups if the level-wise training tasks could lead to memory overflow. + * Returns an array of optimal splits for all nodes at a given level. Splits the task into + * multiple groups if the level-wise training task could lead to memory overflow. * * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data * for DecisionTree From 8053fed22249bc788ba988489caa22f732b6416d Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 5 May 2014 23:48:02 -0700 Subject: [PATCH 14/18] more formatting --- .../org/apache/spark/mllib/tree/DecisionTreeSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 2155ed7b4a154..51802706d2fc9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -405,9 +405,9 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(splits(0).length === 99) assert(bins(0).length === 100) - val leftFilter = Filter(new Split(0, 400, FeatureType.Continuous,List()), -1) - val rightFilter = Filter(new Split(0, 400, FeatureType.Continuous,List()) ,1) - val filters = Array[List[Filter]](List(),List(leftFilter),List(rightFilter)) + val leftFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()), -1) + val rightFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()) ,1) + val filters = Array[List[Filter]](List(),List(leftFilter), List(rightFilter)) val parentImpurities = Array(0.5, 0.5, 0.5) // Single group second level tree construction. From 426bb285f16c816b19e5c25518024ae4d2141c1a Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Tue, 6 May 2014 00:16:02 -0700 Subject: [PATCH 15/18] programming guide blurb --- docs/mllib-decision-tree.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index 6667911a6abaf..a2a2999a00e3a 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -93,6 +93,10 @@ The recursive tree construction is stopped at a node when one of the two conditi 1. The node depth is equal to the `maxDepth` training parammeter 2. No split candidate leads to an information gain at the node. +### Max memory requirements + +For faster processing, the decision tree algorithm performs simultaneous histogram computations for all nodes at each level of the tree. This could lead to high memory requirements at deeper levels of the tree leading to memory overflow errors. To alleviate this problem, a 'maxMemoryInMB' training parameter is provided which specifies the maximum amount of memory at the workers (twice as much at the master) to be allocated to the histogram computation. The default value is conservatively chosen to be 128 MB to allow the decision algorithm to work in most scenarios. Once the memory requirements for a level-wise computation crosses the `maxMemoryInMB` threshold, the node training tasks at each subsequent level is split into smaller tasks. + ### Practical limitations 1. The implemented algorithm reads both sparse and dense data. However, it is not optimized for From b27ad2c20edb8a7bf0c0edd5d82a6a683b5d9ea2 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Tue, 6 May 2014 00:19:10 -0700 Subject: [PATCH 16/18] formatting --- .../apache/spark/mllib/tree/configuration/Strategy.scala | 2 +- .../org/apache/spark/mllib/tree/DecisionTreeSuite.scala | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index eeec2f1621cdd..1b505fd76eb75 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -36,7 +36,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * 1, 2, ... , k-1. It's important to note that features are * zero-indexed. * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is - * 128 MB. + * 128 MB. * */ @Experimental diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 51802706d2fc9..bc3b1a3fbe95c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -407,7 +407,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val leftFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()), -1) val rightFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()) ,1) - val filters = Array[List[Filter]](List(),List(leftFilter), List(rightFilter)) + val filters = Array[List[Filter]](List(), List(leftFilter), List(rightFilter)) val parentImpurities = Array(0.5, 0.5, 0.5) // Single group second level tree construction. @@ -463,7 +463,7 @@ object DecisionTreeSuite { def generateOrderedLabeledPoints(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000) { - if (i < 600){ + if (i < 600) { val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) arr(i) = lp } else { @@ -477,7 +477,7 @@ object DecisionTreeSuite { def generateCategoricalDataPoints(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000) { - if (i < 600){ + if (i < 600) { arr(i) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)) } else { arr(i) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0)) From ce004a1ab63405e0a5d0bc892a48b1c96c4d6605 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Tue, 6 May 2014 10:29:04 -0700 Subject: [PATCH 17/18] minor formatting --- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 1f99f28e991f7..c3cbe2c63ab03 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -77,12 +77,12 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Max memory usage for aggregates val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024 logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") - val numElementsPerNode = { + val numElementsPerNode = strategy.algo match { case Classification => 2 * numBins * numFeatures case Regression => 3 * numBins * numFeatures } - } + logDebug("numElementsPerNode = " + numElementsPerNode) val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array val maxNumberOfNodesPerGroup = math.max(maxMemoryUsage / arraySizePerNode, 1) From 7fc95457ec66023ddf14e0ef3e1e18cbf828a4db Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Wed, 7 May 2014 10:47:27 -0700 Subject: [PATCH 18/18] added docs --- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index c3cbe2c63ab03..0fe30a3e7040b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -334,6 +334,8 @@ object DecisionTree extends Serializable with Logging { * @param filters Filters for all nodes at a given level * @param splits possible splits for all features * @param bins possible bins for all features + * @param numGroups total number of node groups at the current level. Default value is set to 1. + * @param groupIndex index of the node group being processed. Default value is set to 0. * @return array of splits with best splits for all nodes at a given level. */ private def findBestSplitsPerGroup(