Skip to content

Commit f7d9e3d

Browse files
huaxingaosrowen
authored andcommitted
[SPARK-23631][ML][PYSPARK] Add summary to RandomForestClassificationModel
### What changes were proposed in this pull request? Add summary to RandomForestClassificationModel... ### Why are the changes needed? so user can get a summary of this classification model, and retrieve common metrics such as accuracy, weightedTruePositiveRate, roc (for binary), pr curves (for binary), etc. ### Does this PR introduce _any_ user-facing change? Yes ``` RandomForestClassificationModel.summary RandomForestClassificationModel.evaluate ``` ### How was this patch tested? Add new tests Closes #28913 from huaxingao/rf_summary. Authored-by: Huaxin Gao <[email protected]> Signed-off-by: Sean Owen <[email protected]>
1 parent 15fb5d7 commit f7d9e3d

File tree

10 files changed

+496
-52
lines changed

10 files changed

+496
-52
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/ClassificationSummary.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ private[classification] trait ClassificationSummary extends Serializable {
4444
@Since("3.1.0")
4545
def labelCol: String
4646

47-
/** Field in "predictions" which gives the weight of each instance as a vector. */
47+
/** Field in "predictions" which gives the weight of each instance. */
4848
@Since("3.1.0")
4949
def weightCol: String
5050

mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import org.apache.spark.annotation.Since
2222
import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
2323
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
2424
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
25+
import org.apache.spark.ml.param.ParamMap
2526
import org.apache.spark.ml.param.shared.HasRawPredictionCol
2627
import org.apache.spark.ml.util.{MetadataUtils, SchemaUtils}
2728
import org.apache.spark.rdd.RDD
@@ -269,4 +270,26 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
269270
* @return predicted label
270271
*/
271272
protected def raw2prediction(rawPrediction: Vector): Double = rawPrediction.argmax
273+
274+
/**
275+
* If the rawPrediction and prediction columns are set, this method returns the current model,
276+
* otherwise it generates new columns for them and sets them as columns on a new copy of
277+
* the current model
278+
*/
279+
private[classification] def findSummaryModel():
280+
(ClassificationModel[FeaturesType, M], String, String) = {
281+
val model = if ($(rawPredictionCol).isEmpty && $(predictionCol).isEmpty) {
282+
copy(ParamMap.empty)
283+
.setRawPredictionCol("rawPrediction_" + java.util.UUID.randomUUID.toString)
284+
.setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
285+
} else if ($(rawPredictionCol).isEmpty) {
286+
copy(ParamMap.empty).setRawPredictionCol("rawPrediction_" +
287+
java.util.UUID.randomUUID.toString)
288+
} else if ($(predictionCol).isEmpty) {
289+
copy(ParamMap.empty).setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
290+
} else {
291+
this
292+
}
293+
(model, model.getRawPredictionCol, model.getPredictionCol)
294+
}
272295
}

mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -394,27 +394,6 @@ class LinearSVCModel private[classification] (
394394
@Since("3.1.0")
395395
override def summary: LinearSVCTrainingSummary = super.summary
396396

397-
/**
398-
* If the rawPrediction and prediction columns are set, this method returns the current model,
399-
* otherwise it generates new columns for them and sets them as columns on a new copy of
400-
* the current model
401-
*/
402-
private[classification] def findSummaryModel(): (LinearSVCModel, String, String) = {
403-
val model = if ($(rawPredictionCol).isEmpty && $(predictionCol).isEmpty) {
404-
copy(ParamMap.empty)
405-
.setRawPredictionCol("rawPrediction_" + java.util.UUID.randomUUID.toString)
406-
.setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
407-
} else if ($(rawPredictionCol).isEmpty) {
408-
copy(ParamMap.empty).setRawPredictionCol("rawPrediction_" +
409-
java.util.UUID.randomUUID.toString)
410-
} else if ($(predictionCol).isEmpty) {
411-
copy(ParamMap.empty).setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
412-
} else {
413-
this
414-
}
415-
(model, model.getRawPredictionCol, model.getPredictionCol)
416-
}
417-
418397
/**
419398
* Evaluates the model on a test dataset.
420399
*

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,27 +1158,6 @@ class LogisticRegressionModel private[spark] (
11581158
s"(numClasses=${numClasses}), use summary instead.")
11591159
}
11601160

1161-
/**
1162-
* If the probability and prediction columns are set, this method returns the current model,
1163-
* otherwise it generates new columns for them and sets them as columns on a new copy of
1164-
* the current model
1165-
*/
1166-
private[classification] def findSummaryModel():
1167-
(LogisticRegressionModel, String, String) = {
1168-
val model = if ($(probabilityCol).isEmpty && $(predictionCol).isEmpty) {
1169-
copy(ParamMap.empty)
1170-
.setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString)
1171-
.setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
1172-
} else if ($(probabilityCol).isEmpty) {
1173-
copy(ParamMap.empty).setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString)
1174-
} else if ($(predictionCol).isEmpty) {
1175-
copy(ParamMap.empty).setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
1176-
} else {
1177-
this
1178-
}
1179-
(model, model.getProbabilityCol, model.getPredictionCol)
1180-
}
1181-
11821161
/**
11831162
* Evaluates the model on a test dataset.
11841163
*
@@ -1451,7 +1430,7 @@ sealed trait BinaryLogisticRegressionTrainingSummary extends BinaryLogisticRegre
14511430
* double.
14521431
* @param labelCol field in "predictions" which gives the true label of each instance.
14531432
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
1454-
* @param weightCol field in "predictions" which gives the weight of each instance as a vector.
1433+
* @param weightCol field in "predictions" which gives the weight of each instance.
14551434
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
14561435
*/
14571436
private class LogisticRegressionTrainingSummaryImpl(
@@ -1476,7 +1455,7 @@ private class LogisticRegressionTrainingSummaryImpl(
14761455
* double.
14771456
* @param labelCol field in "predictions" which gives the true label of each instance.
14781457
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
1479-
* @param weightCol field in "predictions" which gives the weight of each instance as a vector.
1458+
* @param weightCol field in "predictions" which gives the weight of each instance.
14801459
*/
14811460
private class LogisticRegressionSummaryImpl(
14821461
@transient override val predictions: DataFrame,
@@ -1497,7 +1476,7 @@ private class LogisticRegressionSummaryImpl(
14971476
* double.
14981477
* @param labelCol field in "predictions" which gives the true label of each instance.
14991478
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
1500-
* @param weightCol field in "predictions" which gives the weight of each instance as a vector.
1479+
* @param weightCol field in "predictions" which gives the weight of each instance.
15011480
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
15021481
*/
15031482
private class BinaryLogisticRegressionTrainingSummaryImpl(
@@ -1522,7 +1501,7 @@ private class BinaryLogisticRegressionTrainingSummaryImpl(
15221501
* each class as a double.
15231502
* @param labelCol field in "predictions" which gives the true label of each instance.
15241503
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
1525-
* @param weightCol field in "predictions" which gives the weight of each instance as a vector.
1504+
* @param weightCol field in "predictions" which gives the weight of each instance.
15261505
*/
15271506
private class BinaryLogisticRegressionSummaryImpl(
15281507
predictions: DataFrame,

mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.ml.classification
1919

2020
import org.apache.spark.annotation.Since
2121
import org.apache.spark.ml.linalg.{DenseVector, Vector, VectorUDT}
22+
import org.apache.spark.ml.param.ParamMap
2223
import org.apache.spark.ml.param.shared._
2324
import org.apache.spark.ml.util.SchemaUtils
2425
import org.apache.spark.sql.{DataFrame, Dataset}
@@ -229,6 +230,27 @@ abstract class ProbabilisticClassificationModel[
229230
argMax
230231
}
231232
}
233+
234+
/**
235+
*If the probability and prediction columns are set, this method returns the current model,
236+
* otherwise it generates new columns for them and sets them as columns on a new copy of
237+
* the current model
238+
*/
239+
override private[classification] def findSummaryModel():
240+
(ProbabilisticClassificationModel[FeaturesType, M], String, String) = {
241+
val model = if ($(probabilityCol).isEmpty && $(predictionCol).isEmpty) {
242+
copy(ParamMap.empty)
243+
.setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString)
244+
.setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
245+
} else if ($(probabilityCol).isEmpty) {
246+
copy(ParamMap.empty).setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString)
247+
} else if ($(predictionCol).isEmpty) {
248+
copy(ParamMap.empty).setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
249+
} else {
250+
this
251+
}
252+
(model, model.getProbabilityCol, model.getPredictionCol)
253+
}
232254
}
233255

234256
private[ml] object ProbabilisticClassificationModel {

mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala

Lines changed: 179 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,35 @@ class RandomForestClassifier @Since("1.4.0") (
166166
val numFeatures = trees.head.numFeatures
167167
instr.logNumClasses(numClasses)
168168
instr.logNumFeatures(numFeatures)
169-
new RandomForestClassificationModel(uid, trees, numFeatures, numClasses)
169+
createModel(dataset, trees, numFeatures, numClasses)
170+
}
171+
172+
private def createModel(
173+
dataset: Dataset[_],
174+
trees: Array[DecisionTreeClassificationModel],
175+
numFeatures: Int,
176+
numClasses: Int): RandomForestClassificationModel = {
177+
val model = copyValues(new RandomForestClassificationModel(uid, trees, numFeatures, numClasses))
178+
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
179+
180+
val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
181+
val rfSummary = if (numClasses <= 2) {
182+
new BinaryRandomForestClassificationTrainingSummaryImpl(
183+
summaryModel.transform(dataset),
184+
probabilityColName,
185+
predictionColName,
186+
$(labelCol),
187+
weightColName,
188+
Array(0.0))
189+
} else {
190+
new RandomForestClassificationTrainingSummaryImpl(
191+
summaryModel.transform(dataset),
192+
predictionColName,
193+
$(labelCol),
194+
weightColName,
195+
Array(0.0))
196+
}
197+
model.setSummary(Some(rfSummary))
170198
}
171199

172200
@Since("1.4.1")
@@ -204,7 +232,8 @@ class RandomForestClassificationModel private[ml] (
204232
@Since("1.5.0") override val numClasses: Int)
205233
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
206234
with RandomForestClassifierParams with TreeEnsembleModel[DecisionTreeClassificationModel]
207-
with MLWritable with Serializable {
235+
with MLWritable with Serializable
236+
with HasTrainingSummary[RandomForestClassificationTrainingSummary] {
208237

209238
require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.")
210239

@@ -228,6 +257,44 @@ class RandomForestClassificationModel private[ml] (
228257
@Since("1.4.0")
229258
override def treeWeights: Array[Double] = _treeWeights
230259

260+
/**
261+
* Gets summary of model on training set. An exception is thrown
262+
* if `hasSummary` is false.
263+
*/
264+
@Since("3.1.0")
265+
override def summary: RandomForestClassificationTrainingSummary = super.summary
266+
267+
/**
268+
* Gets summary of model on training set. An exception is thrown
269+
* if `hasSummary` is false or it is a multiclass model.
270+
*/
271+
@Since("3.1.0")
272+
def binarySummary: BinaryRandomForestClassificationTrainingSummary = summary match {
273+
case b: BinaryRandomForestClassificationTrainingSummary => b
274+
case _ =>
275+
throw new RuntimeException("Cannot create a binary summary for a non-binary model" +
276+
s"(numClasses=${numClasses}), use summary instead.")
277+
}
278+
279+
/**
280+
* Evaluates the model on a test dataset.
281+
*
282+
* @param dataset Test dataset to evaluate model on.
283+
*/
284+
@Since("3.1.0")
285+
def evaluate(dataset: Dataset[_]): RandomForestClassificationSummary = {
286+
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
287+
// Handle possible missing or invalid prediction columns
288+
val (summaryModel, probabilityColName, predictionColName) = findSummaryModel()
289+
if (numClasses > 2) {
290+
new RandomForestClassificationSummaryImpl(summaryModel.transform(dataset),
291+
predictionColName, $(labelCol), weightColName)
292+
} else {
293+
new BinaryRandomForestClassificationSummaryImpl(summaryModel.transform(dataset),
294+
probabilityColName, predictionColName, $(labelCol), weightColName)
295+
}
296+
}
297+
231298
@Since("1.4.0")
232299
override def transformSchema(schema: StructType): StructType = {
233300
var outputSchema = super.transformSchema(schema)
@@ -388,3 +455,113 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica
388455
new RandomForestClassificationModel(uid, newTrees, numFeatures, numClasses)
389456
}
390457
}
458+
459+
/**
460+
* Abstraction for multiclass RandomForestClassification results for a given model.
461+
*/
462+
sealed trait RandomForestClassificationSummary extends ClassificationSummary {
463+
/**
464+
* Convenient method for casting to BinaryRandomForestClassificationSummary.
465+
* This method will throw an Exception if the summary is not a binary summary.
466+
*/
467+
@Since("3.1.0")
468+
def asBinary: BinaryRandomForestClassificationSummary = this match {
469+
case b: BinaryRandomForestClassificationSummary => b
470+
case _ =>
471+
throw new RuntimeException("Cannot cast to a binary summary.")
472+
}
473+
}
474+
475+
/**
476+
* Abstraction for multiclass RandomForestClassification training results.
477+
*/
478+
sealed trait RandomForestClassificationTrainingSummary extends RandomForestClassificationSummary
479+
with TrainingSummary
480+
481+
/**
482+
* Abstraction for BinaryRandomForestClassification results for a given model.
483+
*/
484+
sealed trait BinaryRandomForestClassificationSummary extends BinaryClassificationSummary
485+
486+
/**
487+
* Abstraction for BinaryRandomForestClassification training results.
488+
*/
489+
sealed trait BinaryRandomForestClassificationTrainingSummary extends
490+
BinaryRandomForestClassificationSummary with RandomForestClassificationTrainingSummary
491+
492+
/**
493+
* Multiclass RandomForestClassification training results.
494+
*
495+
* @param predictions dataframe output by the model's `transform` method.
496+
* @param predictionCol field in "predictions" which gives the prediction for a data instance as a
497+
* double.
498+
* @param labelCol field in "predictions" which gives the true label of each instance.
499+
* @param weightCol field in "predictions" which gives the weight of each instance.
500+
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
501+
*/
502+
private class RandomForestClassificationTrainingSummaryImpl(
503+
predictions: DataFrame,
504+
predictionCol: String,
505+
labelCol: String,
506+
weightCol: String,
507+
override val objectiveHistory: Array[Double])
508+
extends RandomForestClassificationSummaryImpl(
509+
predictions, predictionCol, labelCol, weightCol)
510+
with RandomForestClassificationTrainingSummary
511+
512+
/**
513+
* Multiclass RandomForestClassification results for a given model.
514+
*
515+
* @param predictions dataframe output by the model's `transform` method.
516+
* @param predictionCol field in "predictions" which gives the prediction for a data instance as a
517+
* double.
518+
* @param labelCol field in "predictions" which gives the true label of each instance.
519+
* @param weightCol field in "predictions" which gives the weight of each instance.
520+
*/
521+
private class RandomForestClassificationSummaryImpl(
522+
@transient override val predictions: DataFrame,
523+
override val predictionCol: String,
524+
override val labelCol: String,
525+
override val weightCol: String)
526+
extends RandomForestClassificationSummary
527+
528+
/**
529+
* Binary RandomForestClassification training results.
530+
*
531+
* @param predictions dataframe output by the model's `transform` method.
532+
* @param scoreCol field in "predictions" which gives the probability of each class as a vector.
533+
* @param predictionCol field in "predictions" which gives the prediction for a data instance as a
534+
* double.
535+
* @param labelCol field in "predictions" which gives the true label of each instance.
536+
* @param weightCol field in "predictions" which gives the weight of each instance.
537+
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
538+
*/
539+
private class BinaryRandomForestClassificationTrainingSummaryImpl(
540+
predictions: DataFrame,
541+
scoreCol: String,
542+
predictionCol: String,
543+
labelCol: String,
544+
weightCol: String,
545+
override val objectiveHistory: Array[Double])
546+
extends BinaryRandomForestClassificationSummaryImpl(
547+
predictions, scoreCol, predictionCol, labelCol, weightCol)
548+
with BinaryRandomForestClassificationTrainingSummary
549+
550+
/**
551+
* Binary RandomForestClassification for a given model.
552+
*
553+
* @param predictions dataframe output by the model's `transform` method.
554+
* @param scoreCol field in "predictions" which gives the prediction of
555+
* each class as a vector.
556+
* @param labelCol field in "predictions" which gives the true label of each instance.
557+
* @param weightCol field in "predictions" which gives the weight of each instance.
558+
*/
559+
private class BinaryRandomForestClassificationSummaryImpl(
560+
predictions: DataFrame,
561+
override val scoreCol: String,
562+
predictionCol: String,
563+
labelCol: String,
564+
weightCol: String)
565+
extends RandomForestClassificationSummaryImpl(
566+
predictions, predictionCol, labelCol, weightCol)
567+
with BinaryRandomForestClassificationSummary

mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
342342
blorModel2.summary.asBinary.weightedPrecision relTol 1e-6)
343343
assert(blorModel.summary.asBinary.weightedRecall ~==
344344
blorModel2.summary.asBinary.weightedRecall relTol 1e-6)
345-
assert(blorModel.summary.asBinary.asBinary.areaUnderROC ~==
345+
assert(blorModel.summary.asBinary.areaUnderROC ~==
346346
blorModel2.summary.asBinary.areaUnderROC relTol 1e-6)
347347

348348
assert(mlorSummary.accuracy ~== mlorSummary2.accuracy relTol 1e-6)

0 commit comments

Comments
 (0)