Skip to content

Commit f79f77c

Browse files
committed
renamed NodeQueue to NodeStack
1 parent 3c00d03 commit f79f77c

File tree

2 files changed

+28
-28
lines changed

2 files changed

+28
-28
lines changed

mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -170,22 +170,22 @@ private[spark] object RandomForest extends Logging {
170170
training the same tree in the next iteration. This focus allows us to send fewer trees to
171171
workers on each iteration; see topNodesForGroup below.
172172
*/
173-
val nodeQueue = new NodeQueue
173+
val nodeStack = new NodeStack
174174

175175
val rng = new Random()
176176
rng.setSeed(seed)
177177

178178
// Allocate and queue root nodes.
179179
val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1))
180-
Range(0, numTrees).foreach(treeIndex => nodeQueue.put(treeIndex, topNodes(treeIndex)))
180+
Range(0, numTrees).foreach(treeIndex => nodeStack.put(treeIndex, topNodes(treeIndex)))
181181

182182
timer.stop("init")
183183

184-
while (nodeQueue.nonEmpty) {
184+
while (nodeStack.nonEmpty) {
185185
// Collect some nodes to split, and choose features for each node (if subsampling).
186186
// Each group of nodes may come from one or multiple trees, and at multiple levels.
187187
val (nodesForGroup, treeToNodeToIndexInfo) =
188-
RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
188+
RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng)
189189
// Sanity check (should never occur):
190190
assert(nodesForGroup.nonEmpty,
191191
s"RandomForest selected empty nodesForGroup. Error for unknown reason.")
@@ -197,7 +197,7 @@ private[spark] object RandomForest extends Logging {
197197
// Choose node splits, and enqueue new nodes as needed.
198198
timer.start("findBestSplits")
199199
RandomForest.findBestSplits(baggedInput, metadata, topNodesForGroup, nodesForGroup,
200-
treeToNodeToIndexInfo, splits, nodeQueue, timer, nodeIdCache)
200+
treeToNodeToIndexInfo, splits, nodeStack, timer, nodeIdCache)
201201
timer.stop("findBestSplits")
202202
}
203203

@@ -353,7 +353,7 @@ private[spark] object RandomForest extends Logging {
353353
* where nodeIndexInfo stores the index in the group and the
354354
* feature subsets (if using feature subsets).
355355
* @param splits possible splits for all features, indexed (numFeatures)(numSplits)
356-
* @param nodeQueue Queue of nodes to split, with values (treeIndex, node).
356+
* @param nodeStack Queue of nodes to split, with values (treeIndex, node).
357357
* Updated with new non-leaf nodes which are created.
358358
* @param nodeIdCache Node Id cache containing an RDD of Array[Int] where
359359
* each value in the array is the data point's node Id
@@ -368,7 +368,7 @@ private[spark] object RandomForest extends Logging {
368368
nodesForGroup: Map[Int, Array[LearningNode]],
369369
treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]],
370370
splits: Array[Array[Split]],
371-
nodeQueue: NodeQueue,
371+
nodeStack: NodeStack,
372372
timer: TimeTracker = new TimeTracker,
373373
nodeIdCache: Option[NodeIdCache] = None): Unit = {
374374

@@ -607,10 +607,10 @@ private[spark] object RandomForest extends Logging {
607607

608608
// enqueue left child and right child if they are not leaves
609609
if (!leftChildIsLeaf) {
610-
nodeQueue.put(treeIndex, node.leftChild.get)
610+
nodeStack.put(treeIndex, node.leftChild.get)
611611
}
612612
if (!rightChildIsLeaf) {
613-
nodeQueue.put(treeIndex, node.rightChild.get)
613+
nodeStack.put(treeIndex, node.rightChild.get)
614614
}
615615

616616
logDebug("leftChildIndex = " + node.leftChild.get.id +
@@ -1043,7 +1043,7 @@ private[spark] object RandomForest extends Logging {
10431043
* will be needed; this allows an adaptive number of nodes since different nodes may require
10441044
* different amounts of memory (if featureSubsetStrategy is not "all").
10451045
*
1046-
* @param nodeQueue Queue of nodes to split.
1046+
* @param nodeStack Queue of nodes to split.
10471047
* @param maxMemoryUsage Bound on size of aggregate statistics.
10481048
* @return (nodesForGroup, treeToNodeToIndexInfo).
10491049
* nodesForGroup holds the nodes to split: treeIndex --> nodes in tree.
@@ -1055,7 +1055,7 @@ private[spark] object RandomForest extends Logging {
10551055
* The feature indices are None if not subsampling features.
10561056
*/
10571057
private[tree] def selectNodesToSplit(
1058-
nodeQueue: NodeQueue,
1058+
nodeStack: NodeStack,
10591059
maxMemoryUsage: Long,
10601060
metadata: DecisionTreeMetadata,
10611061
rng: Random): (Map[Int, Array[LearningNode]], Map[Int, Map[Int, NodeIndexInfo]]) = {
@@ -1068,8 +1068,8 @@ private[spark] object RandomForest extends Logging {
10681068
var numNodesInGroup = 0
10691069
// If maxMemoryInMB is set very small, we want to still try to split 1 node,
10701070
// so we allow one iteration if memUsage == 0.
1071-
while (nodeQueue.nonEmpty && (memUsage < maxMemoryUsage || memUsage == 0)) {
1072-
val (treeIndex, node) = nodeQueue.peek()
1071+
while (nodeStack.nonEmpty && (memUsage < maxMemoryUsage || memUsage == 0)) {
1072+
val (treeIndex, node) = nodeStack.peek()
10731073
// Choose subset of features for node (if subsampling).
10741074
val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
10751075
Some(SamplingUtils.reservoirSampleAndCount(Range(0,
@@ -1080,7 +1080,7 @@ private[spark] object RandomForest extends Logging {
10801080
// Check if enough memory remains to add this node to the group.
10811081
val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
10821082
if (memUsage + nodeMemUsage <= maxMemoryUsage || memUsage == 0) {
1083-
nodeQueue.pop()
1083+
nodeStack.pop()
10841084
mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[LearningNode]()) +=
10851085
node
10861086
mutableTreeToNodeToIndexInfo
@@ -1126,9 +1126,9 @@ private[spark] object RandomForest extends Logging {
11261126

11271127
/**
11281128
* Class for queueing nodes to split on each iteration.
1129-
* This is a FILO queue.
1129+
* This must be a stack (FILO); see developer note where it is used above.
11301130
*/
1131-
private[impl] class NodeQueue {
1131+
private[impl] class NodeStack {
11321132
private var q: List[(Int, LearningNode)] =
11331133
List.empty[(Int, LearningNode)]
11341134

mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.apache.spark.ml.classification.DecisionTreeClassificationModel
2222
import org.apache.spark.ml.feature.LabeledPoint
2323
import org.apache.spark.ml.linalg.{Vector, Vectors}
2424
import org.apache.spark.ml.tree._
25-
import org.apache.spark.ml.tree.impl.RandomForest.NodeQueue
25+
import org.apache.spark.ml.tree.impl.RandomForest.NodeStack
2626
import org.apache.spark.ml.util.TestingUtils._
2727
import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, EnsembleTestHelper}
2828
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy,
@@ -239,12 +239,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
239239
val treeToNodeToIndexInfo = Map((0, Map(
240240
(topNode.id, new RandomForest.NodeIndexInfo(0, None))
241241
)))
242-
val nodeQueue = new NodeQueue
242+
val nodeStack = new NodeStack
243243
RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode),
244-
nodesForGroup, treeToNodeToIndexInfo, splits, nodeQueue)
244+
nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack)
245245

246246
// don't enqueue leaf nodes into node queue
247-
assert(nodeQueue.isEmpty)
247+
assert(nodeStack.isEmpty)
248248

249249
// set impurity and predict for topNode
250250
assert(topNode.stats !== null)
@@ -281,12 +281,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
281281
val treeToNodeToIndexInfo = Map((0, Map(
282282
(topNode.id, new RandomForest.NodeIndexInfo(0, None))
283283
)))
284-
val nodeQueue = new NodeQueue
284+
val nodeStack = new NodeStack
285285
RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode),
286-
nodesForGroup, treeToNodeToIndexInfo, splits, nodeQueue)
286+
nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack)
287287

288288
// don't enqueue a node into node queue if its impurity is 0.0
289-
assert(nodeQueue.isEmpty)
289+
assert(nodeStack.isEmpty)
290290

291291
// set impurity and predict for topNode
292292
assert(topNode.stats !== null)
@@ -393,16 +393,16 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
393393
val failString = s"Failed on test with:" +
394394
s"numTrees=$numTrees, featureSubsetStrategy=$featureSubsetStrategy," +
395395
s" numFeaturesPerNode=$numFeaturesPerNode, seed=$seed"
396-
val nodeQueue = new NodeQueue
396+
val nodeStack = new NodeStack
397397
val topNodes: Array[LearningNode] = new Array[LearningNode](numTrees)
398398
Range(0, numTrees).foreach { treeIndex =>
399399
topNodes(treeIndex) = LearningNode.emptyNode(nodeIndex = 1)
400-
nodeQueue.put(treeIndex, topNodes(treeIndex))
400+
nodeStack.put(treeIndex, topNodes(treeIndex))
401401
}
402402
val rng = new scala.util.Random(seed = seed)
403403
val (nodesForGroup: Map[Int, Array[LearningNode]],
404404
treeToNodeToIndexInfo: Map[Int, Map[Int, RandomForest.NodeIndexInfo]]) =
405-
RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
405+
RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng)
406406

407407
assert(nodesForGroup.size === numTrees, failString)
408408
assert(nodesForGroup.values.forall(_.length == 1), failString) // 1 node per tree
@@ -547,8 +547,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
547547
assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
548548
}
549549

550-
test("NodeQueue should be FILO") {
551-
val q = new NodeQueue
550+
test("NodeStack should be FILO") {
551+
val q = new NodeStack
552552
Range(0, 5).foreach { idx =>
553553
val node = LearningNode.emptyNode(idx)
554554
q.put(treeIndex = idx, node = node)

0 commit comments

Comments
 (0)