diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index 7bef899a633d9..8e5b80d61d323 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -21,6 +21,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} +import org.apache.spark.ml.util.MultiStopwatch import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy} @@ -246,9 +247,11 @@ private[spark] object GradientBoostedTrees extends Logging { boostingStrategy: OldBoostingStrategy, validate: Boolean, seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = { - val timer = new TimeTracker() - timer.start("total") - timer.start("init") + val multiTimer = new MultiStopwatch(input.sparkContext) + multiTimer.addLocal("total") + multiTimer.addLocal("init") + multiTimer("total").start() + multiTimer("init").start() boostingStrategy.assertValid() @@ -279,14 +282,15 @@ private[spark] object GradientBoostedTrees extends Logging { val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)]( treeStrategy.getCheckpointInterval, input.sparkContext) - timer.stop("init") + multiTimer("init").stop() logDebug("##########") logDebug("Building tree 0") logDebug("##########") // Initialize tree - timer.start("building tree 0") + multiTimer.addLocal("building tree 0") + multiTimer("building tree 0").start() val firstTree = new DecisionTreeRegressor().setSeed(seed) val firstTreeModel = firstTree.train(input, treeStrategy) val firstTreeWeight = 1.0 @@ -299,7 +303,7 @@ private[spark] object GradientBoostedTrees extends Logging { logDebug("error of gbt = " + predError.values.mean()) // Note: A model of type regression is used since we require raw prediction - timer.stop("building tree 0") + multiTimer("building tree 0").stop() var validatePredError: RDD[(Double, Double)] = computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss) @@ -315,13 +319,14 @@ private[spark] object GradientBoostedTrees extends Logging { LabeledPoint(-loss.gradient(pred, point.label), point.features) } - timer.start(s"building tree $m") + multiTimer.addLocal(s"building tree $m") + multiTimer(s"building tree $m").start() logDebug("###################################################") logDebug("Gradient boosting tree iteration " + m) logDebug("###################################################") val dt = new DecisionTreeRegressor().setSeed(seed + m) val model = dt.train(data, treeStrategy) - timer.stop(s"building tree $m") + multiTimer(s"building tree $m").stop() // Update partial model baseLearners(m) = model // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError. @@ -355,10 +360,10 @@ private[spark] object GradientBoostedTrees extends Logging { m += 1 } - timer.stop("total") + multiTimer("total").stop() logInfo("Internal timing for DecisionTree:") - logInfo(s"$timer") + logInfo(s"$multiTimer") predErrorCheckpointer.deleteAllCheckpoints() validatePredErrorCheckpointer.deleteAllCheckpoints() diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 71c8c42ce5eba..03892e129cf5e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.classification.DecisionTreeClassificationModel import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree._ -import org.apache.spark.ml.util.Instrumentation +import org.apache.spark.ml.util.{Instrumentation, LocalStopwatch, MultiStopwatch, Stopwatch} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.ImpurityStats @@ -94,11 +94,13 @@ private[spark] object RandomForest extends Logging { instr: Option[Instrumentation[_]], parentUID: Option[String] = None): Array[DecisionTreeModel] = { - val timer = new TimeTracker() + val timers = new MultiStopwatch(input.sparkContext) - timer.start("total") + timers.addLocal("total") + timers("total").start() - timer.start("init") + timers.addLocal("init") + timers("init").start() val retaggedInput = input.retag(classOf[LabeledPoint]) val metadata = @@ -114,9 +116,11 @@ private[spark] object RandomForest extends Logging { // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. - timer.start("findSplits") + timers.addLocal("findSplitsBins") + timers("findSplitsBins").start() val splits = findSplits(retaggedInput, metadata, seed) - timer.stop("findSplits") + timers("findSplitsBins").stop() + logDebug("numBins: feature: number of bins") logDebug(Range(0, metadata.numFeatures).map { featureIndex => s"\t$featureIndex\t${metadata.numBins(featureIndex)}" @@ -142,6 +146,7 @@ private[spark] object RandomForest extends Logging { val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") + timers("init").stop() /* * The main idea here is to perform group-wise training of the decision tree nodes thus * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup). @@ -170,8 +175,9 @@ private[spark] object RandomForest extends Logging { // Allocate and queue root nodes. val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1)) Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))) - - timer.stop("init") + timers.addLocal("findBestSplits") + timers.addLocal("chooseSplits") + timers.addDistributed("binsToBestSplit") while (nodeQueue.nonEmpty) { // Collect some nodes to split, and choose features for each node (if subsampling). @@ -183,18 +189,18 @@ private[spark] object RandomForest extends Logging { s"RandomForest selected empty nodesForGroup. Error for unknown reason.") // Choose node splits, and enqueue new nodes as needed. - timer.start("findBestSplits") + timers("findBestSplits").start() RandomForest.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup, - treeToNodeToIndexInfo, splits, nodeQueue, timer, nodeIdCache) - timer.stop("findBestSplits") + treeToNodeToIndexInfo, splits, nodeQueue, Option(timers), nodeIdCache) + timers("findBestSplits").stop() } baggedInput.unpersist() - timer.stop("total") + timers("total").stop() logInfo("Internal timing for DecisionTree:") - logInfo(s"$timer") + logInfo(s"$timers") // Delete any remaining checkpoints used for node Id cache. if (nodeIdCache.nonEmpty) { @@ -356,7 +362,7 @@ private[spark] object RandomForest extends Logging { treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]], splits: Array[Array[Split]], nodeQueue: mutable.Queue[(Int, LearningNode)], - timer: TimeTracker = new TimeTracker, + multiStopwatch: Option[MultiStopwatch] = None, nodeIdCache: Option[NodeIdCache] = None): Unit = { /* @@ -479,6 +485,13 @@ private[spark] object RandomForest extends Logging { } } + val timers = multiStopwatch match { + case Some(timers) => timers + case None => new MultiStopwatch(input.sparkContext) + .addLocal("chooseSplits") + .addLocal("binsToBestSplit") + } + // array of nodes to train indexed by node index in group val nodes = new Array[LearningNode](numNodes) nodesForGroup.foreach { case (treeIndex, nodesForTree) => @@ -488,7 +501,7 @@ private[spark] object RandomForest extends Logging { } // Calculate best splits for all nodes in the group - timer.start("chooseSplits") + timers("chooseSplits").start() // In each partition, iterate all instances and compute aggregate stats for each node, // yield a (nodeIndex, nodeAggregateStats) pair for each node. @@ -544,12 +557,14 @@ private[spark] object RandomForest extends Logging { } // find best split for each node + timers("binsToBestSplit").start() val (split: Split, stats: ImpurityStats) = binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) + timers("binsToBestSplit").stop() (nodeIndex, (split, stats)) }.collectAsMap() - timer.stop("chooseSplits") + timers("chooseSplits").stop() val nodeIdUpdaters = if (nodeIdCache.nonEmpty) { Array.fill[mutable.Map[Int, NodeIndexUpdater]]( diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala deleted file mode 100644 index 4cc250aa462e3..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.tree.impl - -import scala.collection.mutable.{HashMap => MutableHashMap} - -/** - * Time tracker implementation which holds labeled timers. - */ -private[spark] class TimeTracker extends Serializable { - - private val starts: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() - - private val totals: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() - - /** - * Starts a new timer, or re-starts a stopped timer. - */ - def start(timerLabel: String): Unit = { - val currentTime = System.nanoTime() - if (starts.contains(timerLabel)) { - throw new RuntimeException(s"TimeTracker.start(timerLabel) called again on" + - s" timerLabel = $timerLabel before that timer was stopped.") - } - starts(timerLabel) = currentTime - } - - /** - * Stops a timer and returns the elapsed time in seconds. - */ - def stop(timerLabel: String): Double = { - val currentTime = System.nanoTime() - if (!starts.contains(timerLabel)) { - throw new RuntimeException(s"TimeTracker.stop(timerLabel) called on" + - s" timerLabel = $timerLabel, but that timer was not started.") - } - val elapsed = currentTime - starts(timerLabel) - starts.remove(timerLabel) - if (totals.contains(timerLabel)) { - totals(timerLabel) += elapsed - } else { - totals(timerLabel) = elapsed - } - elapsed / 1e9 - } - - /** - * Print all timing results in seconds. - */ - override def toString: String = { - totals.map { case (label, elapsed) => - s" $label: ${elapsed / 1e9}" - }.mkString("\n") - } -}