Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ object DecisionTreeRunner {
case Variance => impurity.Variance
}

params.checkpointDir.foreach(sc.setCheckpointDir)

val strategy
= new Strategy(
algo = params.algo,
Expand All @@ -282,7 +284,6 @@ object DecisionTreeRunner {
minInstancesPerNode = params.minInstancesPerNode,
minInfoGain = params.minInfoGain,
useNodeIdCache = params.useNodeIdCache,
checkpointDir = params.checkpointDir,
checkpointInterval = params.checkpointInterval)
if (params.numTrees == 1) {
val startTime = System.nanoTime()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ private class RandomForest (
Some(NodeIdCache.init(
data = baggedInput,
numTrees = numTrees,
checkpointDir = strategy.checkpointDir,
checkpointInterval = strategy.checkpointInterval,
initVal = 1))
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,10 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* @param subsamplingRate Fraction of the training data used for learning decision tree.
* @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will
* maintain a separate RDD of node Id cache for each row.
* @param checkpointDir If the node Id cache is used, it will help to checkpoint
* the node Id cache periodically. This is the checkpoint directory
* to be used for the node Id cache.
* @param checkpointInterval How often to checkpoint when the node Id cache gets updated.
* E.g. 10 means that the cache will get checkpointed every 10 updates.
* E.g. 10 means that the cache will get checkpointed every 10 updates. If
* the checkpoint directory is not set in
* [[org.apache.spark.SparkContext]], this setting is ignored.
*/
@Experimental
class Strategy (
Expand All @@ -82,7 +81,6 @@ class Strategy (
@BeanProperty var maxMemoryInMB: Int = 256,
@BeanProperty var subsamplingRate: Double = 1,
@BeanProperty var useNodeIdCache: Boolean = false,
@BeanProperty var checkpointDir: Option[String] = None,
@BeanProperty var checkpointInterval: Int = 10) extends Serializable {

def isMulticlassClassification =
Expand Down Expand Up @@ -165,7 +163,7 @@ class Strategy (
def copy: Strategy = {
new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain,
maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointDir, checkpointInterval)
maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,12 @@ private[tree] case class NodeIndexUpdater(
* The nodeIdsForInstances RDD needs to be updated at each iteration.
* @param nodeIdsForInstances The initial values in the cache
* (should be an Array of all 1's (meaning the root nodes)).
* @param checkpointDir The checkpoint directory where
* the checkpointed files will be stored.
* @param checkpointInterval The checkpointing interval
* (how often should the cache be checkpointed.).
*/
@DeveloperApi
private[tree] class NodeIdCache(
var nodeIdsForInstances: RDD[Array[Int]],
val checkpointDir: Option[String],
val checkpointInterval: Int) {

// Keep a reference to a previous node Ids for instances.
Expand All @@ -91,12 +88,6 @@ private[tree] class NodeIdCache(
private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]()
private var rddUpdateCount = 0

// If a checkpoint directory is given, and there's no prior checkpoint directory,
// then set the checkpoint directory with the given one.
if (checkpointDir.nonEmpty && nodeIdsForInstances.sparkContext.getCheckpointDir.isEmpty) {
nodeIdsForInstances.sparkContext.setCheckpointDir(checkpointDir.get)
}

/**
* Update the node index values in the cache.
* This updates the RDD and its lineage.
Expand Down Expand Up @@ -184,7 +175,6 @@ private[tree] object NodeIdCache {
* Initialize the node Id cache with initial node Id values.
* @param data The RDD of training rows.
* @param numTrees The number of trees that we want to create cache for.
* @param checkpointDir The checkpoint directory where the checkpointed files will be stored.
* @param checkpointInterval The checkpointing interval
* (how often should the cache be checkpointed.).
* @param initVal The initial values in the cache.
Expand All @@ -193,12 +183,10 @@ private[tree] object NodeIdCache {
def init(
data: RDD[BaggedPoint[TreePoint]],
numTrees: Int,
checkpointDir: Option[String],
checkpointInterval: Int,
initVal: Int = 1): NodeIdCache = {
new NodeIdCache(
data.map(_ => Array.fill[Int](numTrees)(initVal)),
checkpointDir,
checkpointInterval)
}
}