@@ -166,7 +166,35 @@ class RandomForestClassifier @Since("1.4.0") (
166166 val numFeatures = trees.head.numFeatures
167167 instr.logNumClasses(numClasses)
168168 instr.logNumFeatures(numFeatures)
169- new RandomForestClassificationModel (uid, trees, numFeatures, numClasses)
169+ createModel(dataset, trees, numFeatures, numClasses)
170+ }
171+
172+ private def createModel (
173+ dataset : Dataset [_],
174+ trees : Array [DecisionTreeClassificationModel ],
175+ numFeatures : Int ,
176+ numClasses : Int ): RandomForestClassificationModel = {
177+ val model = copyValues(new RandomForestClassificationModel (uid, trees, numFeatures, numClasses))
178+ val weightColName = if (! isDefined(weightCol)) " weightCol" else $(weightCol)
179+
180+ val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
181+ val rfSummary = if (numClasses <= 2 ) {
182+ new BinaryRandomForestClassificationTrainingSummaryImpl (
183+ summaryModel.transform(dataset),
184+ probabilityColName,
185+ predictionColName,
186+ $(labelCol),
187+ weightColName,
188+ Array (0.0 ))
189+ } else {
190+ new RandomForestClassificationTrainingSummaryImpl (
191+ summaryModel.transform(dataset),
192+ predictionColName,
193+ $(labelCol),
194+ weightColName,
195+ Array (0.0 ))
196+ }
197+ model.setSummary(Some (rfSummary))
170198 }
171199
172200 @ Since (" 1.4.1" )
@@ -204,7 +232,8 @@ class RandomForestClassificationModel private[ml] (
204232 @ Since (" 1.5.0" ) override val numClasses : Int )
205233 extends ProbabilisticClassificationModel [Vector , RandomForestClassificationModel ]
206234 with RandomForestClassifierParams with TreeEnsembleModel [DecisionTreeClassificationModel ]
207- with MLWritable with Serializable {
235+ with MLWritable with Serializable
236+ with HasTrainingSummary [RandomForestClassificationTrainingSummary ] {
208237
209238 require(_trees.nonEmpty, " RandomForestClassificationModel requires at least 1 tree." )
210239
@@ -228,6 +257,44 @@ class RandomForestClassificationModel private[ml] (
228257 @ Since (" 1.4.0" )
229258 override def treeWeights : Array [Double ] = _treeWeights
230259
260+ /**
261+ * Gets summary of model on training set. An exception is thrown
262+ * if `hasSummary` is false.
263+ */
264+ @ Since (" 3.1.0" )
265+ override def summary : RandomForestClassificationTrainingSummary = super .summary
266+
267+ /**
268+ * Gets summary of model on training set. An exception is thrown
269+ * if `hasSummary` is false or it is a multiclass model.
270+ */
271+ @ Since (" 3.1.0" )
272+ def binarySummary : BinaryRandomForestClassificationTrainingSummary = summary match {
273+ case b : BinaryRandomForestClassificationTrainingSummary => b
274+ case _ =>
275+ throw new RuntimeException (" Cannot create a binary summary for a non-binary model" +
276+ s " (numClasses= ${numClasses}), use summary instead. " )
277+ }
278+
279+ /**
280+ * Evaluates the model on a test dataset.
281+ *
282+ * @param dataset Test dataset to evaluate model on.
283+ */
284+ @ Since (" 3.1.0" )
285+ def evaluate (dataset : Dataset [_]): RandomForestClassificationSummary = {
286+ val weightColName = if (! isDefined(weightCol)) " weightCol" else $(weightCol)
287+ // Handle possible missing or invalid prediction columns
288+ val (summaryModel, probabilityColName, predictionColName) = findSummaryModel()
289+ if (numClasses > 2 ) {
290+ new RandomForestClassificationSummaryImpl (summaryModel.transform(dataset),
291+ predictionColName, $(labelCol), weightColName)
292+ } else {
293+ new BinaryRandomForestClassificationSummaryImpl (summaryModel.transform(dataset),
294+ probabilityColName, predictionColName, $(labelCol), weightColName)
295+ }
296+ }
297+
231298 @ Since (" 1.4.0" )
232299 override def transformSchema (schema : StructType ): StructType = {
233300 var outputSchema = super .transformSchema(schema)
@@ -388,3 +455,113 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica
388455 new RandomForestClassificationModel (uid, newTrees, numFeatures, numClasses)
389456 }
390457}
458+
459+ /**
460+ * Abstraction for multiclass RandomForestClassification results for a given model.
461+ */
462+ sealed trait RandomForestClassificationSummary extends ClassificationSummary {
463+ /**
464+ * Convenient method for casting to BinaryRandomForestClassificationSummary.
465+ * This method will throw an Exception if the summary is not a binary summary.
466+ */
467+ @ Since (" 3.1.0" )
468+ def asBinary : BinaryRandomForestClassificationSummary = this match {
469+ case b : BinaryRandomForestClassificationSummary => b
470+ case _ =>
471+ throw new RuntimeException (" Cannot cast to a binary summary." )
472+ }
473+ }
474+
475+ /**
476+ * Abstraction for multiclass RandomForestClassification training results.
477+ */
478+ sealed trait RandomForestClassificationTrainingSummary extends RandomForestClassificationSummary
479+ with TrainingSummary
480+
481+ /**
482+ * Abstraction for BinaryRandomForestClassification results for a given model.
483+ */
484+ sealed trait BinaryRandomForestClassificationSummary extends BinaryClassificationSummary
485+
486+ /**
487+ * Abstraction for BinaryRandomForestClassification training results.
488+ */
489+ sealed trait BinaryRandomForestClassificationTrainingSummary extends
490+ BinaryRandomForestClassificationSummary with RandomForestClassificationTrainingSummary
491+
492+ /**
493+ * Multiclass RandomForestClassification training results.
494+ *
495+ * @param predictions dataframe output by the model's `transform` method.
496+ * @param predictionCol field in "predictions" which gives the prediction for a data instance as a
497+ * double.
498+ * @param labelCol field in "predictions" which gives the true label of each instance.
499+ * @param weightCol field in "predictions" which gives the weight of each instance.
500+ * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
501+ */
502+ private class RandomForestClassificationTrainingSummaryImpl (
503+ predictions : DataFrame ,
504+ predictionCol : String ,
505+ labelCol : String ,
506+ weightCol : String ,
507+ override val objectiveHistory : Array [Double ])
508+ extends RandomForestClassificationSummaryImpl (
509+ predictions, predictionCol, labelCol, weightCol)
510+ with RandomForestClassificationTrainingSummary
511+
512+ /**
513+ * Multiclass RandomForestClassification results for a given model.
514+ *
515+ * @param predictions dataframe output by the model's `transform` method.
516+ * @param predictionCol field in "predictions" which gives the prediction for a data instance as a
517+ * double.
518+ * @param labelCol field in "predictions" which gives the true label of each instance.
519+ * @param weightCol field in "predictions" which gives the weight of each instance.
520+ */
521+ private class RandomForestClassificationSummaryImpl (
522+ @ transient override val predictions : DataFrame ,
523+ override val predictionCol : String ,
524+ override val labelCol : String ,
525+ override val weightCol : String )
526+ extends RandomForestClassificationSummary
527+
528+ /**
529+ * Binary RandomForestClassification training results.
530+ *
531+ * @param predictions dataframe output by the model's `transform` method.
532+ * @param scoreCol field in "predictions" which gives the probability of each class as a vector.
533+ * @param predictionCol field in "predictions" which gives the prediction for a data instance as a
534+ * double.
535+ * @param labelCol field in "predictions" which gives the true label of each instance.
536+ * @param weightCol field in "predictions" which gives the weight of each instance.
537+ * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
538+ */
539+ private class BinaryRandomForestClassificationTrainingSummaryImpl (
540+ predictions : DataFrame ,
541+ scoreCol : String ,
542+ predictionCol : String ,
543+ labelCol : String ,
544+ weightCol : String ,
545+ override val objectiveHistory : Array [Double ])
546+ extends BinaryRandomForestClassificationSummaryImpl (
547+ predictions, scoreCol, predictionCol, labelCol, weightCol)
548+ with BinaryRandomForestClassificationTrainingSummary
549+
550+ /**
551+ * Binary RandomForestClassification for a given model.
552+ *
553+ * @param predictions dataframe output by the model's `transform` method.
554+ * @param scoreCol field in "predictions" which gives the prediction of
555+ * each class as a vector.
556+ * @param labelCol field in "predictions" which gives the true label of each instance.
557+ * @param weightCol field in "predictions" which gives the weight of each instance.
558+ */
559+ private class BinaryRandomForestClassificationSummaryImpl (
560+ predictions : DataFrame ,
561+ override val scoreCol : String ,
562+ predictionCol : String ,
563+ labelCol : String ,
564+ weightCol : String )
565+ extends RandomForestClassificationSummaryImpl (
566+ predictions, predictionCol, labelCol, weightCol)
567+ with BinaryRandomForestClassificationSummary
0 commit comments