Skip to content

Commit 3cb1b57

Browse files
MrBagojkbradley
authored andcommitted
[SPARK-24852][ML] Update spark.ml to use Instrumentation.instrumented.
## What changes were proposed in this pull request? Followup for #21719. Update spark.ml training code to fully wrap instrumented methods and remove old instrumentation APIs. ## How was this patch tested? existing tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Bago Amirbekian <[email protected]> Closes #21799 from MrBago/new-instrumentation-apis2.
1 parent 244bcff commit 3cb1b57

24 files changed

+153
-186
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.apache.spark.ml.tree._
2929
import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
3030
import org.apache.spark.ml.tree.impl.RandomForest
3131
import org.apache.spark.ml.util._
32+
import org.apache.spark.ml.util.Instrumentation.instrumented
3233
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
3334
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
3435
import org.apache.spark.rdd.RDD
@@ -96,8 +97,10 @@ class DecisionTreeClassifier @Since("1.4.0") (
9697
@Since("1.6.0")
9798
override def setSeed(value: Long): this.type = set(seed, value)
9899

99-
override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = {
100-
val instr = Instrumentation.create(this, dataset)
100+
override protected def train(
101+
dataset: Dataset[_]): DecisionTreeClassificationModel = instrumented { instr =>
102+
instr.logPipelineStage(this)
103+
instr.logDataset(dataset)
101104
val categoricalFeatures: Map[Int, Int] =
102105
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
103106
val numClasses: Int = getNumClasses(dataset)
@@ -112,30 +115,27 @@ class DecisionTreeClassifier @Since("1.4.0") (
112115
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
113116
val strategy = getOldStrategy(categoricalFeatures, numClasses)
114117

115-
instr.logParams(maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
118+
instr.logParams(this, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
116119
cacheNodeIds, checkpointInterval, impurity, seed)
117120

118121
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
119122
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
120123

121-
val m = trees.head.asInstanceOf[DecisionTreeClassificationModel]
122-
instr.logSuccess(m)
123-
m
124+
trees.head.asInstanceOf[DecisionTreeClassificationModel]
124125
}
125126

126127
/** (private[ml]) Train a decision tree on an RDD */
127128
private[ml] def train(data: RDD[LabeledPoint],
128-
oldStrategy: OldStrategy): DecisionTreeClassificationModel = {
129-
val instr = Instrumentation.create(this, data)
130-
instr.logParams(maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
129+
oldStrategy: OldStrategy): DecisionTreeClassificationModel = instrumented { instr =>
130+
instr.logPipelineStage(this)
131+
instr.logDataset(data)
132+
instr.logParams(this, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
131133
cacheNodeIds, checkpointInterval, impurity, seed)
132134

133135
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
134136
seed = 0L, instr = Some(instr), parentUID = Some(uid))
135137

136-
val m = trees.head.asInstanceOf[DecisionTreeClassificationModel]
137-
instr.logSuccess(m)
138-
m
138+
trees.head.asInstanceOf[DecisionTreeClassificationModel]
139139
}
140140

141141
/** (private[ml]) Create a Strategy instance to use with the old API. */

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ import org.apache.spark.ml.tree._
3131
import org.apache.spark.ml.tree.impl.GradientBoostedTrees
3232
import org.apache.spark.ml.util._
3333
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
34+
import org.apache.spark.ml.util.Instrumentation.instrumented
3435
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
3536
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
36-
import org.apache.spark.rdd.RDD
3737
import org.apache.spark.sql.{DataFrame, Dataset, Row}
3838
import org.apache.spark.sql.functions._
3939

@@ -152,7 +152,8 @@ class GBTClassifier @Since("1.4.0") (
152152
set(validationIndicatorCol, value)
153153
}
154154

155-
override protected def train(dataset: Dataset[_]): GBTClassificationModel = {
155+
override protected def train(
156+
dataset: Dataset[_]): GBTClassificationModel = instrumented { instr =>
156157
val categoricalFeatures: Map[Int, Int] =
157158
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
158159

@@ -189,8 +190,9 @@ class GBTClassifier @Since("1.4.0") (
189190
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
190191
}
191192

192-
val instr = Instrumentation.create(this, dataset)
193-
instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType,
193+
instr.logPipelineStage(this)
194+
instr.logDataset(dataset)
195+
instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, lossType,
194196
maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
195197
seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy,
196198
validationIndicatorCol)
@@ -204,9 +206,7 @@ class GBTClassifier @Since("1.4.0") (
204206
GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy))
205207
}
206208

207-
val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
208-
instr.logSuccess(m)
209-
m
209+
new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
210210
}
211211

212212
@Since("1.4.1")

mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
3333
import org.apache.spark.ml.param._
3434
import org.apache.spark.ml.param.shared._
3535
import org.apache.spark.ml.util._
36+
import org.apache.spark.ml.util.Instrumentation.instrumented
3637
import org.apache.spark.mllib.linalg.VectorImplicits._
3738
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
3839
import org.apache.spark.rdd.RDD
@@ -162,16 +163,17 @@ class LinearSVC @Since("2.2.0") (
162163
@Since("2.2.0")
163164
override def copy(extra: ParamMap): LinearSVC = defaultCopy(extra)
164165

165-
override protected def train(dataset: Dataset[_]): LinearSVCModel = {
166+
override protected def train(dataset: Dataset[_]): LinearSVCModel = instrumented { instr =>
166167
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
167168
val instances: RDD[Instance] =
168169
dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
169170
case Row(label: Double, weight: Double, features: Vector) =>
170171
Instance(label, weight, features)
171172
}
172173

173-
val instr = Instrumentation.create(this, dataset)
174-
instr.logParams(regParam, maxIter, fitIntercept, tol, standardization, threshold,
174+
instr.logPipelineStage(this)
175+
instr.logDataset(dataset)
176+
instr.logParams(this, regParam, maxIter, fitIntercept, tol, standardization, threshold,
175177
aggregationDepth)
176178

177179
val (summarizer, labelSummarizer) = {
@@ -276,9 +278,7 @@ class LinearSVC @Since("2.2.0") (
276278
(Vectors.dense(coefficientArray), intercept, scaledObjectiveHistory.result())
277279
}
278280

279-
val model = copyValues(new LinearSVCModel(uid, coefficientVector, interceptVector))
280-
instr.logSuccess(model)
281-
model
281+
copyValues(new LinearSVCModel(uid, coefficientVector, interceptVector))
282282
}
283283
}
284284

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ class LogisticRegression @Since("1.2.0") (
503503

504504
instr.logPipelineStage(this)
505505
instr.logDataset(dataset)
506-
instr.logParams(regParam, elasticNetParam, standardization, threshold,
506+
instr.logParams(this, regParam, elasticNetParam, standardization, threshold,
507507
maxIter, tol, fitIntercept)
508508

509509
val (summarizer, labelSummarizer) = {

mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors}
2828
import org.apache.spark.ml.param._
2929
import org.apache.spark.ml.param.shared._
3030
import org.apache.spark.ml.util._
31+
import org.apache.spark.ml.util.Instrumentation.instrumented
3132
import org.apache.spark.sql.Dataset
3233

3334
/** Params for Multilayer Perceptron. */
@@ -230,9 +231,11 @@ class MultilayerPerceptronClassifier @Since("1.5.0") (
230231
* @param dataset Training dataset
231232
* @return Fitted model
232233
*/
233-
override protected def train(dataset: Dataset[_]): MultilayerPerceptronClassificationModel = {
234-
val instr = Instrumentation.create(this, dataset)
235-
instr.logParams(labelCol, featuresCol, predictionCol, layers, maxIter, tol,
234+
override protected def train(
235+
dataset: Dataset[_]): MultilayerPerceptronClassificationModel = instrumented { instr =>
236+
instr.logPipelineStage(this)
237+
instr.logDataset(dataset)
238+
instr.logParams(this, labelCol, featuresCol, predictionCol, layers, maxIter, tol,
236239
blockSize, solver, stepSize, seed)
237240

238241
val myLayers = $(layers)
@@ -264,10 +267,7 @@ class MultilayerPerceptronClassifier @Since("1.5.0") (
264267
}
265268
trainer.setStackSize($(blockSize))
266269
val mlpModel = trainer.train(data)
267-
val model = new MultilayerPerceptronClassificationModel(uid, myLayers, mlpModel.weights)
268-
269-
instr.logSuccess(model)
270-
model
270+
new MultilayerPerceptronClassificationModel(uid, myLayers, mlpModel.weights)
271271
}
272272
}
273273

mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.apache.spark.ml.linalg._
2525
import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators}
2626
import org.apache.spark.ml.param.shared.HasWeightCol
2727
import org.apache.spark.ml.util._
28+
import org.apache.spark.ml.util.Instrumentation.instrumented
2829
import org.apache.spark.mllib.util.MLUtils
2930
import org.apache.spark.sql.{Dataset, Row}
3031
import org.apache.spark.sql.functions.{col, lit}
@@ -125,8 +126,9 @@ class NaiveBayes @Since("1.5.0") (
125126
*/
126127
private[spark] def trainWithLabelCheck(
127128
dataset: Dataset[_],
128-
positiveLabel: Boolean): NaiveBayesModel = {
129-
val instr = Instrumentation.create(this, dataset)
129+
positiveLabel: Boolean): NaiveBayesModel = instrumented { instr =>
130+
instr.logPipelineStage(this)
131+
instr.logDataset(dataset)
130132
if (positiveLabel && isDefined(thresholds)) {
131133
val numClasses = getNumClasses(dataset)
132134
instr.logNumClasses(numClasses)
@@ -148,7 +150,7 @@ class NaiveBayes @Since("1.5.0") (
148150
}
149151
}
150152

151-
instr.logParams(labelCol, featuresCol, weightCol, predictionCol, rawPredictionCol,
153+
instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, rawPredictionCol,
152154
probabilityCol, modelType, smoothing, thresholds)
153155

154156
val numFeatures = dataset.select(col($(featuresCol))).head().getAs[Vector](0).size
@@ -204,9 +206,7 @@ class NaiveBayes @Since("1.5.0") (
204206

205207
val pi = Vectors.dense(piArray)
206208
val theta = new DenseMatrix(numLabels, numFeatures, thetaArray, true)
207-
val model = new NaiveBayesModel(uid, pi, theta).setOldLabels(labelArray)
208-
instr.logSuccess(model)
209-
model
209+
new NaiveBayesModel(uid, pi, theta).setOldLabels(labelArray)
210210
}
211211

212212
@Since("1.5.0")

mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors}
3636
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
3737
import org.apache.spark.ml.param.shared.{HasParallelism, HasWeightCol}
3838
import org.apache.spark.ml.util._
39+
import org.apache.spark.ml.util.Instrumentation.instrumented
3940
import org.apache.spark.sql.{DataFrame, Dataset, Row}
4041
import org.apache.spark.sql.functions._
4142
import org.apache.spark.sql.types._
@@ -362,11 +363,12 @@ final class OneVsRest @Since("1.4.0") (
362363
}
363364

364365
@Since("2.0.0")
365-
override def fit(dataset: Dataset[_]): OneVsRestModel = {
366+
override def fit(dataset: Dataset[_]): OneVsRestModel = instrumented { instr =>
366367
transformSchema(dataset.schema)
367368

368-
val instr = Instrumentation.create(this, dataset)
369-
instr.logParams(labelCol, featuresCol, predictionCol, parallelism, rawPredictionCol)
369+
instr.logPipelineStage(this)
370+
instr.logDataset(dataset)
371+
instr.logParams(this, labelCol, featuresCol, predictionCol, parallelism, rawPredictionCol)
370372
instr.logNamedValue("classifier", $(classifier).getClass.getCanonicalName)
371373

372374
// determine number of classes either from metadata if provided, or via computation.
@@ -440,7 +442,6 @@ final class OneVsRest @Since("1.4.0") (
440442
case attr: Attribute => attr
441443
}
442444
val model = new OneVsRestModel(uid, labelAttribute.toMetadata(), models).setParent(this)
443-
instr.logSuccess(model)
444445
copyValues(model)
445446
}
446447

mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import org.apache.spark.ml.tree._
2828
import org.apache.spark.ml.tree.impl.RandomForest
2929
import org.apache.spark.ml.util._
3030
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
31+
import org.apache.spark.ml.util.Instrumentation.instrumented
3132
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
3233
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
3334
import org.apache.spark.rdd.RDD
@@ -115,8 +116,10 @@ class RandomForestClassifier @Since("1.4.0") (
115116
override def setFeatureSubsetStrategy(value: String): this.type =
116117
set(featureSubsetStrategy, value)
117118

118-
override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = {
119-
val instr = Instrumentation.create(this, dataset)
119+
override protected def train(
120+
dataset: Dataset[_]): RandomForestClassificationModel = instrumented { instr =>
121+
instr.logPipelineStage(this)
122+
instr.logDataset(dataset)
120123
val categoricalFeatures: Map[Int, Int] =
121124
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
122125
val numClasses: Int = getNumClasses(dataset)
@@ -131,7 +134,7 @@ class RandomForestClassifier @Since("1.4.0") (
131134
val strategy =
132135
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
133136

134-
instr.logParams(labelCol, featuresCol, predictionCol, probabilityCol, rawPredictionCol,
137+
instr.logParams(this, labelCol, featuresCol, predictionCol, probabilityCol, rawPredictionCol,
135138
impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain,
136139
minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval)
137140

@@ -140,11 +143,9 @@ class RandomForestClassifier @Since("1.4.0") (
140143
.map(_.asInstanceOf[DecisionTreeClassificationModel])
141144

142145
val numFeatures = oldDataset.first().features.size
143-
val m = new RandomForestClassificationModel(uid, trees, numFeatures, numClasses)
144146
instr.logNumClasses(numClasses)
145147
instr.logNumFeatures(numFeatures)
146-
instr.logSuccess(m)
147-
m
148+
new RandomForestClassificationModel(uid, trees, numFeatures, numClasses)
148149
}
149150

150151
@Since("1.4.1")

mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.ml.linalg.Vector
2626
import org.apache.spark.ml.param._
2727
import org.apache.spark.ml.param.shared._
2828
import org.apache.spark.ml.util._
29+
import org.apache.spark.ml.util.Instrumentation.instrumented
2930
import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans,
3031
BisectingKMeansModel => MLlibBisectingKMeansModel}
3132
import org.apache.spark.mllib.linalg.VectorImplicits._
@@ -257,12 +258,13 @@ class BisectingKMeans @Since("2.0.0") (
257258
def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value)
258259

259260
@Since("2.0.0")
260-
override def fit(dataset: Dataset[_]): BisectingKMeansModel = {
261+
override def fit(dataset: Dataset[_]): BisectingKMeansModel = instrumented { instr =>
261262
transformSchema(dataset.schema, logging = true)
262263
val rdd = DatasetUtils.columnToOldVector(dataset, getFeaturesCol)
263264

264-
val instr = Instrumentation.create(this, dataset)
265-
instr.logParams(featuresCol, predictionCol, k, maxIter, seed,
265+
instr.logPipelineStage(this)
266+
instr.logDataset(dataset)
267+
instr.logParams(this, featuresCol, predictionCol, k, maxIter, seed,
266268
minDivisibleClusterSize, distanceMeasure)
267269

268270
val bkm = new MLlibBisectingKMeans()
@@ -275,10 +277,8 @@ class BisectingKMeans @Since("2.0.0") (
275277
val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this))
276278
val summary = new BisectingKMeansSummary(
277279
model.transform(dataset), $(predictionCol), $(featuresCol), $(k), $(maxIter))
278-
model.setSummary(Some(summary))
279280
instr.logNamedValue("clusterSizes", summary.clusterSizes)
280-
instr.logSuccess(model)
281-
model
281+
model.setSummary(Some(summary))
282282
}
283283

284284
@Since("2.0.0")

mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.apache.spark.ml.param._
2929
import org.apache.spark.ml.param.shared._
3030
import org.apache.spark.ml.stat.distribution.MultivariateGaussian
3131
import org.apache.spark.ml.util._
32+
import org.apache.spark.ml.util.Instrumentation.instrumented
3233
import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatrix,
3334
Vector => OldVector, Vectors => OldVectors}
3435
import org.apache.spark.rdd.RDD
@@ -335,7 +336,7 @@ class GaussianMixture @Since("2.0.0") (
335336
private val numSamples = 5
336337

337338
@Since("2.0.0")
338-
override def fit(dataset: Dataset[_]): GaussianMixtureModel = {
339+
override def fit(dataset: Dataset[_]): GaussianMixtureModel = instrumented { instr =>
339340
transformSchema(dataset.schema, logging = true)
340341

341342
val sc = dataset.sparkSession.sparkContext
@@ -352,8 +353,9 @@ class GaussianMixture @Since("2.0.0") (
352353
s"than ${GaussianMixture.MAX_NUM_FEATURES} features because the size of the covariance" +
353354
s" matrix is quadratic in the number of features.")
354355

355-
val instr = Instrumentation.create(this, dataset)
356-
instr.logParams(featuresCol, predictionCol, probabilityCol, k, maxIter, seed, tol)
356+
instr.logPipelineStage(this)
357+
instr.logDataset(dataset)
358+
instr.logParams(this, featuresCol, predictionCol, probabilityCol, k, maxIter, seed, tol)
357359
instr.logNumFeatures(numFeatures)
358360

359361
val shouldDistributeGaussians = GaussianMixture.shouldDistributeGaussians(
@@ -425,11 +427,9 @@ class GaussianMixture @Since("2.0.0") (
425427
val model = copyValues(new GaussianMixtureModel(uid, weights, gaussianDists)).setParent(this)
426428
val summary = new GaussianMixtureSummary(model.transform(dataset),
427429
$(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood, iter)
428-
model.setSummary(Some(summary))
429430
instr.logNamedValue("logLikelihood", logLikelihood)
430431
instr.logNamedValue("clusterSizes", summary.clusterSizes)
431-
instr.logSuccess(model)
432-
model
432+
model.setSummary(Some(summary))
433433
}
434434

435435
@Since("2.0.0")

0 commit comments

Comments
 (0)