Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ private[classification] trait ClassificationSummary extends Serializable {
@Since("3.1.0")
def labelCol: String

/** Field in "predictions" which gives the weight of each instance as a vector. */
/** Field in "predictions" which gives the weight of each instance. */
@Since("3.1.0")
def weightCol: String

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.annotation.Since
import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasRawPredictionCol
import org.apache.spark.ml.util.{MetadataUtils, SchemaUtils}
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -269,4 +270,26 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
* @return predicted label
*/
protected def raw2prediction(rawPrediction: Vector): Double = rawPrediction.argmax

/**
* If the rawPrediction and prediction columns are set, this method returns the current model,
* otherwise it generates new columns for them and sets them as columns on a new copy of
* the current model
*/
private[classification] def findSummaryModel():
(ClassificationModel[FeaturesType, M], String, String) = {
val model = if ($(rawPredictionCol).isEmpty && $(predictionCol).isEmpty) {
copy(ParamMap.empty)
.setRawPredictionCol("rawPrediction_" + java.util.UUID.randomUUID.toString)
.setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
} else if ($(rawPredictionCol).isEmpty) {
copy(ParamMap.empty).setRawPredictionCol("rawPrediction_" +
java.util.UUID.randomUUID.toString)
} else if ($(predictionCol).isEmpty) {
copy(ParamMap.empty).setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
} else {
this
}
(model, model.getRawPredictionCol, model.getPredictionCol)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -394,27 +394,6 @@ class LinearSVCModel private[classification] (
@Since("3.1.0")
override def summary: LinearSVCTrainingSummary = super.summary

/**
* If the rawPrediction and prediction columns are set, this method returns the current model,
* otherwise it generates new columns for them and sets them as columns on a new copy of
* the current model
*/
private[classification] def findSummaryModel(): (LinearSVCModel, String, String) = {
val model = if ($(rawPredictionCol).isEmpty && $(predictionCol).isEmpty) {
copy(ParamMap.empty)
.setRawPredictionCol("rawPrediction_" + java.util.UUID.randomUUID.toString)
.setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
} else if ($(rawPredictionCol).isEmpty) {
copy(ParamMap.empty).setRawPredictionCol("rawPrediction_" +
java.util.UUID.randomUUID.toString)
} else if ($(predictionCol).isEmpty) {
copy(ParamMap.empty).setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
} else {
this
}
(model, model.getRawPredictionCol, model.getPredictionCol)
}

/**
* Evaluates the model on a test dataset.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1158,27 +1158,6 @@ class LogisticRegressionModel private[spark] (
s"(numClasses=${numClasses}), use summary instead.")
}

/**
* If the probability and prediction columns are set, this method returns the current model,
* otherwise it generates new columns for them and sets them as columns on a new copy of
* the current model
*/
private[classification] def findSummaryModel():
(LogisticRegressionModel, String, String) = {
val model = if ($(probabilityCol).isEmpty && $(predictionCol).isEmpty) {
copy(ParamMap.empty)
.setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString)
.setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
} else if ($(probabilityCol).isEmpty) {
copy(ParamMap.empty).setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString)
} else if ($(predictionCol).isEmpty) {
copy(ParamMap.empty).setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
} else {
this
}
(model, model.getProbabilityCol, model.getPredictionCol)
}

/**
* Evaluates the model on a test dataset.
*
Expand Down Expand Up @@ -1451,7 +1430,7 @@ sealed trait BinaryLogisticRegressionTrainingSummary extends BinaryLogisticRegre
* double.
* @param labelCol field in "predictions" which gives the true label of each instance.
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
* @param weightCol field in "predictions" which gives the weight of each instance as a vector.
* @param weightCol field in "predictions" which gives the weight of each instance.
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
*/
private class LogisticRegressionTrainingSummaryImpl(
Expand All @@ -1476,7 +1455,7 @@ private class LogisticRegressionTrainingSummaryImpl(
* double.
* @param labelCol field in "predictions" which gives the true label of each instance.
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
* @param weightCol field in "predictions" which gives the weight of each instance as a vector.
* @param weightCol field in "predictions" which gives the weight of each instance.
*/
private class LogisticRegressionSummaryImpl(
@transient override val predictions: DataFrame,
Expand All @@ -1497,7 +1476,7 @@ private class LogisticRegressionSummaryImpl(
* double.
* @param labelCol field in "predictions" which gives the true label of each instance.
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
* @param weightCol field in "predictions" which gives the weight of each instance as a vector.
* @param weightCol field in "predictions" which gives the weight of each instance.
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
*/
private class BinaryLogisticRegressionTrainingSummaryImpl(
Expand All @@ -1522,7 +1501,7 @@ private class BinaryLogisticRegressionTrainingSummaryImpl(
* each class as a double.
* @param labelCol field in "predictions" which gives the true label of each instance.
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
* @param weightCol field in "predictions" which gives the weight of each instance as a vector.
* @param weightCol field in "predictions" which gives the weight of each instance.
*/
private class BinaryLogisticRegressionSummaryImpl(
predictions: DataFrame,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.ml.classification

import org.apache.spark.annotation.Since
import org.apache.spark.ml.linalg.{DenseVector, Vector, VectorUDT}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.sql.{DataFrame, Dataset}
Expand Down Expand Up @@ -229,6 +230,27 @@ abstract class ProbabilisticClassificationModel[
argMax
}
}

/**
*If the probability and prediction columns are set, this method returns the current model,
* otherwise it generates new columns for them and sets them as columns on a new copy of
* the current model
*/
override private[classification] def findSummaryModel():
(ProbabilisticClassificationModel[FeaturesType, M], String, String) = {
val model = if ($(probabilityCol).isEmpty && $(predictionCol).isEmpty) {
copy(ParamMap.empty)
.setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString)
.setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
} else if ($(probabilityCol).isEmpty) {
copy(ParamMap.empty).setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString)
} else if ($(predictionCol).isEmpty) {
copy(ParamMap.empty).setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
} else {
this
}
(model, model.getProbabilityCol, model.getPredictionCol)
}
}

private[ml] object ProbabilisticClassificationModel {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,35 @@ class RandomForestClassifier @Since("1.4.0") (
val numFeatures = trees.head.numFeatures
instr.logNumClasses(numClasses)
instr.logNumFeatures(numFeatures)
new RandomForestClassificationModel(uid, trees, numFeatures, numClasses)
createModel(dataset, trees, numFeatures, numClasses)
}

private def createModel(
dataset: Dataset[_],
trees: Array[DecisionTreeClassificationModel],
numFeatures: Int,
numClasses: Int): RandomForestClassificationModel = {
val model = copyValues(new RandomForestClassificationModel(uid, trees, numFeatures, numClasses))
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)

val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
val rfSummary = if (numClasses <= 2) {
new BinaryRandomForestClassificationTrainingSummaryImpl(
summaryModel.transform(dataset),
probabilityColName,
predictionColName,
$(labelCol),
weightColName,
Array(0.0))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for non iterative algorithm, set objectiveHistory to Array(0.0).

} else {
new RandomForestClassificationTrainingSummaryImpl(
summaryModel.transform(dataset),
predictionColName,
$(labelCol),
weightColName,
Array(0.0))
}
model.setSummary(Some(rfSummary))
}

@Since("1.4.1")
Expand Down Expand Up @@ -204,7 +232,8 @@ class RandomForestClassificationModel private[ml] (
@Since("1.5.0") override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
with RandomForestClassifierParams with TreeEnsembleModel[DecisionTreeClassificationModel]
with MLWritable with Serializable {
with MLWritable with Serializable
with HasTrainingSummary[RandomForestClassificationTrainingSummary] {

require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.")

Expand All @@ -228,6 +257,44 @@ class RandomForestClassificationModel private[ml] (
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights

/**
* Gets summary of model on training set. An exception is thrown
* if `hasSummary` is false.
*/
@Since("3.1.0")
override def summary: RandomForestClassificationTrainingSummary = super.summary

/**
* Gets summary of model on training set. An exception is thrown
* if `hasSummary` is false or it is a multiclass model.
*/
@Since("3.1.0")
def binarySummary: BinaryRandomForestClassificationTrainingSummary = summary match {
case b: BinaryRandomForestClassificationTrainingSummary => b
case _ =>
throw new RuntimeException("Cannot create a binary summary for a non-binary model" +
s"(numClasses=${numClasses}), use summary instead.")
}

/**
* Evaluates the model on a test dataset.
*
* @param dataset Test dataset to evaluate model on.
*/
@Since("3.1.0")
def evaluate(dataset: Dataset[_]): RandomForestClassificationSummary = {
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
// Handle possible missing or invalid prediction columns
val (summaryModel, probabilityColName, predictionColName) = findSummaryModel()
if (numClasses > 2) {
new RandomForestClassificationSummaryImpl(summaryModel.transform(dataset),
predictionColName, $(labelCol), weightColName)
} else {
new BinaryRandomForestClassificationSummaryImpl(summaryModel.transform(dataset),
probabilityColName, predictionColName, $(labelCol), weightColName)
}
}

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
var outputSchema = super.transformSchema(schema)
Expand Down Expand Up @@ -388,3 +455,113 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica
new RandomForestClassificationModel(uid, newTrees, numFeatures, numClasses)
}
}

/**
* Abstraction for multiclass RandomForestClassification results for a given model.
*/
sealed trait RandomForestClassificationSummary extends ClassificationSummary {
/**
* Convenient method for casting to BinaryRandomForestClassificationSummary.
* This method will throw an Exception if the summary is not a binary summary.
*/
@Since("3.1.0")
def asBinary: BinaryRandomForestClassificationSummary = this match {
case b: BinaryRandomForestClassificationSummary => b
case _ =>
throw new RuntimeException("Cannot cast to a binary summary.")
}
}

/**
* Abstraction for multiclass RandomForestClassification training results.
*/
sealed trait RandomForestClassificationTrainingSummary extends RandomForestClassificationSummary
with TrainingSummary

/**
* Abstraction for BinaryRandomForestClassification results for a given model.
*/
sealed trait BinaryRandomForestClassificationSummary extends BinaryClassificationSummary

/**
* Abstraction for BinaryRandomForestClassification training results.
*/
sealed trait BinaryRandomForestClassificationTrainingSummary extends
BinaryRandomForestClassificationSummary with RandomForestClassificationTrainingSummary

/**
* Multiclass RandomForestClassification training results.
*
* @param predictions dataframe output by the model's `transform` method.
* @param predictionCol field in "predictions" which gives the prediction for a data instance as a
* double.
* @param labelCol field in "predictions" which gives the true label of each instance.
* @param weightCol field in "predictions" which gives the weight of each instance.
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
*/
private class RandomForestClassificationTrainingSummaryImpl(
predictions: DataFrame,
predictionCol: String,
labelCol: String,
weightCol: String,
override val objectiveHistory: Array[Double])
extends RandomForestClassificationSummaryImpl(
predictions, predictionCol, labelCol, weightCol)
with RandomForestClassificationTrainingSummary

/**
* Multiclass RandomForestClassification results for a given model.
*
* @param predictions dataframe output by the model's `transform` method.
* @param predictionCol field in "predictions" which gives the prediction for a data instance as a
* double.
* @param labelCol field in "predictions" which gives the true label of each instance.
* @param weightCol field in "predictions" which gives the weight of each instance.
*/
private class RandomForestClassificationSummaryImpl(
@transient override val predictions: DataFrame,
override val predictionCol: String,
override val labelCol: String,
override val weightCol: String)
extends RandomForestClassificationSummary

/**
* Binary RandomForestClassification training results.
*
* @param predictions dataframe output by the model's `transform` method.
* @param scoreCol field in "predictions" which gives the probability of each class as a vector.
* @param predictionCol field in "predictions" which gives the prediction for a data instance as a
* double.
* @param labelCol field in "predictions" which gives the true label of each instance.
* @param weightCol field in "predictions" which gives the weight of each instance.
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
*/
private class BinaryRandomForestClassificationTrainingSummaryImpl(
predictions: DataFrame,
scoreCol: String,
predictionCol: String,
labelCol: String,
weightCol: String,
override val objectiveHistory: Array[Double])
extends BinaryRandomForestClassificationSummaryImpl(
predictions, scoreCol, predictionCol, labelCol, weightCol)
with BinaryRandomForestClassificationTrainingSummary

/**
* Binary RandomForestClassification for a given model.
*
* @param predictions dataframe output by the model's `transform` method.
* @param scoreCol field in "predictions" which gives the prediction of
* each class as a vector.
* @param labelCol field in "predictions" which gives the true label of each instance.
* @param weightCol field in "predictions" which gives the weight of each instance.
*/
private class BinaryRandomForestClassificationSummaryImpl(
predictions: DataFrame,
override val scoreCol: String,
predictionCol: String,
labelCol: String,
weightCol: String)
extends RandomForestClassificationSummaryImpl(
predictions, predictionCol, labelCol, weightCol)
with BinaryRandomForestClassificationSummary
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
blorModel2.summary.asBinary.weightedPrecision relTol 1e-6)
assert(blorModel.summary.asBinary.weightedRecall ~==
blorModel2.summary.asBinary.weightedRecall relTol 1e-6)
assert(blorModel.summary.asBinary.asBinary.areaUnderROC ~==
assert(blorModel.summary.asBinary.areaUnderROC ~==
blorModel2.summary.asBinary.areaUnderROC relTol 1e-6)

assert(mlorSummary.accuracy ~== mlorSummary2.accuracy relTol 1e-6)
Expand Down
Loading