Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
import org.apache.spark.ml.util.MultiStopwatch
import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy}
Expand Down Expand Up @@ -246,9 +247,11 @@ private[spark] object GradientBoostedTrees extends Logging {
boostingStrategy: OldBoostingStrategy,
validate: Boolean,
seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
val timer = new TimeTracker()
timer.start("total")
timer.start("init")
val multiTimer = new MultiStopwatch(input.sparkContext)
multiTimer.addLocal("total")
multiTimer.addLocal("init")
multiTimer("total").start()
multiTimer("init").start()

boostingStrategy.assertValid()

Expand Down Expand Up @@ -279,14 +282,15 @@ private[spark] object GradientBoostedTrees extends Logging {
val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
treeStrategy.getCheckpointInterval, input.sparkContext)

timer.stop("init")
multiTimer("init").stop()

logDebug("##########")
logDebug("Building tree 0")
logDebug("##########")

// Initialize tree
timer.start("building tree 0")
multiTimer.addLocal("building tree 0")
multiTimer("building tree 0").start()
val firstTree = new DecisionTreeRegressor().setSeed(seed)
val firstTreeModel = firstTree.train(input, treeStrategy)
val firstTreeWeight = 1.0
Expand All @@ -299,7 +303,7 @@ private[spark] object GradientBoostedTrees extends Logging {
logDebug("error of gbt = " + predError.values.mean())

// Note: A model of type regression is used since we require raw prediction
timer.stop("building tree 0")
multiTimer("building tree 0").stop()

var validatePredError: RDD[(Double, Double)] =
computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss)
Expand All @@ -315,13 +319,14 @@ private[spark] object GradientBoostedTrees extends Logging {
LabeledPoint(-loss.gradient(pred, point.label), point.features)
}

timer.start(s"building tree $m")
multiTimer.addLocal(s"building tree $m")
multiTimer(s"building tree $m").start()
logDebug("###################################################")
logDebug("Gradient boosting tree iteration " + m)
logDebug("###################################################")
val dt = new DecisionTreeRegressor().setSeed(seed + m)
val model = dt.train(data, treeStrategy)
timer.stop(s"building tree $m")
multiTimer(s"building tree $m").stop()
// Update partial model
baseLearners(m) = model
// Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
Expand Down Expand Up @@ -355,10 +360,10 @@ private[spark] object GradientBoostedTrees extends Logging {
m += 1
}

timer.stop("total")
multiTimer("total").stop()

logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")
logInfo(s"$multiTimer")

predErrorCheckpointer.deleteAllCheckpoints()
validatePredErrorCheckpointer.deleteAllCheckpoints()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree._
import org.apache.spark.ml.util.Instrumentation
import org.apache.spark.ml.util.{Instrumentation, LocalStopwatch, MultiStopwatch, Stopwatch}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.ImpurityStats
Expand Down Expand Up @@ -94,11 +94,13 @@ private[spark] object RandomForest extends Logging {
instr: Option[Instrumentation[_]],
parentUID: Option[String] = None): Array[DecisionTreeModel] = {

val timer = new TimeTracker()
val timers = new MultiStopwatch(input.sparkContext)

timer.start("total")
timers.addLocal("total")
timers("total").start()

timer.start("init")
timers.addLocal("init")
timers("init").start()

val retaggedInput = input.retag(classOf[LabeledPoint])
val metadata =
Expand All @@ -114,9 +116,11 @@ private[spark] object RandomForest extends Logging {

// Find the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
timer.start("findSplits")
timers.addLocal("findSplitsBins")
timers("findSplitsBins").start()
val splits = findSplits(retaggedInput, metadata, seed)
timer.stop("findSplits")
timers("findSplitsBins").stop()

logDebug("numBins: feature: number of bins")
logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
Expand All @@ -142,6 +146,7 @@ private[spark] object RandomForest extends Logging {
val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L
logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")

timers("init").stop()
/*
* The main idea here is to perform group-wise training of the decision tree nodes thus
* reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup).
Expand Down Expand Up @@ -170,8 +175,9 @@ private[spark] object RandomForest extends Logging {
// Allocate and queue root nodes.
val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1))
Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))

timer.stop("init")
timers.addLocal("findBestSplits")
timers.addLocal("chooseSplits")
timers.addDistributed("binsToBestSplit")

while (nodeQueue.nonEmpty) {
// Collect some nodes to split, and choose features for each node (if subsampling).
Expand All @@ -183,18 +189,18 @@ private[spark] object RandomForest extends Logging {
s"RandomForest selected empty nodesForGroup. Error for unknown reason.")

// Choose node splits, and enqueue new nodes as needed.
timer.start("findBestSplits")
timers("findBestSplits").start()
RandomForest.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
treeToNodeToIndexInfo, splits, nodeQueue, timer, nodeIdCache)
timer.stop("findBestSplits")
treeToNodeToIndexInfo, splits, nodeQueue, Option(timers), nodeIdCache)
timers("findBestSplits").stop()
}

baggedInput.unpersist()

timer.stop("total")
timers("total").stop()

logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")
logInfo(s"$timers")

// Delete any remaining checkpoints used for node Id cache.
if (nodeIdCache.nonEmpty) {
Expand Down Expand Up @@ -356,7 +362,7 @@ private[spark] object RandomForest extends Logging {
treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]],
splits: Array[Array[Split]],
nodeQueue: mutable.Queue[(Int, LearningNode)],
timer: TimeTracker = new TimeTracker,
multiStopwatch: Option[MultiStopwatch] = None,
nodeIdCache: Option[NodeIdCache] = None): Unit = {

/*
Expand Down Expand Up @@ -479,6 +485,13 @@ private[spark] object RandomForest extends Logging {
}
}

val timers = multiStopwatch match {
case Some(timers) => timers
case None => new MultiStopwatch(input.sparkContext)
.addLocal("chooseSplits")
.addLocal("binsToBestSplit")
}

// array of nodes to train indexed by node index in group
val nodes = new Array[LearningNode](numNodes)
nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
Expand All @@ -488,7 +501,7 @@ private[spark] object RandomForest extends Logging {
}

// Calculate best splits for all nodes in the group
timer.start("chooseSplits")
timers("chooseSplits").start()

// In each partition, iterate all instances and compute aggregate stats for each node,
// yield a (nodeIndex, nodeAggregateStats) pair for each node.
Expand Down Expand Up @@ -544,12 +557,14 @@ private[spark] object RandomForest extends Logging {
}

// find best split for each node
timers("binsToBestSplit").start()
val (split: Split, stats: ImpurityStats) =
binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
timers("binsToBestSplit").stop()
(nodeIndex, (split, stats))
}.collectAsMap()

timer.stop("chooseSplits")
timers("chooseSplits").stop()

val nodeIdUpdaters = if (nodeIdCache.nonEmpty) {
Array.fill[mutable.Map[Int, NodeIndexUpdater]](
Expand Down

This file was deleted.