@@ -87,10 +87,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
8787 val maxDepth = strategy.maxDepth
8888 require(maxDepth <= 30 ,
8989 s " DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth. " )
90- // Number of nodes to allocate: max number of nodes possible given the depth of the tree, plus 1
91- val maxNumNodesPlus1 = Node .startIndexInLevel(maxDepth + 1 )
92- // dummy value for top node (updated during first split calculation)
93- val nodes = new Array [Node ](maxNumNodesPlus1)
9490
9591 // Calculate level for single group construction
9692
@@ -118,61 +114,29 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
118114 * beforehand and is not used in later levels.
119115 */
120116
117+ var topNode : Node = null // set on first iteration
121118 var level = 0
122119 var break = false
123120 while (level <= maxDepth && ! break) {
124-
125121 logDebug(" #####################################" )
126122 logDebug(" level = " + level)
127123 logDebug(" #####################################" )
128124
129125 // Find best split for all nodes at a level.
130126 timer.start(" findBestSplits" )
131- val splitsStatsForLevel : Array [(Split , InformationGainStats )] =
132- DecisionTree .findBestSplits(treeInput,
133- metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
127+ val (tmpTopNode : Node , doneTraining : Boolean ) = DecisionTree .findBestSplits(treeInput,
128+ metadata, level, topNode, splits, bins, maxLevelForSingleGroup, timer)
134129 timer.stop(" findBestSplits" )
135130
136- val levelNodeIndexOffset = Node .startIndexInLevel(level)
137- for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
138- val nodeIndex = levelNodeIndexOffset + index
139-
140- // Extract info for this node (index) at the current level.
141- timer.start(" extractNodeInfo" )
142- val split = nodeSplitStats._1
143- val stats = nodeSplitStats._2
144- val isLeaf = (stats.gain <= 0 ) || (level == strategy.maxDepth)
145- val node = new Node (nodeIndex, stats.predict, isLeaf, Some (split), None , None , Some (stats))
146- logDebug(" Node = " + node)
147- nodes(nodeIndex) = node
148- timer.stop(" extractNodeInfo" )
149-
150- if (level != 0 ) {
151- // Set parent.
152- val parentNodeIndex = Node .parentIndex(nodeIndex)
153- if (Node .isLeftChild(nodeIndex)) {
154- nodes(parentNodeIndex).leftNode = Some (nodes(nodeIndex))
155- } else {
156- nodes(parentNodeIndex).rightNode = Some (nodes(nodeIndex))
157- }
158- }
159- if (level < maxDepth) {
160- logDebug(" leftChildIndex = " + Node .leftChildIndex(nodeIndex) +
161- " , impurity = " + stats.leftImpurity)
162- logDebug(" rightChildIndex = " + Node .rightChildIndex(nodeIndex) +
163- " , impurity = " + stats.rightImpurity)
164- }
165- logDebug(" final best split = " + split)
131+ if (level == 0 ) {
132+ topNode = tmpTopNode
166133 }
167- require(Node .maxNodesInLevel(level) == splitsStatsForLevel.length)
168- // Check whether all the nodes at the current level at leaves.
169- val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0 )
170- logDebug(" all leaf = " + allLeaf)
171- if (allLeaf) {
172- break = true // no more tree construction
173- } else {
174- level += 1
134+ if (doneTraining) {
135+ break = true
136+ logDebug(" done training" )
175137 }
138+
139+ level += 1
176140 }
177141
178142 logDebug(" #####################################" )
@@ -184,7 +148,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
184148 logInfo(" Internal timing for DecisionTree:" )
185149 logInfo(s " $timer" )
186150
187- new DecisionTreeModel (nodes( 1 ) , strategy.algo)
151+ new DecisionTreeModel (topNode , strategy.algo)
188152 }
189153
190154}
@@ -398,17 +362,19 @@ object DecisionTree extends Serializable with Logging {
398362 * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
399363 * @param bins possible bins for all features, indexed (numFeatures)(numBins)
400364 * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation.
401- * @return array (over nodes) of splits with best split for each node at a given level.
365+ * @return (root, doneTraining) where:
366+ * root = Root node (which is newly created on the first iteration),
367+ * doneTraining = true if no more internal nodes were created.
402368 */
403369 private [tree] def findBestSplits (
404370 input : RDD [TreePoint ],
405371 metadata : DecisionTreeMetadata ,
406372 level : Int ,
407- nodes : Array [ Node ] ,
373+ topNode : Node ,
408374 splits : Array [Array [Split ]],
409375 bins : Array [Array [Bin ]],
410376 maxLevelForSingleGroup : Int ,
411- timer : TimeTracker = new TimeTracker ): Array [( Split , InformationGainStats )] = {
377+ timer : TimeTracker = new TimeTracker ): ( Node , Boolean ) = {
412378 // split into groups to avoid memory overflow during aggregation
413379 if (level > maxLevelForSingleGroup) {
414380 // When information for all nodes at a given level cannot be stored in memory,
@@ -417,18 +383,18 @@ object DecisionTree extends Serializable with Logging {
417383 // numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
418384 val numGroups = 1 << level - maxLevelForSingleGroup
419385 logDebug(" numGroups = " + numGroups)
420- var bestSplits = new Array [(Split , InformationGainStats )](0 )
421386 // Iterate over each group of nodes at a level.
422387 var groupIndex = 0
388+ var doneTraining = true
423389 while (groupIndex < numGroups) {
424- val bestSplitsForGroup = findBestSplitsPerGroup(input, metadata, level,
425- nodes , splits, bins, timer, numGroups, groupIndex)
426- bestSplits = Array .concat(bestSplits, bestSplitsForGroup)
390+ val (tmpRoot, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level,
391+ topNode , splits, bins, timer, numGroups, groupIndex)
392+ doneTraining = doneTraining && doneTrainingGroup
427393 groupIndex += 1
428394 }
429- bestSplits
395+ (topNode, doneTraining) // Not first iteration, so topNode was already set.
430396 } else {
431- findBestSplitsPerGroup(input, metadata, level, nodes , splits, bins, timer)
397+ findBestSplitsPerGroup(input, metadata, level, topNode , splits, bins, timer)
432398 }
433399 }
434400
@@ -570,23 +536,25 @@ object DecisionTree extends Serializable with Logging {
570536 * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint ]]
571537 * @param metadata Learning and dataset metadata
572538 * @param level Level of the tree
573- * @param nodes Array of all nodes in the tree. Used for matching data points to nodes .
539+ * @param topNode Root node of the tree (or invalid node when training first level) .
574540 * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
575541 * @param bins possible bins for all features, indexed (numFeatures)(numBins)
576542 * @param numGroups total number of node groups at the current level. Default value is set to 1.
577543 * @param groupIndex index of the node group being processed. Default value is set to 0.
578- * @return array of splits with best splits for all nodes at a given level.
544+ * @return (root, doneTraining) where:
545+ * root = Root node (which is newly created on the first iteration),
546+ * doneTraining = true if no more internal nodes were created.
579547 */
580548 private def findBestSplitsPerGroup (
581549 input : RDD [TreePoint ],
582550 metadata : DecisionTreeMetadata ,
583551 level : Int ,
584- nodes : Array [ Node ] ,
552+ topNode : Node ,
585553 splits : Array [Array [Split ]],
586554 bins : Array [Array [Bin ]],
587555 timer : TimeTracker ,
588556 numGroups : Int = 1 ,
589- groupIndex : Int = 0 ): Array [( Split , InformationGainStats )] = {
557+ groupIndex : Int = 0 ): ( Node , Boolean ) = {
590558
591559 /*
592560 * The high-level descriptions of the best split optimizations are noted here.
@@ -643,7 +611,7 @@ object DecisionTree extends Serializable with Logging {
643611 0
644612 } else {
645613 val globalNodeIndex =
646- predictNodeIndex(nodes( 1 ) , treePoint.binnedFeatures, bins, metadata.unorderedFeatures)
614+ predictNodeIndex(topNode , treePoint.binnedFeatures, bins, metadata.unorderedFeatures)
647615 globalNodeIndex - globalNodeIndexOffset
648616 }
649617 }
@@ -686,18 +654,53 @@ object DecisionTree extends Serializable with Logging {
686654
687655 // Calculate best splits for all nodes at a given level
688656 timer.start(" chooseSplits" )
689- val bestSplits = new Array [(Split , InformationGainStats )](numNodes)
657+ // On the first iteration, we need to get and return the newly created root node.
658+ var newTopNode : Node = topNode
690659 // Iterating over all nodes at this level
691660 var nodeIndex = 0
661+ var internalNodeCount = 0
692662 while (nodeIndex < numNodes) {
693- bestSplits(nodeIndex ) =
663+ val ( split : Split , stats : InformationGainStats ) =
694664 binsToBestSplit(binAggregates, nodeIndex, level, metadata, splits)
695- logDebug(" best split = " + bestSplits(nodeIndex)._1)
665+ logDebug(" best split = " + split)
666+
667+ val globalNodeIndex = globalNodeIndexOffset + nodeIndex
668+
669+ // Extract info for this node at the current level.
670+ timer.start(" extractNodeInfo" )
671+ val isLeaf = (stats.gain <= 0 ) || (level == metadata.maxDepth)
672+ val node =
673+ new Node (globalNodeIndex, stats.predict, isLeaf, Some (split), None , None , Some (stats))
674+ logDebug(" Node = " + node)
675+ timer.stop(" extractNodeInfo" )
676+
677+ if (! isLeaf) {
678+ internalNodeCount += 1
679+ }
680+ if (level == 0 ) {
681+ newTopNode = node
682+ } else {
683+ // Set parent.
684+ val parentNode = Node .getNode(Node .parentIndex(globalNodeIndex), topNode)
685+ if (Node .isLeftChild(globalNodeIndex)) {
686+ parentNode.leftNode = Some (node)
687+ } else {
688+ parentNode.rightNode = Some (node)
689+ }
690+ }
691+ if (level < metadata.maxDepth) {
692+ logDebug(" leftChildIndex = " + Node .leftChildIndex(globalNodeIndex) +
693+ " , impurity = " + stats.leftImpurity)
694+ logDebug(" rightChildIndex = " + Node .rightChildIndex(globalNodeIndex) +
695+ " , impurity = " + stats.rightImpurity)
696+ }
697+
696698 nodeIndex += 1
697699 }
698700 timer.stop(" chooseSplits" )
699701
700- bestSplits
702+ val doneTraining = internalNodeCount == 0
703+ (newTopNode, doneTraining)
701704 }
702705
703706 /**
0 commit comments