From fc055532ff3afac0df14e5ff8b63358f9410eae6 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Sun, 2 Aug 2015 12:25:30 -0400 Subject: [PATCH 1/7] Initial draft --- .../spark/ml/tree/impl/RandomForest.scala | 38 +++++++++++-------- .../spark/mllib/tree/DecisionTree.scala | 4 +- .../spark/mllib/tree/RandomForest.scala | 1 + 3 files changed, 26 insertions(+), 17 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 71c8c42ce5eba..e30c8ce56de8a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -27,7 +27,8 @@ import org.apache.spark.ml.classification.DecisionTreeClassificationModel import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree._ -import org.apache.spark.ml.util.Instrumentation +import org.apache.spark.ml.util.{Instrumentation, Stopwatch, LocalStopwatch, MultiStopwatch} +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.ImpurityStats @@ -94,11 +95,13 @@ private[spark] object RandomForest extends Logging { instr: Option[Instrumentation[_]], parentUID: Option[String] = None): Array[DecisionTreeModel] = { - val timer = new TimeTracker() + val multiTimer = new MultiStopwatch(input.sparkContext) - timer.start("total") + multiTimer.addLocal("total") + multiTimer("total").start() - timer.start("init") + multiTimer.addLocal("init") + multiTimer("init").start() val retaggedInput = input.retag(classOf[LabeledPoint]) val metadata = @@ -114,9 +117,11 @@ private[spark] object RandomForest extends Logging { // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. - timer.start("findSplits") + multiTimer.addLocal("findSplitsBins") + multiTimer("findSplitsBins").start() val splits = findSplits(retaggedInput, metadata, seed) - timer.stop("findSplits") + multiTimer("findSplitsBins").stop() + logDebug("numBins: feature: number of bins") logDebug(Range(0, metadata.numFeatures).map { featureIndex => s"\t$featureIndex\t${metadata.numBins(featureIndex)}" @@ -142,6 +147,7 @@ private[spark] object RandomForest extends Logging { val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") + multiTimer("init").stop() /* * The main idea here is to perform group-wise training of the decision tree nodes thus * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup). @@ -170,8 +176,8 @@ private[spark] object RandomForest extends Logging { // Allocate and queue root nodes. val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1)) Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))) - - timer.stop("init") + multiTimer.addLocal("findBestSplits") + multiTimer.addLocal("chooseSplits") while (nodeQueue.nonEmpty) { // Collect some nodes to split, and choose features for each node (if subsampling). @@ -183,18 +189,18 @@ private[spark] object RandomForest extends Logging { s"RandomForest selected empty nodesForGroup. Error for unknown reason.") // Choose node splits, and enqueue new nodes as needed. - timer.start("findBestSplits") + multiTimer("findBestSplits").start() RandomForest.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup, - treeToNodeToIndexInfo, splits, nodeQueue, timer, nodeIdCache) - timer.stop("findBestSplits") + treeToNodeToIndexInfo, splits, nodeQueue, multiTimer("chooseSplits"), nodeIdCache) + multiTimer("findBestSplits").stop() } baggedInput.unpersist() - timer.stop("total") + multiTimer("total").stop() logInfo("Internal timing for DecisionTree:") - logInfo(s"$timer") + logInfo(s"$multiTimer") // Delete any remaining checkpoints used for node Id cache. if (nodeIdCache.nonEmpty) { @@ -356,7 +362,7 @@ private[spark] object RandomForest extends Logging { treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]], splits: Array[Array[Split]], nodeQueue: mutable.Queue[(Int, LearningNode)], - timer: TimeTracker = new TimeTracker, + timer: Stopwatch = new LocalStopwatch("chooseSplits"), nodeIdCache: Option[NodeIdCache] = None): Unit = { /* @@ -488,7 +494,7 @@ private[spark] object RandomForest extends Logging { } // Calculate best splits for all nodes in the group - timer.start("chooseSplits") + timer.start() // In each partition, iterate all instances and compute aggregate stats for each node, // yield a (nodeIndex, nodeAggregateStats) pair for each node. @@ -549,7 +555,7 @@ private[spark] object RandomForest extends Logging { (nodeIndex, (split, stats)) }.collectAsMap() - timer.stop("chooseSplits") + timer.stop() val nodeIdUpdaters = if (nodeIdCache.nonEmpty) { Array.fill[mutable.Map[Int, NodeIndexUpdater]]( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 36feab7859b43..2cd9b1be717f1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -19,7 +19,9 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ -import org.apache.spark.annotation.Since +import org.apache.spark.ml.util.{Stopwatch, LocalStopwatch} +import org.apache.spark.Logging +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.mllib.regression.LabeledPoint diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 14f11ce51b878..54f8e13e138b7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -25,6 +25,7 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel, RandomForestParams => NewRFParams} import org.apache.spark.ml.tree.impl.{RandomForest => NewRandomForest} +import org.apache.spark.ml.util.MultiStopwatch import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ From c981ad554fa4706fbda40b42acbe4b275a2dbf47 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 19 Jul 2016 14:50:12 -0700 Subject: [PATCH 2/7] Remove unused import --- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 4 +--- .../main/scala/org/apache/spark/mllib/tree/RandomForest.scala | 1 - 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 2cd9b1be717f1..36feab7859b43 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -19,9 +19,7 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ -import org.apache.spark.ml.util.{Stopwatch, LocalStopwatch} -import org.apache.spark.Logging -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.mllib.regression.LabeledPoint diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 54f8e13e138b7..14f11ce51b878 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -25,7 +25,6 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel, RandomForestParams => NewRFParams} import org.apache.spark.ml.tree.impl.{RandomForest => NewRandomForest} -import org.apache.spark.ml.util.MultiStopwatch import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ From ea9caf497f392b9149572a2ff4fcefed9d66f9ab Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 19 Jul 2016 15:32:19 -0700 Subject: [PATCH 3/7] Add MultiStopWatch to GBT's --- .../ml/tree/impl/GradientBoostedTrees.scala | 25 +++++++++++-------- .../spark/ml/tree/impl/RandomForest.scala | 3 +-- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index 7bef899a633d9..8e5b80d61d323 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -21,6 +21,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} +import org.apache.spark.ml.util.MultiStopwatch import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy} @@ -246,9 +247,11 @@ private[spark] object GradientBoostedTrees extends Logging { boostingStrategy: OldBoostingStrategy, validate: Boolean, seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = { - val timer = new TimeTracker() - timer.start("total") - timer.start("init") + val multiTimer = new MultiStopwatch(input.sparkContext) + multiTimer.addLocal("total") + multiTimer.addLocal("init") + multiTimer("total").start() + multiTimer("init").start() boostingStrategy.assertValid() @@ -279,14 +282,15 @@ private[spark] object GradientBoostedTrees extends Logging { val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)]( treeStrategy.getCheckpointInterval, input.sparkContext) - timer.stop("init") + multiTimer("init").stop() logDebug("##########") logDebug("Building tree 0") logDebug("##########") // Initialize tree - timer.start("building tree 0") + multiTimer.addLocal("building tree 0") + multiTimer("building tree 0").start() val firstTree = new DecisionTreeRegressor().setSeed(seed) val firstTreeModel = firstTree.train(input, treeStrategy) val firstTreeWeight = 1.0 @@ -299,7 +303,7 @@ private[spark] object GradientBoostedTrees extends Logging { logDebug("error of gbt = " + predError.values.mean()) // Note: A model of type regression is used since we require raw prediction - timer.stop("building tree 0") + multiTimer("building tree 0").stop() var validatePredError: RDD[(Double, Double)] = computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss) @@ -315,13 +319,14 @@ private[spark] object GradientBoostedTrees extends Logging { LabeledPoint(-loss.gradient(pred, point.label), point.features) } - timer.start(s"building tree $m") + multiTimer.addLocal(s"building tree $m") + multiTimer(s"building tree $m").start() logDebug("###################################################") logDebug("Gradient boosting tree iteration " + m) logDebug("###################################################") val dt = new DecisionTreeRegressor().setSeed(seed + m) val model = dt.train(data, treeStrategy) - timer.stop(s"building tree $m") + multiTimer(s"building tree $m").stop() // Update partial model baseLearners(m) = model // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError. @@ -355,10 +360,10 @@ private[spark] object GradientBoostedTrees extends Logging { m += 1 } - timer.stop("total") + multiTimer("total").stop() logInfo("Internal timing for DecisionTree:") - logInfo(s"$timer") + logInfo(s"$multiTimer") predErrorCheckpointer.deleteAllCheckpoints() validatePredErrorCheckpointer.deleteAllCheckpoints() diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index e30c8ce56de8a..7e84b2215d388 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -27,8 +27,7 @@ import org.apache.spark.ml.classification.DecisionTreeClassificationModel import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree._ -import org.apache.spark.ml.util.{Instrumentation, Stopwatch, LocalStopwatch, MultiStopwatch} -import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.ml.util.{Instrumentation, LocalStopwatch, MultiStopwatch, Stopwatch} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.ImpurityStats From 7cb2fa09232f8512b018e0673d9b2d4402f88c86 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 19 Jul 2016 15:33:37 -0700 Subject: [PATCH 4/7] Remove TimeTracker --- .../spark/ml/tree/impl/TimeTracker.scala | 70 ------------------- 1 file changed, 70 deletions(-) delete mode 100644 mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala deleted file mode 100644 index 4cc250aa462e3..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.tree.impl - -import scala.collection.mutable.{HashMap => MutableHashMap} - -/** - * Time tracker implementation which holds labeled timers. - */ -private[spark] class TimeTracker extends Serializable { - - private val starts: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() - - private val totals: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() - - /** - * Starts a new timer, or re-starts a stopped timer. - */ - def start(timerLabel: String): Unit = { - val currentTime = System.nanoTime() - if (starts.contains(timerLabel)) { - throw new RuntimeException(s"TimeTracker.start(timerLabel) called again on" + - s" timerLabel = $timerLabel before that timer was stopped.") - } - starts(timerLabel) = currentTime - } - - /** - * Stops a timer and returns the elapsed time in seconds. - */ - def stop(timerLabel: String): Double = { - val currentTime = System.nanoTime() - if (!starts.contains(timerLabel)) { - throw new RuntimeException(s"TimeTracker.stop(timerLabel) called on" + - s" timerLabel = $timerLabel, but that timer was not started.") - } - val elapsed = currentTime - starts(timerLabel) - starts.remove(timerLabel) - if (totals.contains(timerLabel)) { - totals(timerLabel) += elapsed - } else { - totals(timerLabel) = elapsed - } - elapsed / 1e9 - } - - /** - * Print all timing results in seconds. - */ - override def toString: String = { - totals.map { case (label, elapsed) => - s" $label: ${elapsed / 1e9}" - }.mkString("\n") - } -} From 3dd9b3135722aa937b04052501876dc2b3ebb06f Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 19 Jul 2016 16:21:50 -0700 Subject: [PATCH 5/7] Pass MultiStopWatch instead of LocalStopWatch --- .../org/apache/spark/ml/tree/impl/RandomForest.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 7e84b2215d388..8dd6781f1ba0d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -190,7 +190,7 @@ private[spark] object RandomForest extends Logging { // Choose node splits, and enqueue new nodes as needed. multiTimer("findBestSplits").start() RandomForest.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup, - treeToNodeToIndexInfo, splits, nodeQueue, multiTimer("chooseSplits"), nodeIdCache) + treeToNodeToIndexInfo, splits, nodeQueue, multiTimer, nodeIdCache) multiTimer("findBestSplits").stop() } @@ -361,7 +361,7 @@ private[spark] object RandomForest extends Logging { treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]], splits: Array[Array[Split]], nodeQueue: mutable.Queue[(Int, LearningNode)], - timer: Stopwatch = new LocalStopwatch("chooseSplits"), + multiTimer: MultiStopwatch, nodeIdCache: Option[NodeIdCache] = None): Unit = { /* @@ -493,7 +493,7 @@ private[spark] object RandomForest extends Logging { } // Calculate best splits for all nodes in the group - timer.start() + multiTimer("chooseSplits").start() // In each partition, iterate all instances and compute aggregate stats for each node, // yield a (nodeIndex, nodeAggregateStats) pair for each node. @@ -554,7 +554,7 @@ private[spark] object RandomForest extends Logging { (nodeIndex, (split, stats)) }.collectAsMap() - timer.stop() + multiTimer("chooseSplits").stop() val nodeIdUpdaters = if (nodeIdCache.nonEmpty) { Array.fill[mutable.Map[Int, NodeIndexUpdater]]( From e5b077de8a901bae666ff25d2e1800caf622681b Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 19 Jul 2016 16:48:51 -0700 Subject: [PATCH 6/7] add distributed timer to multitimer --- .../scala/org/apache/spark/ml/tree/impl/RandomForest.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 8dd6781f1ba0d..d479ac038dff7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -177,6 +177,7 @@ private[spark] object RandomForest extends Logging { Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))) multiTimer.addLocal("findBestSplits") multiTimer.addLocal("chooseSplits") + multiTimer.addDistributed("binsToBestSplit") while (nodeQueue.nonEmpty) { // Collect some nodes to split, and choose features for each node (if subsampling). @@ -549,8 +550,10 @@ private[spark] object RandomForest extends Logging { } // find best split for each node + multiTimer("binsToBestSplit").start() val (split: Split, stats: ImpurityStats) = binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) + multiTimer("binsToBestSplit").stop() (nodeIndex, (split, stats)) }.collectAsMap() From abff51bd9ec9e5964d1342c8874bab58f152766c Mon Sep 17 00:00:00 2001 From: MechCoder Date: Wed, 20 Jul 2016 11:54:55 -0700 Subject: [PATCH 7/7] Make MultiStopWatch optional --- .../spark/ml/tree/impl/RandomForest.scala | 51 +++++++++++-------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index d479ac038dff7..03892e129cf5e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -94,13 +94,13 @@ private[spark] object RandomForest extends Logging { instr: Option[Instrumentation[_]], parentUID: Option[String] = None): Array[DecisionTreeModel] = { - val multiTimer = new MultiStopwatch(input.sparkContext) + val timers = new MultiStopwatch(input.sparkContext) - multiTimer.addLocal("total") - multiTimer("total").start() + timers.addLocal("total") + timers("total").start() - multiTimer.addLocal("init") - multiTimer("init").start() + timers.addLocal("init") + timers("init").start() val retaggedInput = input.retag(classOf[LabeledPoint]) val metadata = @@ -116,10 +116,10 @@ private[spark] object RandomForest extends Logging { // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. - multiTimer.addLocal("findSplitsBins") - multiTimer("findSplitsBins").start() + timers.addLocal("findSplitsBins") + timers("findSplitsBins").start() val splits = findSplits(retaggedInput, metadata, seed) - multiTimer("findSplitsBins").stop() + timers("findSplitsBins").stop() logDebug("numBins: feature: number of bins") logDebug(Range(0, metadata.numFeatures).map { featureIndex => @@ -146,7 +146,7 @@ private[spark] object RandomForest extends Logging { val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") - multiTimer("init").stop() + timers("init").stop() /* * The main idea here is to perform group-wise training of the decision tree nodes thus * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup). @@ -175,9 +175,9 @@ private[spark] object RandomForest extends Logging { // Allocate and queue root nodes. val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1)) Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))) - multiTimer.addLocal("findBestSplits") - multiTimer.addLocal("chooseSplits") - multiTimer.addDistributed("binsToBestSplit") + timers.addLocal("findBestSplits") + timers.addLocal("chooseSplits") + timers.addDistributed("binsToBestSplit") while (nodeQueue.nonEmpty) { // Collect some nodes to split, and choose features for each node (if subsampling). @@ -189,18 +189,18 @@ private[spark] object RandomForest extends Logging { s"RandomForest selected empty nodesForGroup. Error for unknown reason.") // Choose node splits, and enqueue new nodes as needed. - multiTimer("findBestSplits").start() + timers("findBestSplits").start() RandomForest.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup, - treeToNodeToIndexInfo, splits, nodeQueue, multiTimer, nodeIdCache) - multiTimer("findBestSplits").stop() + treeToNodeToIndexInfo, splits, nodeQueue, Option(timers), nodeIdCache) + timers("findBestSplits").stop() } baggedInput.unpersist() - multiTimer("total").stop() + timers("total").stop() logInfo("Internal timing for DecisionTree:") - logInfo(s"$multiTimer") + logInfo(s"$timers") // Delete any remaining checkpoints used for node Id cache. if (nodeIdCache.nonEmpty) { @@ -362,7 +362,7 @@ private[spark] object RandomForest extends Logging { treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]], splits: Array[Array[Split]], nodeQueue: mutable.Queue[(Int, LearningNode)], - multiTimer: MultiStopwatch, + multiStopwatch: Option[MultiStopwatch] = None, nodeIdCache: Option[NodeIdCache] = None): Unit = { /* @@ -485,6 +485,13 @@ private[spark] object RandomForest extends Logging { } } + val timers = multiStopwatch match { + case Some(timers) => timers + case None => new MultiStopwatch(input.sparkContext) + .addLocal("chooseSplits") + .addLocal("binsToBestSplit") + } + // array of nodes to train indexed by node index in group val nodes = new Array[LearningNode](numNodes) nodesForGroup.foreach { case (treeIndex, nodesForTree) => @@ -494,7 +501,7 @@ private[spark] object RandomForest extends Logging { } // Calculate best splits for all nodes in the group - multiTimer("chooseSplits").start() + timers("chooseSplits").start() // In each partition, iterate all instances and compute aggregate stats for each node, // yield a (nodeIndex, nodeAggregateStats) pair for each node. @@ -550,14 +557,14 @@ private[spark] object RandomForest extends Logging { } // find best split for each node - multiTimer("binsToBestSplit").start() + timers("binsToBestSplit").start() val (split: Split, stats: ImpurityStats) = binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) - multiTimer("binsToBestSplit").stop() + timers("binsToBestSplit").stop() (nodeIndex, (split, stats)) }.collectAsMap() - multiTimer("chooseSplits").stop() + timers("chooseSplits").stop() val nodeIdUpdaters = if (nodeIdCache.nonEmpty) { Array.fill[mutable.Map[Int, NodeIndexUpdater]](