Skip to content

Commit d78fbcc

Browse files
BenFradetjkbradley
authored andcommitted
[SPARK-14570][ML] Log instrumentation in Random forests
## What changes were proposed in this pull request? Added Instrumentation logging to DecisionTree{Classifier,Regressor} and RandomForest{Classifier,Regressor} ## How was this patch tested? No tests involved since it's logging related. Author: BenFradet <[email protected]> Closes #12536 from BenFradet/SPARK-14570.
1 parent af32f4a commit d78fbcc

File tree

8 files changed

+81
-33
lines changed

8 files changed

+81
-33
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,17 +88,30 @@ class DecisionTreeClassifier @Since("1.4.0") (
8888
val numClasses: Int = getNumClasses(dataset)
8989
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
9090
val strategy = getOldStrategy(categoricalFeatures, numClasses)
91+
92+
val instr = Instrumentation.create(this, oldDataset)
93+
instr.logParams(params: _*)
94+
9195
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
92-
seed = $(seed), parentUID = Some(uid))
93-
trees.head.asInstanceOf[DecisionTreeClassificationModel]
96+
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
97+
98+
val m = trees.head.asInstanceOf[DecisionTreeClassificationModel]
99+
instr.logSuccess(m)
100+
m
94101
}
95102

96103
/** (private[ml]) Train a decision tree on an RDD */
97104
private[ml] def train(data: RDD[LabeledPoint],
98105
oldStrategy: OldStrategy): DecisionTreeClassificationModel = {
106+
val instr = Instrumentation.create(this, data)
107+
instr.logParams(params: _*)
108+
99109
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
100-
seed = 0L, parentUID = Some(uid))
101-
trees.head.asInstanceOf[DecisionTreeClassificationModel]
110+
seed = 0L, instr = Some(instr), parentUID = Some(uid))
111+
112+
val m = trees.head.asInstanceOf[DecisionTreeClassificationModel]
113+
instr.logSuccess(m)
114+
m
102115
}
103116

104117
/** (private[ml]) Create a Strategy instance to use with the old API. */

mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,18 @@ class RandomForestClassifier @Since("1.4.0") (
105105
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
106106
val strategy =
107107
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
108-
val trees =
109-
RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
110-
.map(_.asInstanceOf[DecisionTreeClassificationModel])
108+
109+
val instr = Instrumentation.create(this, oldDataset)
110+
instr.logParams(params: _*)
111+
112+
val trees = RandomForest
113+
.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
114+
.map(_.asInstanceOf[DecisionTreeClassificationModel])
115+
111116
val numFeatures = oldDataset.first().features.size
112-
new RandomForestClassificationModel(trees, numFeatures, numClasses)
117+
val m = new RandomForestClassificationModel(trees, numFeatures, numClasses)
118+
instr.logSuccess(m)
119+
m
113120
}
114121

115122
@Since("1.4.1")

mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,17 +88,30 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
8888
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
8989
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
9090
val strategy = getOldStrategy(categoricalFeatures)
91+
92+
val instr = Instrumentation.create(this, oldDataset)
93+
instr.logParams(params: _*)
94+
9195
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
92-
seed = $(seed), parentUID = Some(uid))
93-
trees.head.asInstanceOf[DecisionTreeRegressionModel]
96+
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
97+
98+
val m = trees.head.asInstanceOf[DecisionTreeRegressionModel]
99+
instr.logSuccess(m)
100+
m
94101
}
95102

96103
/** (private[ml]) Train a decision tree on an RDD */
97104
private[ml] def train(data: RDD[LabeledPoint],
98105
oldStrategy: OldStrategy): DecisionTreeRegressionModel = {
106+
val instr = Instrumentation.create(this, data)
107+
instr.logParams(params: _*)
108+
99109
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
100-
seed = $(seed), parentUID = Some(uid))
101-
trees.head.asInstanceOf[DecisionTreeRegressionModel]
110+
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
111+
112+
val m = trees.head.asInstanceOf[DecisionTreeRegressionModel]
113+
instr.logSuccess(m)
114+
m
102115
}
103116

104117
/** (private[ml]) Create a Strategy instance to use with the old API. */
@@ -167,7 +180,7 @@ class DecisionTreeRegressionModel private[ml] (
167180
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
168181
val predictUDF = udf { (features: Vector) => predict(features) }
169182
val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) }
170-
var output = dataset.toDF
183+
var output = dataset.toDF()
171184
if ($(predictionCol).nonEmpty) {
172185
output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
173186
}

mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,18 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
9999
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
100100
val strategy =
101101
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
102-
val trees =
103-
RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
104-
.map(_.asInstanceOf[DecisionTreeRegressionModel])
102+
103+
val instr = Instrumentation.create(this, oldDataset)
104+
instr.logParams(params: _*)
105+
106+
val trees = RandomForest
107+
.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
108+
.map(_.asInstanceOf[DecisionTreeRegressionModel])
109+
105110
val numFeatures = oldDataset.first().features.size
106-
new RandomForestRegressionModel(trees, numFeatures)
111+
val m = new RandomForestRegressionModel(trees, numFeatures)
112+
instr.logSuccess(m)
113+
m
107114
}
108115

109116
@Since("1.4.0")

mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.internal.Logging
2626
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
2727
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
2828
import org.apache.spark.ml.tree._
29+
import org.apache.spark.ml.util.Instrumentation
2930
import org.apache.spark.mllib.regression.LabeledPoint
3031
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
3132
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
@@ -80,6 +81,7 @@ private[spark] object RandomForest extends Logging {
8081

8182
/**
8283
* Train a random forest.
84+
*
8385
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
8486
* @return an unweighted set of trees
8587
*/
@@ -89,6 +91,7 @@ private[spark] object RandomForest extends Logging {
8991
numTrees: Int,
9092
featureSubsetStrategy: String,
9193
seed: Long,
94+
instr: Option[Instrumentation[_]],
9295
parentUID: Option[String] = None): Array[DecisionTreeModel] = {
9396

9497
val timer = new TimeTracker()
@@ -100,13 +103,14 @@ private[spark] object RandomForest extends Logging {
100103
val retaggedInput = input.retag(classOf[LabeledPoint])
101104
val metadata =
102105
DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
103-
logDebug("algo = " + strategy.algo)
104-
logDebug("numTrees = " + numTrees)
105-
logDebug("seed = " + seed)
106-
logDebug("maxBins = " + metadata.maxBins)
107-
logDebug("featureSubsetStrategy = " + featureSubsetStrategy)
108-
logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode)
109-
logDebug("subsamplingRate = " + strategy.subsamplingRate)
106+
instr match {
107+
case Some(instrumentation) =>
108+
instrumentation.logNumFeatures(metadata.numFeatures)
109+
instrumentation.logNumClasses(metadata.numClasses)
110+
case None =>
111+
logInfo("numFeatures: " + metadata.numFeatures)
112+
logInfo("numClasses: " + metadata.numClasses)
113+
}
110114

111115
// Find the splits and the corresponding bins (interval between the splits) using a sample
112116
// of the input data.
@@ -610,7 +614,9 @@ private[spark] object RandomForest extends Logging {
610614
}
611615

612616
/**
613-
* Calculate the impurity statistics for a give (feature, split) based upon left/right aggregates.
617+
* Calculate the impurity statistics for a given (feature, split) based upon left/right
618+
* aggregates.
619+
*
614620
* @param stats the recycle impurity statistics for this feature's all splits,
615621
* only 'impurity' and 'impurityCalculator' are valid between each iteration
616622
* @param leftImpurityCalculator left node aggregates for this (feature, split)
@@ -668,6 +674,7 @@ private[spark] object RandomForest extends Logging {
668674

669675
/**
670676
* Find the best split for a node.
677+
*
671678
* @param binAggregates Bin statistics.
672679
* @return tuple for best split: (Split, information gain, prediction at node)
673680
*/
@@ -940,6 +947,7 @@ private[spark] object RandomForest extends Logging {
940947
* NOTE: Returned number of splits is set based on `featureSamples` and
941948
* could be different from the specified `numSplits`.
942949
* The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
950+
*
943951
* @param featureSamples feature values of each sample
944952
* @param metadata decision tree metadata
945953
* NOTE: `metadata.numbins` will be changed accordingly
@@ -1083,6 +1091,7 @@ private[spark] object RandomForest extends Logging {
10831091

10841092
/**
10851093
* Get the number of values to be stored for this node in the bin aggregates.
1094+
*
10861095
* @param featureSubset Indices of features which may be split at this node.
10871096
* If None, then use all features.
10881097
*/

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ class DecisionTree private[spark] (private val strategy: Strategy, private val s
6262
*/
6363
@Since("1.2.0")
6464
def run(input: RDD[LabeledPoint]): DecisionTreeModel = {
65-
val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all",
66-
seed = seed)
65+
val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = seed)
6766
val rfModel = rf.run(input)
6867
rfModel.trees(0)
6968
}
@@ -88,7 +87,7 @@ object DecisionTree extends Serializable with Logging {
8887
* categorical), depth of the tree, quantile calculation strategy, etc.
8988
* @return DecisionTreeModel that can be used for prediction.
9089
*/
91-
@Since("1.0.0")
90+
@Since("1.0.0")
9291
def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = {
9392
new DecisionTree(strategy).run(input)
9493
}

mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ import org.apache.spark.util.Utils
4545
* - sqrt: recommended by Breiman manual for random forests
4646
* - The defaults of sqrt (classification) and onethird (regression) match the R randomForest
4747
* package.
48+
*
4849
* @see [[http://www.stat.berkeley.edu/~breiman/randomforest2001.pdf Breiman (2001)]]
4950
* @see [[http://www.stat.berkeley.edu/~breiman/Using_random_forests_V3.1.pdf Breiman manual for
5051
* random forests]]
51-
*
5252
* @param strategy The configuration parameters for the random forest algorithm which specify
5353
* the type of random forest (classification or regression), feature type
5454
* (continuous, categorical), depth of the tree, quantile calculation strategy,
@@ -91,7 +91,7 @@ private class RandomForest (
9191
*/
9292
def run(input: RDD[LabeledPoint]): RandomForestModel = {
9393
val trees: Array[NewDTModel] =
94-
NewRandomForest.run(input, strategy, numTrees, featureSubsetStrategy, seed.toLong)
94+
NewRandomForest.run(input, strategy, numTrees, featureSubsetStrategy, seed.toLong, None)
9595
new RandomForestModel(strategy.algo, trees.map(_.toOld))
9696
}
9797

mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
322322
numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3)
323323

324324
val model = RandomForest.run(input, strategy, numTrees = 1, featureSubsetStrategy = "all",
325-
seed = 42).head
325+
seed = 42, instr = None).head
326326
model.rootNode match {
327327
case n: InternalNode => n.split match {
328328
case s: CategoricalSplit =>
@@ -345,9 +345,9 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
345345
new OldStrategy(OldAlgo.Classification, Entropy, 3, 2, 100, maxMemoryInMB = 0)
346346

347347
val tree1 = RandomForest.run(rdd, strategy1, numTrees = 1, featureSubsetStrategy = "all",
348-
seed = 42).head
348+
seed = 42, instr = None).head
349349
val tree2 = RandomForest.run(rdd, strategy2, numTrees = 1, featureSubsetStrategy = "all",
350-
seed = 42).head
350+
seed = 42, instr = None).head
351351

352352
def getChildren(rootNode: Node): Array[InternalNode] = rootNode match {
353353
case n: InternalNode =>

0 commit comments

Comments
 (0)