@@ -166,7 +166,35 @@ class RandomForestClassifier @Since("1.4.0") (
166166 val numFeatures = trees.head.numFeatures
167167 instr.logNumClasses(numClasses)
168168 instr.logNumFeatures(numFeatures)
169- new RandomForestClassificationModel (uid, trees, numFeatures, numClasses)
169+ createModel(dataset, trees, numFeatures, numClasses)
170+ }
171+
172+ private def createModel (
173+ dataset : Dataset [_],
174+ trees : Array [DecisionTreeClassificationModel ],
175+ numFeatures : Int ,
176+ numClasses : Int ): RandomForestClassificationModel = {
177+ val model = copyValues(new RandomForestClassificationModel (uid, trees, numFeatures, numClasses))
178+ val weightColName = if (! isDefined(weightCol)) " weightCol" else $(weightCol)
179+
180+ val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
181+ val rfSummary = if (numClasses <= 2 ) {
182+ new BinaryRandomForestClassificationTrainingSummaryImpl (
183+ summaryModel.transform(dataset),
184+ probabilityColName,
185+ predictionColName,
186+ $(labelCol),
187+ weightColName,
188+ Array (0.0 ))
189+ } else {
190+ new RandomForestClassificationTrainingSummaryImpl (
191+ summaryModel.transform(dataset),
192+ predictionColName,
193+ $(labelCol),
194+ weightColName,
195+ Array (0.0 ))
196+ }
197+ model.setSummary(Some (rfSummary))
170198 }
171199
172200 @ Since (" 1.4.1" )
@@ -204,7 +232,8 @@ class RandomForestClassificationModel private[ml] (
204232 @ Since (" 1.5.0" ) override val numClasses : Int )
205233 extends ProbabilisticClassificationModel [Vector , RandomForestClassificationModel ]
206234 with RandomForestClassifierParams with TreeEnsembleModel [DecisionTreeClassificationModel ]
207- with MLWritable with Serializable {
235+ with MLWritable with Serializable
236+ with HasTrainingSummary [RandomForestClassificationTrainingSummary ] {
208237
209238 require(_trees.nonEmpty, " RandomForestClassificationModel requires at least 1 tree." )
210239
@@ -228,6 +257,65 @@ class RandomForestClassificationModel private[ml] (
228257 @ Since (" 1.4.0" )
229258 override def treeWeights : Array [Double ] = _treeWeights
230259
260+ /**
261+ * Gets summary of model on training set. An exception is thrown
262+ * if `hasSummary` is false.
263+ */
264+ @ Since (" 3.1.0" )
265+ override def summary : RandomForestClassificationTrainingSummary = super .summary
266+
267+ /**
268+ * Gets summary of model on training set. An exception is thrown
269+ * if `hasSummary` is false or it is a multiclass model.
270+ */
271+ @ Since (" 3.1.0" )
272+ def binarySummary : BinaryRandomForestClassificationTrainingSummary = summary match {
273+ case b : BinaryRandomForestClassificationTrainingSummary => b
274+ case _ =>
275+ throw new RuntimeException (" Cannot create a binary summary for a non-binary model" +
276+ s " (numClasses= ${numClasses}), use summary instead. " )
277+ }
278+
279+ /**
280+ * If the probability and prediction columns are set, this method returns the current model,
281+ * otherwise it generates new columns for them and sets them as columns on a new copy of
282+ * the current model
283+ */
284+ private [classification] def findSummaryModel ():
285+ (RandomForestClassificationModel , String , String ) = {
286+ val model = if ($(probabilityCol).isEmpty && $(predictionCol).isEmpty) {
287+ copy(ParamMap .empty)
288+ .setProbabilityCol(" probability_" + java.util.UUID .randomUUID.toString)
289+ .setPredictionCol(" prediction_" + java.util.UUID .randomUUID.toString)
290+ } else if ($(probabilityCol).isEmpty) {
291+ copy(ParamMap .empty).setProbabilityCol(" probability_" + java.util.UUID .randomUUID.toString)
292+ } else if ($(predictionCol).isEmpty) {
293+ copy(ParamMap .empty).setPredictionCol(" prediction_" + java.util.UUID .randomUUID.toString)
294+ } else {
295+ this
296+ }
297+ (model, model.getProbabilityCol, model.getPredictionCol)
298+ }
299+
300+ /**
301+ * Evaluates the model on a test dataset.
302+ *
303+ * @param dataset Test dataset to evaluate model on.
304+ */
305+ @ Since (" 3.1.0" )
306+ def evaluate (dataset : Dataset [_]): RandomForestClassificationSummary = {
307+ val weightColName = if (! isDefined(weightCol)) " weightCol" else $(weightCol)
308+ // Handle possible missing or invalid prediction columns
309+ val (summaryModel, probabilityColName, predictionColName) = findSummaryModel()
310+ if (numClasses > 2 ) {
311+ new RandomForestClassificationSummaryImpl (summaryModel.transform(dataset),
312+ predictionColName, $(labelCol), weightColName)
313+ } else {
314+ new BinaryRandomForestClassificationSummaryImpl (summaryModel.transform(dataset),
315+ probabilityColName, predictionColName, $(labelCol), weightColName)
316+ }
317+ }
318+
231319 @ Since (" 1.4.0" )
232320 override def transformSchema (schema : StructType ): StructType = {
233321 var outputSchema = super .transformSchema(schema)
@@ -388,3 +476,113 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica
388476 new RandomForestClassificationModel (uid, newTrees, numFeatures, numClasses)
389477 }
390478}
479+
480+ /**
481+ * Abstraction for multiclass RandomForestClassification results for a given model.
482+ */
483+ sealed trait RandomForestClassificationSummary extends ClassificationSummary {
484+ /**
485+ * Convenient method for casting to BinaryRandomForestClassificationSummary.
486+ * This method will throw an Exception if the summary is not a binary summary.
487+ */
488+ @ Since (" 3.1.0" )
489+ def asBinary : BinaryRandomForestClassificationSummary = this match {
490+ case b : BinaryRandomForestClassificationSummary => b
491+ case _ =>
492+ throw new RuntimeException (" Cannot cast to a binary summary." )
493+ }
494+ }
495+
496+ /**
497+ * Abstraction for multiclass RandomForestClassification training results.
498+ */
499+ sealed trait RandomForestClassificationTrainingSummary extends RandomForestClassificationSummary
500+ with TrainingSummary
501+
502+ /**
503+ * Abstraction for BinaryRandomForestClassification results for a given model.
504+ */
505+ sealed trait BinaryRandomForestClassificationSummary extends BinaryClassificationSummary
506+
507+ /**
508+ * Abstraction for BinaryRandomForestClassification training results.
509+ */
510+ sealed trait BinaryRandomForestClassificationTrainingSummary extends
511+ BinaryRandomForestClassificationSummary with RandomForestClassificationTrainingSummary
512+
513+ /**
514+ * Multiclass RandomForestClassification training results.
515+ *
516+ * @param predictions dataframe output by the model's `transform` method.
517+ * @param predictionCol field in "predictions" which gives the prediction for a data instance as a
518+ * double.
519+ * @param labelCol field in "predictions" which gives the true label of each instance.
520+ * @param weightCol field in "predictions" which gives the weight of each instance.
521+ * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
522+ */
523+ private class RandomForestClassificationTrainingSummaryImpl (
524+ predictions : DataFrame ,
525+ predictionCol : String ,
526+ labelCol : String ,
527+ weightCol : String ,
528+ override val objectiveHistory : Array [Double ])
529+ extends RandomForestClassificationSummaryImpl (
530+ predictions, predictionCol, labelCol, weightCol)
531+ with RandomForestClassificationTrainingSummary
532+
533+ /**
534+ * Multiclass RandomForestClassification results for a given model.
535+ *
536+ * @param predictions dataframe output by the model's `transform` method.
537+ * @param predictionCol field in "predictions" which gives the prediction for a data instance as a
538+ * double.
539+ * @param labelCol field in "predictions" which gives the true label of each instance.
540+ * @param weightCol field in "predictions" which gives the weight of each instance.
541+ */
542+ private class RandomForestClassificationSummaryImpl (
543+ @ transient override val predictions : DataFrame ,
544+ override val predictionCol : String ,
545+ override val labelCol : String ,
546+ override val weightCol : String )
547+ extends RandomForestClassificationSummary
548+
549+ /**
550+ * Binary RandomForestClassification training results.
551+ *
552+ * @param predictions dataframe output by the model's `transform` method.
553+ * @param scoreCol field in "predictions" which gives the probability of each class as a vector.
554+ * @param predictionCol field in "predictions" which gives the prediction for a data instance as a
555+ * double.
556+ * @param labelCol field in "predictions" which gives the true label of each instance.
557+ * @param weightCol field in "predictions" which gives the weight of each instance.
558+ * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
559+ */
560+ private class BinaryRandomForestClassificationTrainingSummaryImpl (
561+ predictions : DataFrame ,
562+ scoreCol : String ,
563+ predictionCol : String ,
564+ labelCol : String ,
565+ weightCol : String ,
566+ override val objectiveHistory : Array [Double ])
567+ extends BinaryRandomForestClassificationSummaryImpl (
568+ predictions, scoreCol, predictionCol, labelCol, weightCol)
569+ with BinaryRandomForestClassificationTrainingSummary
570+
571+ /**
572+ * Binary RandomForestClassification for a given model.
573+ *
574+ * @param predictions dataframe output by the model's `transform` method.
575+ * @param scoreCol field in "predictions" which gives the prediction of
576+ * each class as a vector.
577+ * @param labelCol field in "predictions" which gives the true label of each instance.
578+ * @param weightCol field in "predictions" which gives the weight of each instance.
579+ */
580+ private class BinaryRandomForestClassificationSummaryImpl (
581+ predictions : DataFrame ,
582+ override val scoreCol : String ,
583+ predictionCol : String ,
584+ labelCol : String ,
585+ weightCol : String )
586+ extends RandomForestClassificationSummaryImpl (
587+ predictions, predictionCol, labelCol, weightCol)
588+ with BinaryRandomForestClassificationSummary
0 commit comments