Skip to content

Commit 1a8f0ad

Browse files
committed
Eliminated pre-allocated nodes array in main train() method.
* Nodes are constructed and added to the tree structure as needed during training. Moved tree construction from main train() method into findBestSplitsPerGroup() since there is no need to keep the (split, gain) array for an entire level of nodes. Only one element of that array is needed at a time, so we do not the array. findBestSplits() now returns 2 items: * rootNode (newly created root node on first iteration, same root node on later iterations) * doneTraining (indicating if all nodes at that level were leafs) Also: * Added Node.deepCopy (private[tree]), used for test suite * Updated test suite (same functionality)
1 parent 2ab763b commit 1a8f0ad

File tree

4 files changed

+224
-172
lines changed

4 files changed

+224
-172
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 68 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
8787
val maxDepth = strategy.maxDepth
8888
require(maxDepth <= 30,
8989
s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.")
90-
// Number of nodes to allocate: max number of nodes possible given the depth of the tree, plus 1
91-
val maxNumNodesPlus1 = Node.startIndexInLevel(maxDepth + 1)
92-
// dummy value for top node (updated during first split calculation)
93-
val nodes = new Array[Node](maxNumNodesPlus1)
9490

9591
// Calculate level for single group construction
9692

@@ -118,61 +114,29 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
118114
* beforehand and is not used in later levels.
119115
*/
120116

117+
var topNode: Node = null // set on first iteration
121118
var level = 0
122119
var break = false
123120
while (level <= maxDepth && !break) {
124-
125121
logDebug("#####################################")
126122
logDebug("level = " + level)
127123
logDebug("#####################################")
128124

129125
// Find best split for all nodes at a level.
130126
timer.start("findBestSplits")
131-
val splitsStatsForLevel: Array[(Split, InformationGainStats)] =
132-
DecisionTree.findBestSplits(treeInput,
133-
metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
127+
val (tmpTopNode: Node, doneTraining: Boolean) = DecisionTree.findBestSplits(treeInput,
128+
metadata, level, topNode, splits, bins, maxLevelForSingleGroup, timer)
134129
timer.stop("findBestSplits")
135130

136-
val levelNodeIndexOffset = Node.startIndexInLevel(level)
137-
for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
138-
val nodeIndex = levelNodeIndexOffset + index
139-
140-
// Extract info for this node (index) at the current level.
141-
timer.start("extractNodeInfo")
142-
val split = nodeSplitStats._1
143-
val stats = nodeSplitStats._2
144-
val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth)
145-
val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
146-
logDebug("Node = " + node)
147-
nodes(nodeIndex) = node
148-
timer.stop("extractNodeInfo")
149-
150-
if (level != 0) {
151-
// Set parent.
152-
val parentNodeIndex = Node.parentIndex(nodeIndex)
153-
if (Node.isLeftChild(nodeIndex)) {
154-
nodes(parentNodeIndex).leftNode = Some(nodes(nodeIndex))
155-
} else {
156-
nodes(parentNodeIndex).rightNode = Some(nodes(nodeIndex))
157-
}
158-
}
159-
if (level < maxDepth) {
160-
logDebug("leftChildIndex = " + Node.leftChildIndex(nodeIndex) +
161-
", impurity = " + stats.leftImpurity)
162-
logDebug("rightChildIndex = " + Node.rightChildIndex(nodeIndex) +
163-
", impurity = " + stats.rightImpurity)
164-
}
165-
logDebug("final best split = " + split)
131+
if (level == 0) {
132+
topNode = tmpTopNode
166133
}
167-
require(Node.maxNodesInLevel(level) == splitsStatsForLevel.length)
168-
// Check whether all the nodes at the current level at leaves.
169-
val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
170-
logDebug("all leaf = " + allLeaf)
171-
if (allLeaf) {
172-
break = true // no more tree construction
173-
} else {
174-
level += 1
134+
if (doneTraining) {
135+
break = true
136+
logDebug("done training")
175137
}
138+
139+
level += 1
176140
}
177141

178142
logDebug("#####################################")
@@ -184,7 +148,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
184148
logInfo("Internal timing for DecisionTree:")
185149
logInfo(s"$timer")
186150

187-
new DecisionTreeModel(nodes(1), strategy.algo)
151+
new DecisionTreeModel(topNode, strategy.algo)
188152
}
189153

190154
}
@@ -398,17 +362,19 @@ object DecisionTree extends Serializable with Logging {
398362
* @param splits possible splits for all features, indexed (numFeatures)(numSplits)
399363
* @param bins possible bins for all features, indexed (numFeatures)(numBins)
400364
* @param maxLevelForSingleGroup the deepest level for single-group level-wise computation.
401-
* @return array (over nodes) of splits with best split for each node at a given level.
365+
* @return (root, doneTraining) where:
366+
* root = Root node (which is newly created on the first iteration),
367+
* doneTraining = true if no more internal nodes were created.
402368
*/
403369
private[tree] def findBestSplits(
404370
input: RDD[TreePoint],
405371
metadata: DecisionTreeMetadata,
406372
level: Int,
407-
nodes: Array[Node],
373+
topNode: Node,
408374
splits: Array[Array[Split]],
409375
bins: Array[Array[Bin]],
410376
maxLevelForSingleGroup: Int,
411-
timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = {
377+
timer: TimeTracker = new TimeTracker): (Node, Boolean) = {
412378
// split into groups to avoid memory overflow during aggregation
413379
if (level > maxLevelForSingleGroup) {
414380
// When information for all nodes at a given level cannot be stored in memory,
@@ -417,18 +383,18 @@ object DecisionTree extends Serializable with Logging {
417383
// numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
418384
val numGroups = 1 << level - maxLevelForSingleGroup
419385
logDebug("numGroups = " + numGroups)
420-
var bestSplits = new Array[(Split, InformationGainStats)](0)
421386
// Iterate over each group of nodes at a level.
422387
var groupIndex = 0
388+
var doneTraining = true
423389
while (groupIndex < numGroups) {
424-
val bestSplitsForGroup = findBestSplitsPerGroup(input, metadata, level,
425-
nodes, splits, bins, timer, numGroups, groupIndex)
426-
bestSplits = Array.concat(bestSplits, bestSplitsForGroup)
390+
val (tmpRoot, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level,
391+
topNode, splits, bins, timer, numGroups, groupIndex)
392+
doneTraining = doneTraining && doneTrainingGroup
427393
groupIndex += 1
428394
}
429-
bestSplits
395+
(topNode, doneTraining) // Not first iteration, so topNode was already set.
430396
} else {
431-
findBestSplitsPerGroup(input, metadata, level, nodes, splits, bins, timer)
397+
findBestSplitsPerGroup(input, metadata, level, topNode, splits, bins, timer)
432398
}
433399
}
434400

@@ -570,23 +536,25 @@ object DecisionTree extends Serializable with Logging {
570536
* @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
571537
* @param metadata Learning and dataset metadata
572538
* @param level Level of the tree
573-
* @param nodes Array of all nodes in the tree. Used for matching data points to nodes.
539+
* @param topNode Root node of the tree (or invalid node when training first level).
574540
* @param splits possible splits for all features, indexed (numFeatures)(numSplits)
575541
* @param bins possible bins for all features, indexed (numFeatures)(numBins)
576542
* @param numGroups total number of node groups at the current level. Default value is set to 1.
577543
* @param groupIndex index of the node group being processed. Default value is set to 0.
578-
* @return array of splits with best splits for all nodes at a given level.
544+
* @return (root, doneTraining) where:
545+
* root = Root node (which is newly created on the first iteration),
546+
* doneTraining = true if no more internal nodes were created.
579547
*/
580548
private def findBestSplitsPerGroup(
581549
input: RDD[TreePoint],
582550
metadata: DecisionTreeMetadata,
583551
level: Int,
584-
nodes: Array[Node],
552+
topNode: Node,
585553
splits: Array[Array[Split]],
586554
bins: Array[Array[Bin]],
587555
timer: TimeTracker,
588556
numGroups: Int = 1,
589-
groupIndex: Int = 0): Array[(Split, InformationGainStats)] = {
557+
groupIndex: Int = 0): (Node, Boolean) = {
590558

591559
/*
592560
* The high-level descriptions of the best split optimizations are noted here.
@@ -643,7 +611,7 @@ object DecisionTree extends Serializable with Logging {
643611
0
644612
} else {
645613
val globalNodeIndex =
646-
predictNodeIndex(nodes(1), treePoint.binnedFeatures, bins, metadata.unorderedFeatures)
614+
predictNodeIndex(topNode, treePoint.binnedFeatures, bins, metadata.unorderedFeatures)
647615
globalNodeIndex - globalNodeIndexOffset
648616
}
649617
}
@@ -686,18 +654,53 @@ object DecisionTree extends Serializable with Logging {
686654

687655
// Calculate best splits for all nodes at a given level
688656
timer.start("chooseSplits")
689-
val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
657+
// On the first iteration, we need to get and return the newly created root node.
658+
var newTopNode: Node = topNode
690659
// Iterating over all nodes at this level
691660
var nodeIndex = 0
661+
var internalNodeCount = 0
692662
while (nodeIndex < numNodes) {
693-
bestSplits(nodeIndex) =
663+
val (split: Split, stats: InformationGainStats) =
694664
binsToBestSplit(binAggregates, nodeIndex, level, metadata, splits)
695-
logDebug("best split = " + bestSplits(nodeIndex)._1)
665+
logDebug("best split = " + split)
666+
667+
val globalNodeIndex = globalNodeIndexOffset + nodeIndex
668+
669+
// Extract info for this node at the current level.
670+
timer.start("extractNodeInfo")
671+
val isLeaf = (stats.gain <= 0) || (level == metadata.maxDepth)
672+
val node =
673+
new Node(globalNodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
674+
logDebug("Node = " + node)
675+
timer.stop("extractNodeInfo")
676+
677+
if (!isLeaf) {
678+
internalNodeCount += 1
679+
}
680+
if (level == 0) {
681+
newTopNode = node
682+
} else {
683+
// Set parent.
684+
val parentNode = Node.getNode(Node.parentIndex(globalNodeIndex), topNode)
685+
if (Node.isLeftChild(globalNodeIndex)) {
686+
parentNode.leftNode = Some(node)
687+
} else {
688+
parentNode.rightNode = Some(node)
689+
}
690+
}
691+
if (level < metadata.maxDepth) {
692+
logDebug("leftChildIndex = " + Node.leftChildIndex(globalNodeIndex) +
693+
", impurity = " + stats.leftImpurity)
694+
logDebug("rightChildIndex = " + Node.rightChildIndex(globalNodeIndex) +
695+
", impurity = " + stats.rightImpurity)
696+
}
697+
696698
nodeIndex += 1
697699
}
698700
timer.stop("chooseSplits")
699701

700-
bestSplits
702+
val doneTraining = internalNodeCount == 0
703+
(newTopNode, doneTraining)
701704
}
702705

703706
/**

mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ private[tree] class DecisionTreeMetadata(
4545
val unorderedFeatures: Set[Int],
4646
val numBins: Array[Int],
4747
val impurity: Impurity,
48-
val quantileStrategy: QuantileStrategy) extends Serializable {
48+
val quantileStrategy: QuantileStrategy,
49+
val maxDepth: Int) extends Serializable {
4950

5051
def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex)
5152

@@ -127,7 +128,7 @@ private[tree] object DecisionTreeMetadata {
127128

128129
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
129130
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
130-
strategy.impurity, strategy.quantileCalculationStrategy)
131+
strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth)
131132
}
132133

133134
/**

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,23 @@ class Node (
9393
}
9494
}
9595

96+
/**
97+
* Returns a deep copy of the subtree rooted at this node.
98+
*/
99+
private[tree] def deepCopy(): Node = {
100+
val leftNodeCopy = if (leftNode.isEmpty) {
101+
None
102+
} else {
103+
Some(leftNode.get.deepCopy())
104+
}
105+
val rightNodeCopy = if (rightNode.isEmpty) {
106+
None
107+
} else {
108+
Some(rightNode.get.deepCopy())
109+
}
110+
new Node(id, predict, isLeaf, split, leftNodeCopy, rightNodeCopy, stats)
111+
}
112+
96113
/**
97114
* Get the number of nodes in tree below this node, including leaf nodes.
98115
* E.g., if this is a leaf, returns 0. If both children are leaves, returns 2.
@@ -190,4 +207,22 @@ private[tree] object Node {
190207
*/
191208
def startIndexInLevel(level: Int): Int = 1 << level
192209

210+
/**
211+
* Traces down from a root node to get the node with the given node index.
212+
* This assumes the node exists.
213+
*/
214+
def getNode(nodeIndex: Int, rootNode: Node): Node = {
215+
var tmpNode: Node = rootNode
216+
var levelsToGo = indexToLevel(nodeIndex)
217+
while (levelsToGo > 0) {
218+
if ((nodeIndex & (1 << levelsToGo - 1)) == 0) {
219+
tmpNode = tmpNode.leftNode.get
220+
} else {
221+
tmpNode = tmpNode.rightNode.get
222+
}
223+
levelsToGo -= 1
224+
}
225+
tmpNode
226+
}
227+
193228
}

0 commit comments

Comments
 (0)