Skip to content

Commit 1e3e3d5

Browse files
committed
[SPARK-23631][ML][PySpark] Add summary to RandomForestClassificationModel
1 parent fcf9768 commit 1e3e3d5

File tree

7 files changed

+473
-11
lines changed

7 files changed

+473
-11
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/ClassificationSummary.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ private[classification] trait ClassificationSummary extends Serializable {
4444
@Since("3.1.0")
4545
def labelCol: String
4646

47-
/** Field in "predictions" which gives the weight of each instance as a vector. */
47+
/** Field in "predictions" which gives the weight of each instance. */
4848
@Since("3.1.0")
4949
def weightCol: String
5050

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1451,7 +1451,7 @@ sealed trait BinaryLogisticRegressionTrainingSummary extends BinaryLogisticRegre
14511451
* double.
14521452
* @param labelCol field in "predictions" which gives the true label of each instance.
14531453
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
1454-
* @param weightCol field in "predictions" which gives the weight of each instance as a vector.
1454+
* @param weightCol field in "predictions" which gives the weight of each instance.
14551455
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
14561456
*/
14571457
private class LogisticRegressionTrainingSummaryImpl(
@@ -1476,7 +1476,7 @@ private class LogisticRegressionTrainingSummaryImpl(
14761476
* double.
14771477
* @param labelCol field in "predictions" which gives the true label of each instance.
14781478
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
1479-
* @param weightCol field in "predictions" which gives the weight of each instance as a vector.
1479+
* @param weightCol field in "predictions" which gives the weight of each instance.
14801480
*/
14811481
private class LogisticRegressionSummaryImpl(
14821482
@transient override val predictions: DataFrame,
@@ -1497,7 +1497,7 @@ private class LogisticRegressionSummaryImpl(
14971497
* double.
14981498
* @param labelCol field in "predictions" which gives the true label of each instance.
14991499
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
1500-
* @param weightCol field in "predictions" which gives the weight of each instance as a vector.
1500+
* @param weightCol field in "predictions" which gives the weight of each instance.
15011501
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
15021502
*/
15031503
private class BinaryLogisticRegressionTrainingSummaryImpl(
@@ -1522,7 +1522,7 @@ private class BinaryLogisticRegressionTrainingSummaryImpl(
15221522
* each class as a double.
15231523
* @param labelCol field in "predictions" which gives the true label of each instance.
15241524
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
1525-
* @param weightCol field in "predictions" which gives the weight of each instance as a vector.
1525+
* @param weightCol field in "predictions" which gives the weight of each instance.
15261526
*/
15271527
private class BinaryLogisticRegressionSummaryImpl(
15281528
predictions: DataFrame,

mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala

Lines changed: 200 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,35 @@ class RandomForestClassifier @Since("1.4.0") (
166166
val numFeatures = trees.head.numFeatures
167167
instr.logNumClasses(numClasses)
168168
instr.logNumFeatures(numFeatures)
169-
new RandomForestClassificationModel(uid, trees, numFeatures, numClasses)
169+
createModel(dataset, trees, numFeatures, numClasses)
170+
}
171+
172+
private def createModel(
173+
dataset: Dataset[_],
174+
trees: Array[DecisionTreeClassificationModel],
175+
numFeatures: Int,
176+
numClasses: Int): RandomForestClassificationModel = {
177+
val model = copyValues(new RandomForestClassificationModel(uid, trees, numFeatures, numClasses))
178+
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
179+
180+
val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
181+
val rfSummary = if (numClasses <= 2) {
182+
new BinaryRandomForestClassificationTrainingSummaryImpl(
183+
summaryModel.transform(dataset),
184+
probabilityColName,
185+
predictionColName,
186+
$(labelCol),
187+
weightColName,
188+
Array(0.0))
189+
} else {
190+
new RandomForestClassificationTrainingSummaryImpl(
191+
summaryModel.transform(dataset),
192+
predictionColName,
193+
$(labelCol),
194+
weightColName,
195+
Array(0.0))
196+
}
197+
model.setSummary(Some(rfSummary))
170198
}
171199

172200
@Since("1.4.1")
@@ -204,7 +232,8 @@ class RandomForestClassificationModel private[ml] (
204232
@Since("1.5.0") override val numClasses: Int)
205233
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
206234
with RandomForestClassifierParams with TreeEnsembleModel[DecisionTreeClassificationModel]
207-
with MLWritable with Serializable {
235+
with MLWritable with Serializable
236+
with HasTrainingSummary[RandomForestClassificationTrainingSummary] {
208237

209238
require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.")
210239

@@ -228,6 +257,65 @@ class RandomForestClassificationModel private[ml] (
228257
@Since("1.4.0")
229258
override def treeWeights: Array[Double] = _treeWeights
230259

260+
/**
261+
* Gets summary of model on training set. An exception is thrown
262+
* if `hasSummary` is false.
263+
*/
264+
@Since("3.1.0")
265+
override def summary: RandomForestClassificationTrainingSummary = super.summary
266+
267+
/**
268+
* Gets summary of model on training set. An exception is thrown
269+
* if `hasSummary` is false or it is a multiclass model.
270+
*/
271+
@Since("3.1.0")
272+
def binarySummary: BinaryRandomForestClassificationTrainingSummary = summary match {
273+
case b: BinaryRandomForestClassificationTrainingSummary => b
274+
case _ =>
275+
throw new RuntimeException("Cannot create a binary summary for a non-binary model" +
276+
s"(numClasses=${numClasses}), use summary instead.")
277+
}
278+
279+
/**
280+
* If the probability and prediction columns are set, this method returns the current model,
281+
* otherwise it generates new columns for them and sets them as columns on a new copy of
282+
* the current model
283+
*/
284+
private[classification] def findSummaryModel():
285+
(RandomForestClassificationModel, String, String) = {
286+
val model = if ($(probabilityCol).isEmpty && $(predictionCol).isEmpty) {
287+
copy(ParamMap.empty)
288+
.setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString)
289+
.setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
290+
} else if ($(probabilityCol).isEmpty) {
291+
copy(ParamMap.empty).setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString)
292+
} else if ($(predictionCol).isEmpty) {
293+
copy(ParamMap.empty).setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
294+
} else {
295+
this
296+
}
297+
(model, model.getProbabilityCol, model.getPredictionCol)
298+
}
299+
300+
/**
301+
* Evaluates the model on a test dataset.
302+
*
303+
* @param dataset Test dataset to evaluate model on.
304+
*/
305+
@Since("3.1.0")
306+
def evaluate(dataset: Dataset[_]): RandomForestClassificationSummary = {
307+
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
308+
// Handle possible missing or invalid prediction columns
309+
val (summaryModel, probabilityColName, predictionColName) = findSummaryModel()
310+
if (numClasses > 2) {
311+
new RandomForestClassificationSummaryImpl(summaryModel.transform(dataset),
312+
predictionColName, $(labelCol), weightColName)
313+
} else {
314+
new BinaryRandomForestClassificationSummaryImpl(summaryModel.transform(dataset),
315+
probabilityColName, predictionColName, $(labelCol), weightColName)
316+
}
317+
}
318+
231319
@Since("1.4.0")
232320
override def transformSchema(schema: StructType): StructType = {
233321
var outputSchema = super.transformSchema(schema)
@@ -388,3 +476,113 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica
388476
new RandomForestClassificationModel(uid, newTrees, numFeatures, numClasses)
389477
}
390478
}
479+
480+
/**
481+
* Abstraction for multiclass RandomForestClassification results for a given model.
482+
*/
483+
sealed trait RandomForestClassificationSummary extends ClassificationSummary {
484+
/**
485+
* Convenient method for casting to BinaryRandomForestClassificationSummary.
486+
* This method will throw an Exception if the summary is not a binary summary.
487+
*/
488+
@Since("3.1.0")
489+
def asBinary: BinaryRandomForestClassificationSummary = this match {
490+
case b: BinaryRandomForestClassificationSummary => b
491+
case _ =>
492+
throw new RuntimeException("Cannot cast to a binary summary.")
493+
}
494+
}
495+
496+
/**
497+
* Abstraction for multiclass RandomForestClassification training results.
498+
*/
499+
sealed trait RandomForestClassificationTrainingSummary extends RandomForestClassificationSummary
500+
with TrainingSummary
501+
502+
/**
503+
* Abstraction for BinaryRandomForestClassification results for a given model.
504+
*/
505+
sealed trait BinaryRandomForestClassificationSummary extends BinaryClassificationSummary
506+
507+
/**
508+
* Abstraction for BinaryRandomForestClassification training results.
509+
*/
510+
sealed trait BinaryRandomForestClassificationTrainingSummary extends
511+
BinaryRandomForestClassificationSummary with RandomForestClassificationTrainingSummary
512+
513+
/**
514+
* Multiclass RandomForestClassification training results.
515+
*
516+
* @param predictions dataframe output by the model's `transform` method.
517+
* @param predictionCol field in "predictions" which gives the prediction for a data instance as a
518+
* double.
519+
* @param labelCol field in "predictions" which gives the true label of each instance.
520+
* @param weightCol field in "predictions" which gives the weight of each instance.
521+
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
522+
*/
523+
private class RandomForestClassificationTrainingSummaryImpl(
524+
predictions: DataFrame,
525+
predictionCol: String,
526+
labelCol: String,
527+
weightCol: String,
528+
override val objectiveHistory: Array[Double])
529+
extends RandomForestClassificationSummaryImpl(
530+
predictions, predictionCol, labelCol, weightCol)
531+
with RandomForestClassificationTrainingSummary
532+
533+
/**
534+
* Multiclass RandomForestClassification results for a given model.
535+
*
536+
* @param predictions dataframe output by the model's `transform` method.
537+
* @param predictionCol field in "predictions" which gives the prediction for a data instance as a
538+
* double.
539+
* @param labelCol field in "predictions" which gives the true label of each instance.
540+
* @param weightCol field in "predictions" which gives the weight of each instance.
541+
*/
542+
private class RandomForestClassificationSummaryImpl(
543+
@transient override val predictions: DataFrame,
544+
override val predictionCol: String,
545+
override val labelCol: String,
546+
override val weightCol: String)
547+
extends RandomForestClassificationSummary
548+
549+
/**
550+
* Binary RandomForestClassification training results.
551+
*
552+
* @param predictions dataframe output by the model's `transform` method.
553+
* @param scoreCol field in "predictions" which gives the probability of each class as a vector.
554+
* @param predictionCol field in "predictions" which gives the prediction for a data instance as a
555+
* double.
556+
* @param labelCol field in "predictions" which gives the true label of each instance.
557+
* @param weightCol field in "predictions" which gives the weight of each instance.
558+
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
559+
*/
560+
private class BinaryRandomForestClassificationTrainingSummaryImpl(
561+
predictions: DataFrame,
562+
scoreCol: String,
563+
predictionCol: String,
564+
labelCol: String,
565+
weightCol: String,
566+
override val objectiveHistory: Array[Double])
567+
extends BinaryRandomForestClassificationSummaryImpl(
568+
predictions, scoreCol, predictionCol, labelCol, weightCol)
569+
with BinaryRandomForestClassificationTrainingSummary
570+
571+
/**
572+
* Binary RandomForestClassification for a given model.
573+
*
574+
* @param predictions dataframe output by the model's `transform` method.
575+
* @param scoreCol field in "predictions" which gives the prediction of
576+
* each class as a vector.
577+
* @param labelCol field in "predictions" which gives the true label of each instance.
578+
* @param weightCol field in "predictions" which gives the weight of each instance.
579+
*/
580+
private class BinaryRandomForestClassificationSummaryImpl(
581+
predictions: DataFrame,
582+
override val scoreCol: String,
583+
predictionCol: String,
584+
labelCol: String,
585+
weightCol: String)
586+
extends RandomForestClassificationSummaryImpl(
587+
predictions, predictionCol, labelCol, weightCol)
588+
with BinaryRandomForestClassificationSummary

mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
342342
blorModel2.summary.asBinary.weightedPrecision relTol 1e-6)
343343
assert(blorModel.summary.asBinary.weightedRecall ~==
344344
blorModel2.summary.asBinary.weightedRecall relTol 1e-6)
345-
assert(blorModel.summary.asBinary.asBinary.areaUnderROC ~==
345+
assert(blorModel.summary.asBinary.areaUnderROC ~==
346346
blorModel2.summary.asBinary.areaUnderROC relTol 1e-6)
347347

348348
assert(mlorSummary.accuracy ~== mlorSummary2.accuracy relTol 1e-6)

0 commit comments

Comments
 (0)