From f1993fc65e81814fe4cc8db9164e81e659adbf3f Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 23 Jun 2020 15:47:40 -0700 Subject: [PATCH 1/3] [SPARK-23631][ML][PySpark] Add summary to RandomForestClassificationModel --- .../ClassificationSummary.scala | 2 +- .../classification/LogisticRegression.scala | 8 +- .../RandomForestClassifier.scala | 202 +++++++++++++++++- .../LogisticRegressionSuite.scala | 2 +- .../RandomForestClassifierSuite.scala | 110 ++++++++++ python/pyspark/ml/classification.py | 79 ++++++- .../pyspark/ml/tests/test_training_summary.py | 79 ++++++- 7 files changed, 472 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ClassificationSummary.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ClassificationSummary.scala index e9ea38161d3c..9f3428db484c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ClassificationSummary.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ClassificationSummary.scala @@ -44,7 +44,7 @@ private[classification] trait ClassificationSummary extends Serializable { @Since("3.1.0") def labelCol: String - /** Field in "predictions" which gives the weight of each instance as a vector. */ + /** Field in "predictions" which gives the weight of each instance. */ @Since("3.1.0") def weightCol: String diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 20d619334f7b..bbf4dd805068 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1451,7 +1451,7 @@ sealed trait BinaryLogisticRegressionTrainingSummary extends BinaryLogisticRegre * double. * @param labelCol field in "predictions" which gives the true label of each instance. * @param featuresCol field in "predictions" which gives the features of each instance as a vector. - * @param weightCol field in "predictions" which gives the weight of each instance as a vector. + * @param weightCol field in "predictions" which gives the weight of each instance. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ private class LogisticRegressionTrainingSummaryImpl( @@ -1476,7 +1476,7 @@ private class LogisticRegressionTrainingSummaryImpl( * double. * @param labelCol field in "predictions" which gives the true label of each instance. * @param featuresCol field in "predictions" which gives the features of each instance as a vector. - * @param weightCol field in "predictions" which gives the weight of each instance as a vector. + * @param weightCol field in "predictions" which gives the weight of each instance. */ private class LogisticRegressionSummaryImpl( @transient override val predictions: DataFrame, @@ -1497,7 +1497,7 @@ private class LogisticRegressionSummaryImpl( * double. * @param labelCol field in "predictions" which gives the true label of each instance. * @param featuresCol field in "predictions" which gives the features of each instance as a vector. - * @param weightCol field in "predictions" which gives the weight of each instance as a vector. + * @param weightCol field in "predictions" which gives the weight of each instance. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ private class BinaryLogisticRegressionTrainingSummaryImpl( @@ -1522,7 +1522,7 @@ private class BinaryLogisticRegressionTrainingSummaryImpl( * each class as a double. * @param labelCol field in "predictions" which gives the true label of each instance. * @param featuresCol field in "predictions" which gives the features of each instance as a vector. - * @param weightCol field in "predictions" which gives the weight of each instance as a vector. + * @param weightCol field in "predictions" which gives the weight of each instance. */ private class BinaryLogisticRegressionSummaryImpl( predictions: DataFrame, diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index a316e472d967..af5401f9b599 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -166,7 +166,35 @@ class RandomForestClassifier @Since("1.4.0") ( val numFeatures = trees.head.numFeatures instr.logNumClasses(numClasses) instr.logNumFeatures(numFeatures) - new RandomForestClassificationModel(uid, trees, numFeatures, numClasses) + createModel(dataset, trees, numFeatures, numClasses) + } + + private def createModel( + dataset: Dataset[_], + trees: Array[DecisionTreeClassificationModel], + numFeatures: Int, + numClasses: Int): RandomForestClassificationModel = { + val model = copyValues(new RandomForestClassificationModel(uid, trees, numFeatures, numClasses)) + val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) + + val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel() + val rfSummary = if (numClasses <= 2) { + new BinaryRandomForestClassificationTrainingSummaryImpl( + summaryModel.transform(dataset), + probabilityColName, + predictionColName, + $(labelCol), + weightColName, + Array(0.0)) + } else { + new RandomForestClassificationTrainingSummaryImpl( + summaryModel.transform(dataset), + predictionColName, + $(labelCol), + weightColName, + Array(0.0)) + } + model.setSummary(Some(rfSummary)) } @Since("1.4.1") @@ -204,7 +232,8 @@ class RandomForestClassificationModel private[ml] ( @Since("1.5.0") override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel] with RandomForestClassifierParams with TreeEnsembleModel[DecisionTreeClassificationModel] - with MLWritable with Serializable { + with MLWritable with Serializable + with HasTrainingSummary[RandomForestClassificationTrainingSummary] { require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.") @@ -228,6 +257,65 @@ class RandomForestClassificationModel private[ml] ( @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights + /** + * Gets summary of model on training set. An exception is thrown + * if `hasSummary` is false. + */ + @Since("3.1.0") + override def summary: RandomForestClassificationTrainingSummary = super.summary + + /** + * Gets summary of model on training set. An exception is thrown + * if `hasSummary` is false or it is a multiclass model. + */ + @Since("3.1.0") + def binarySummary: BinaryRandomForestClassificationTrainingSummary = summary match { + case b: BinaryRandomForestClassificationTrainingSummary => b + case _ => + throw new RuntimeException("Cannot create a binary summary for a non-binary model" + + s"(numClasses=${numClasses}), use summary instead.") + } + + /** + * If the probability and prediction columns are set, this method returns the current model, + * otherwise it generates new columns for them and sets them as columns on a new copy of + * the current model + */ + private[classification] def findSummaryModel(): + (RandomForestClassificationModel, String, String) = { + val model = if ($(probabilityCol).isEmpty && $(predictionCol).isEmpty) { + copy(ParamMap.empty) + .setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString) + .setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString) + } else if ($(probabilityCol).isEmpty) { + copy(ParamMap.empty).setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString) + } else if ($(predictionCol).isEmpty) { + copy(ParamMap.empty).setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString) + } else { + this + } + (model, model.getProbabilityCol, model.getPredictionCol) + } + + /** + * Evaluates the model on a test dataset. + * + * @param dataset Test dataset to evaluate model on. + */ + @Since("3.1.0") + def evaluate(dataset: Dataset[_]): RandomForestClassificationSummary = { + val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) + // Handle possible missing or invalid prediction columns + val (summaryModel, probabilityColName, predictionColName) = findSummaryModel() + if (numClasses > 2) { + new RandomForestClassificationSummaryImpl(summaryModel.transform(dataset), + predictionColName, $(labelCol), weightColName) + } else { + new BinaryRandomForestClassificationSummaryImpl(summaryModel.transform(dataset), + probabilityColName, predictionColName, $(labelCol), weightColName) + } + } + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { var outputSchema = super.transformSchema(schema) @@ -388,3 +476,113 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica new RandomForestClassificationModel(uid, newTrees, numFeatures, numClasses) } } + +/** + * Abstraction for multiclass RandomForestClassification results for a given model. + */ +sealed trait RandomForestClassificationSummary extends ClassificationSummary { + /** + * Convenient method for casting to BinaryRandomForestClassificationSummary. + * This method will throw an Exception if the summary is not a binary summary. + */ + @Since("3.1.0") + def asBinary: BinaryRandomForestClassificationSummary = this match { + case b: BinaryRandomForestClassificationSummary => b + case _ => + throw new RuntimeException("Cannot cast to a binary summary.") + } +} + +/** + * Abstraction for multiclass RandomForestClassification training results. + */ +sealed trait RandomForestClassificationTrainingSummary extends RandomForestClassificationSummary + with TrainingSummary + +/** + * Abstraction for BinaryRandomForestClassification results for a given model. + */ +sealed trait BinaryRandomForestClassificationSummary extends BinaryClassificationSummary + +/** + * Abstraction for BinaryRandomForestClassification training results. + */ +sealed trait BinaryRandomForestClassificationTrainingSummary extends + BinaryRandomForestClassificationSummary with RandomForestClassificationTrainingSummary + +/** + * Multiclass RandomForestClassification training results. + * + * @param predictions dataframe output by the model's `transform` method. + * @param predictionCol field in "predictions" which gives the prediction for a data instance as a + * double. + * @param labelCol field in "predictions" which gives the true label of each instance. + * @param weightCol field in "predictions" which gives the weight of each instance. + * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. + */ +private class RandomForestClassificationTrainingSummaryImpl( + predictions: DataFrame, + predictionCol: String, + labelCol: String, + weightCol: String, + override val objectiveHistory: Array[Double]) + extends RandomForestClassificationSummaryImpl( + predictions, predictionCol, labelCol, weightCol) + with RandomForestClassificationTrainingSummary + +/** + * Multiclass RandomForestClassification results for a given model. + * + * @param predictions dataframe output by the model's `transform` method. + * @param predictionCol field in "predictions" which gives the prediction for a data instance as a + * double. + * @param labelCol field in "predictions" which gives the true label of each instance. + * @param weightCol field in "predictions" which gives the weight of each instance. + */ +private class RandomForestClassificationSummaryImpl( + @transient override val predictions: DataFrame, + override val predictionCol: String, + override val labelCol: String, + override val weightCol: String) + extends RandomForestClassificationSummary + +/** + * Binary RandomForestClassification training results. + * + * @param predictions dataframe output by the model's `transform` method. + * @param scoreCol field in "predictions" which gives the probability of each class as a vector. + * @param predictionCol field in "predictions" which gives the prediction for a data instance as a + * double. + * @param labelCol field in "predictions" which gives the true label of each instance. + * @param weightCol field in "predictions" which gives the weight of each instance. + * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. + */ +private class BinaryRandomForestClassificationTrainingSummaryImpl( + predictions: DataFrame, + scoreCol: String, + predictionCol: String, + labelCol: String, + weightCol: String, + override val objectiveHistory: Array[Double]) + extends BinaryRandomForestClassificationSummaryImpl( + predictions, scoreCol, predictionCol, labelCol, weightCol) + with BinaryRandomForestClassificationTrainingSummary + +/** + * Binary RandomForestClassification for a given model. + * + * @param predictions dataframe output by the model's `transform` method. + * @param scoreCol field in "predictions" which gives the prediction of + * each class as a vector. + * @param labelCol field in "predictions" which gives the true label of each instance. + * @param weightCol field in "predictions" which gives the weight of each instance. + */ +private class BinaryRandomForestClassificationSummaryImpl( + predictions: DataFrame, + override val scoreCol: String, + predictionCol: String, + labelCol: String, + weightCol: String) + extends RandomForestClassificationSummaryImpl( + predictions, predictionCol, labelCol, weightCol) + with BinaryRandomForestClassificationSummary diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index ecee531c88a8..56eadff6df07 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -342,7 +342,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { blorModel2.summary.asBinary.weightedPrecision relTol 1e-6) assert(blorModel.summary.asBinary.weightedRecall ~== blorModel2.summary.asBinary.weightedRecall relTol 1e-6) - assert(blorModel.summary.asBinary.asBinary.areaUnderROC ~== + assert(blorModel.summary.asBinary.areaUnderROC ~== blorModel2.summary.asBinary.areaUnderROC relTol 1e-6) assert(mlorSummary.accuracy ~== mlorSummary2.accuracy relTol 1e-6) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index e30e93ad4628..645a436fa0ad 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions._ /** * Test suite for [[RandomForestClassifier]]. @@ -296,6 +297,115 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest { } } + test("summary for binary and multiclass") { + val arr = new Array[LabeledPoint](300) + for (i <- 0 until 300) { + if (i < 100) { + arr(i) = new LabeledPoint(0.0, Vectors.dense(2.0, 2.0)) + } else if (i < 200) { + arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0)) + } else { + arr(i) = new LabeledPoint(2.0, Vectors.dense(0.0, 2.0)) + } + } + val rdd = sc.parallelize(arr) + val multinomialDataset = spark.createDataFrame(rdd) + + val rf = new RandomForestClassifier() + + val brfModel = rf.fit(binaryDataset) + assert(brfModel.summary.isInstanceOf[BinaryRandomForestClassificationTrainingSummary]) + assert(brfModel.summary.asBinary.isInstanceOf[BinaryRandomForestClassificationTrainingSummary]) + assert(brfModel.binarySummary.isInstanceOf[RandomForestClassificationTrainingSummary]) + assert(brfModel.summary.totalIterations === 0) + assert(brfModel.binarySummary.totalIterations === 0) + + val mrfModel = rf.fit(multinomialDataset) + assert(mrfModel.summary.isInstanceOf[RandomForestClassificationTrainingSummary]) + withClue("cannot get binary summary for multiclass model") { + intercept[RuntimeException] { + mrfModel.binarySummary + } + } + withClue("cannot cast summary to binary summary multiclass model") { + intercept[RuntimeException] { + mrfModel.summary.asBinary + } + } + assert(mrfModel.summary.totalIterations === 0) + + val brfSummary = brfModel.evaluate(binaryDataset) + val mrfSummary = mrfModel.evaluate(multinomialDataset) + assert(brfSummary.isInstanceOf[BinaryRandomForestClassificationSummary]) + assert(mrfSummary.isInstanceOf[RandomForestClassificationSummary]) + + assert(brfSummary.accuracy === brfModel.summary.accuracy) + assert(brfSummary.weightedPrecision === brfModel.summary.weightedPrecision) + assert(brfSummary.weightedRecall === brfModel.summary.weightedRecall) + assert(brfSummary.asBinary.areaUnderROC ~== brfModel.summary.asBinary.areaUnderROC relTol 1e-6) + + // verify instance weight works + val rf2 = new RandomForestClassifier() + .setWeightCol("weight") + + val binaryDatasetWithWeight = + binaryDataset.select(col("label"), col("features"), lit(2.5).as("weight")) + + val multinomialDatasetWithWeight = + multinomialDataset.select(col("label"), col("features"), lit(10.0).as("weight")) + + val brfModel2 = rf2.fit(binaryDatasetWithWeight) + assert(brfModel2.summary.isInstanceOf[BinaryRandomForestClassificationTrainingSummary]) + assert(brfModel2.summary.asBinary.isInstanceOf[BinaryRandomForestClassificationTrainingSummary]) + assert(brfModel2.binarySummary.isInstanceOf[BinaryRandomForestClassificationTrainingSummary]) + + val mrfModel2 = rf2.fit(multinomialDatasetWithWeight) + assert(mrfModel2.summary.isInstanceOf[RandomForestClassificationTrainingSummary]) + withClue("cannot get binary summary for multiclass model") { + intercept[RuntimeException] { + mrfModel2.binarySummary + } + } + withClue("cannot cast summary to binary summary multiclass model") { + intercept[RuntimeException] { + mrfModel2.summary.asBinary + } + } + + val brfSummary2 = brfModel2.evaluate(binaryDatasetWithWeight) + val mrfSummary2 = mrfModel2.evaluate(multinomialDatasetWithWeight) + assert(brfSummary2.isInstanceOf[BinaryRandomForestClassificationSummary]) + assert(mrfSummary2.isInstanceOf[RandomForestClassificationSummary]) + + assert(brfSummary2.accuracy === brfModel2.summary.accuracy) + assert(brfSummary2.weightedPrecision === brfModel2.summary.weightedPrecision) + assert(brfSummary2.weightedRecall === brfModel2.summary.weightedRecall) + assert(brfSummary2.asBinary.areaUnderROC ~== + brfModel2.summary.asBinary.areaUnderROC relTol 1e-6) + + assert(brfSummary.accuracy ~== brfSummary2.accuracy relTol 1e-6) + assert(brfSummary.weightedPrecision ~== brfSummary2.weightedPrecision relTol 1e-6) + assert(brfSummary.weightedRecall ~== brfSummary2.weightedRecall relTol 1e-6) + assert(brfSummary.asBinary.areaUnderROC ~== brfSummary2.asBinary.areaUnderROC relTol 1e-6) + + assert(brfModel.summary.asBinary.accuracy ~== + brfModel2.summary.asBinary.accuracy relTol 1e-6) + assert(brfModel.summary.asBinary.weightedPrecision ~== + brfModel2.summary.asBinary.weightedPrecision relTol 1e-6) + assert(brfModel.summary.asBinary.weightedRecall ~== + brfModel2.summary.asBinary.weightedRecall relTol 1e-6) + assert(brfModel.summary.asBinary.areaUnderROC ~== + brfModel2.summary.asBinary.areaUnderROC relTol 1e-6) + + assert(mrfSummary.accuracy ~== mrfSummary2.accuracy relTol 1e-6) + assert(mrfSummary.weightedPrecision ~== mrfSummary2.weightedPrecision relTol 1e-6) + assert(mrfSummary.weightedRecall ~== mrfSummary2.weightedRecall relTol 1e-6) + + assert(mrfModel.summary.accuracy ~== mrfModel2.summary.accuracy relTol 1e-6) + assert(mrfModel.summary.weightedPrecision ~== mrfModel2.summary.weightedPrecision relTol 1e-6) + assert(mrfModel.summary.weightedRecall ~==mrfModel2.summary.weightedRecall relTol 1e-6) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index bdd37c99df0a..d70932a1bc6f 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -46,6 +46,9 @@ 'DecisionTreeClassifier', 'DecisionTreeClassificationModel', 'GBTClassifier', 'GBTClassificationModel', 'RandomForestClassifier', 'RandomForestClassificationModel', + 'RandomForestClassificationSummary', 'RandomForestClassificationTrainingSummary', + 'BinaryRandomForestClassificationSummary', + 'BinaryRandomForestClassificationTrainingSummary', 'NaiveBayes', 'NaiveBayesModel', 'MultilayerPerceptronClassifier', 'MultilayerPerceptronClassificationModel', 'OneVsRest', 'OneVsRestModel', @@ -1762,7 +1765,7 @@ def setMinWeightFractionPerNode(self, value): class RandomForestClassificationModel(_TreeEnsembleModel, _JavaProbabilisticClassificationModel, _RandomForestClassifierParams, JavaMLWritable, - JavaMLReadable): + JavaMLReadable, HasTrainingSummary): """ Model fitted by RandomForestClassifier. @@ -1790,6 +1793,80 @@ def trees(self): """Trees in this ensemble. Warning: These have null parent Estimators.""" return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))] + @property + @since("3.1.0") + def summary(self): + """ + Gets summary (e.g. accuracy/precision/recall, objective history, total iterations) of model + trained on the training set. An exception is thrown if `trainingSummary is None`. + """ + if self.hasSummary: + if self.numClasses <= 2: + return BinaryRandomForestClassificationTrainingSummary( + super(RandomForestClassificationModel, self).summary) + else: + return RandomForestClassificationTrainingSummary( + super(RandomForestClassificationModel, self).summary) + else: + raise RuntimeError("No training summary available for this %s" % + self.__class__.__name__) + + @since("3.1.0") + def evaluate(self, dataset): + """ + Evaluates the model on a test dataset. + + :param dataset: + Test dataset to evaluate model on, where dataset is an + instance of :py:class:`pyspark.sql.DataFrame` + """ + if not isinstance(dataset, DataFrame): + raise ValueError("dataset must be a DataFrame but got %s." % type(dataset)) + java_rf_summary = self._call_java("evaluate", dataset) + if self.numClasses <= 2: + return BinaryRandomForestClassificationSummary(java_rf_summary) + else: + return RandomForestClassificationSummary(java_rf_summary) + + +class RandomForestClassificationSummary(_ClassificationSummary): + """ + Abstraction for RandomForestClassification Results for a given model. + .. versionadded:: 3.1.0 + """ + pass + + +@inherit_doc +class RandomForestClassificationTrainingSummary(RandomForestClassificationSummary, + _TrainingSummary): + """ + Abstraction for RandomForestClassificationTraining Training results. + .. versionadded:: 3.1.0 + """ + pass + + +@inherit_doc +class BinaryRandomForestClassificationSummary(_BinaryClassificationSummary): + """ + BinaryRandomForestClassification results for a given model. + + .. versionadded:: 3.1.0 + """ + pass + + +@inherit_doc +class BinaryRandomForestClassificationTrainingSummary(BinaryRandomForestClassificationSummary, + RandomForestClassificationTrainingSummary): + """ + BinaryRandomForestClassification training results for a given model. + + .. versionadded:: 3.1.0 + """ + pass + class _GBTClassifierParams(_GBTParams, _HasVarianceImpurity): """ diff --git a/python/pyspark/ml/tests/test_training_summary.py b/python/pyspark/ml/tests/test_training_summary.py index 19acd194f4dd..7d905793188b 100644 --- a/python/pyspark/ml/tests/test_training_summary.py +++ b/python/pyspark/ml/tests/test_training_summary.py @@ -22,7 +22,9 @@ basestring = str from pyspark.ml.classification import BinaryLogisticRegressionSummary, LinearSVC, \ - LinearSVCSummary, LogisticRegression, LogisticRegressionSummary + LinearSVCSummary, BinaryRandomForestClassificationSummary, LogisticRegression, \ + LogisticRegressionSummary, RandomForestClassificationSummary, \ + RandomForestClassifier from pyspark.ml.clustering import BisectingKMeans, GaussianMixture, KMeans from pyspark.ml.linalg import Vectors from pyspark.ml.regression import GeneralizedLinearRegression, LinearRegression @@ -235,6 +237,81 @@ def test_linear_svc_summary(self): self.assertTrue(isinstance(sameSummary, LinearSVCSummary)) self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) + def test_binary_randomforest_classification_summary(self): + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + rf = RandomForestClassifier(weightCol="weight") + model = rf.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + # test that api is callable and returns expected types + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.labelCol, "label") + self.assertEqual(s.predictionCol, "prediction") + self.assertEqual(s.totalIterations, 0) + self.assertTrue(isinstance(s.labels, list)) + self.assertTrue(isinstance(s.truePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.falsePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.precisionByLabel, list)) + self.assertTrue(isinstance(s.recallByLabel, list)) + self.assertTrue(isinstance(s.fMeasureByLabel(), list)) + self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list)) + self.assertTrue(isinstance(s.roc, DataFrame)) + self.assertAlmostEqual(s.areaUnderROC, 1.0, 2) + self.assertTrue(isinstance(s.pr, DataFrame)) + self.assertTrue(isinstance(s.fMeasureByThreshold, DataFrame)) + self.assertTrue(isinstance(s.precisionByThreshold, DataFrame)) + self.assertTrue(isinstance(s.recallByThreshold, DataFrame)) + self.assertAlmostEqual(s.accuracy, 1.0, 2) + self.assertAlmostEqual(s.weightedTruePositiveRate, 1.0, 2) + self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.0, 2) + self.assertAlmostEqual(s.weightedRecall, 1.0, 2) + self.assertAlmostEqual(s.weightedPrecision, 1.0, 2) + self.assertAlmostEqual(s.weightedFMeasure(), 1.0, 2) + self.assertAlmostEqual(s.weightedFMeasure(1.0), 1.0, 2) + # test evaluation (with training dataset) produces a summary with same values + # one check is enough to verify a summary is returned, Scala version runs full test + sameSummary = model.evaluate(df) + self.assertTrue(isinstance(sameSummary, BinaryRandomForestClassificationSummary)) + self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) + + def test_multiclass_randomforest_classification_summary(self): + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], [])), + (2.0, 2.0, Vectors.dense(2.0)), + (2.0, 2.0, Vectors.dense(1.9))], + ["label", "weight", "features"]) + rf = RandomForestClassifier(weightCol="weight") + model = rf.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + # test that api is callable and returns expected types + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.labelCol, "label") + self.assertEqual(s.predictionCol, "prediction") + self.assertEqual(s.totalIterations, 0) + self.assertTrue(isinstance(s.labels, list)) + self.assertTrue(isinstance(s.truePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.falsePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.precisionByLabel, list)) + self.assertTrue(isinstance(s.recallByLabel, list)) + self.assertTrue(isinstance(s.fMeasureByLabel(), list)) + self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list)) + self.assertAlmostEqual(s.accuracy, 1.0, 2) + self.assertAlmostEqual(s.weightedTruePositiveRate, 1.0, 2) + self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.0, 2) + self.assertAlmostEqual(s.weightedRecall, 1.0, 2) + self.assertAlmostEqual(s.weightedPrecision, 1.0, 2) + self.assertAlmostEqual(s.weightedFMeasure(), 1.0, 2) + self.assertAlmostEqual(s.weightedFMeasure(1.0), 1.0, 2) + # test evaluation (with training dataset) produces a summary with same values + # one check is enough to verify a summary is returned, Scala version runs full test + sameSummary = model.evaluate(df) + self.assertTrue(isinstance(sameSummary, RandomForestClassificationSummary)) + self.assertFalse(isinstance(sameSummary, BinaryRandomForestClassificationSummary)) + self.assertAlmostEqual(sameSummary.accuracy, s.accuracy) + def test_gaussian_mixture_summary(self): data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),), (Vectors.sparse(1, [], []),)] From 59ced0ab3f68756c5084624cc340b1e5260bd791 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sat, 27 Jun 2020 17:41:14 -0700 Subject: [PATCH 2/3] put findSummaryModel in super class --- .../spark/ml/classification/Classifier.scala | 23 +++++++++++++++++++ .../classification/LogisticRegression.scala | 21 ----------------- .../ProbabilisticClassifier.scala | 22 ++++++++++++++++++ .../RandomForestClassifier.scala | 21 ----------------- 4 files changed, 45 insertions(+), 42 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 1f3f291644f9..233e8e5bcdc8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -22,6 +22,7 @@ import org.apache.spark.annotation.Since import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{Vector, VectorUDT} +import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasRawPredictionCol import org.apache.spark.ml.util.{MetadataUtils, SchemaUtils} import org.apache.spark.rdd.RDD @@ -269,4 +270,26 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur * @return predicted label */ protected def raw2prediction(rawPrediction: Vector): Double = rawPrediction.argmax + + /** + * If the rawPrediction and prediction columns are set, this method returns the current model, + * otherwise it generates new columns for them and sets them as columns on a new copy of + * the current model + */ + private[classification] def findSummaryModel(): + (ClassificationModel[FeaturesType, M], String, String) = { + val model = if ($(rawPredictionCol).isEmpty && $(predictionCol).isEmpty) { + copy(ParamMap.empty) + .setRawPredictionCol("rawPrediction_" + java.util.UUID.randomUUID.toString) + .setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString) + } else if ($(rawPredictionCol).isEmpty) { + copy(ParamMap.empty).setRawPredictionCol("rawPrediction_" + + java.util.UUID.randomUUID.toString) + } else if ($(predictionCol).isEmpty) { + copy(ParamMap.empty).setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString) + } else { + this + } + (model, model.getRawPredictionCol, model.getPredictionCol) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index bbf4dd805068..47b3e2de7695 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1158,27 +1158,6 @@ class LogisticRegressionModel private[spark] ( s"(numClasses=${numClasses}), use summary instead.") } - /** - * If the probability and prediction columns are set, this method returns the current model, - * otherwise it generates new columns for them and sets them as columns on a new copy of - * the current model - */ - private[classification] def findSummaryModel(): - (LogisticRegressionModel, String, String) = { - val model = if ($(probabilityCol).isEmpty && $(predictionCol).isEmpty) { - copy(ParamMap.empty) - .setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString) - .setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString) - } else if ($(probabilityCol).isEmpty) { - copy(ParamMap.empty).setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString) - } else if ($(predictionCol).isEmpty) { - copy(ParamMap.empty).setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString) - } else { - this - } - (model, model.getProbabilityCol, model.getPredictionCol) - } - /** * Evaluates the model on a test dataset. * diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 9758e3ca72c3..00b2bede4ea4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.Since import org.apache.spark.ml.linalg.{DenseVector, Vector, VectorUDT} import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ @@ -229,6 +230,27 @@ abstract class ProbabilisticClassificationModel[ argMax } } + + /** + *If the probability and prediction columns are set, this method returns the current model, + * otherwise it generates new columns for them and sets them as columns on a new copy of + * the current model + */ + override private[classification] def findSummaryModel(): + (ProbabilisticClassificationModel[FeaturesType, M], String, String) = { + val model = if ($(probabilityCol).isEmpty && $(predictionCol).isEmpty) { + copy(ParamMap.empty) + .setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString) + .setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString) + } else if ($(probabilityCol).isEmpty) { + copy(ParamMap.empty).setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString) + } else if ($(predictionCol).isEmpty) { + copy(ParamMap.empty).setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString) + } else { + this + } + (model, model.getProbabilityCol, model.getPredictionCol) + } } private[ml] object ProbabilisticClassificationModel { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index af5401f9b599..f9ce62b91924 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -276,27 +276,6 @@ class RandomForestClassificationModel private[ml] ( s"(numClasses=${numClasses}), use summary instead.") } - /** - * If the probability and prediction columns are set, this method returns the current model, - * otherwise it generates new columns for them and sets them as columns on a new copy of - * the current model - */ - private[classification] def findSummaryModel(): - (RandomForestClassificationModel, String, String) = { - val model = if ($(probabilityCol).isEmpty && $(predictionCol).isEmpty) { - copy(ParamMap.empty) - .setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString) - .setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString) - } else if ($(probabilityCol).isEmpty) { - copy(ParamMap.empty).setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString) - } else if ($(predictionCol).isEmpty) { - copy(ParamMap.empty).setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString) - } else { - this - } - (model, model.getProbabilityCol, model.getPredictionCol) - } - /** * Evaluates the model on a test dataset. * From 55b52bdb28a001157d0b0b265c687023267bb58c Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sat, 27 Jun 2020 18:28:33 -0700 Subject: [PATCH 3/3] rebase --- .../spark/ml/classification/LinearSVC.scala | 21 ------------------- .../ProbabilisticClassifier.scala | 2 +- 2 files changed, 1 insertion(+), 22 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index 1659bbb1d34b..4adc527c89b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -394,27 +394,6 @@ class LinearSVCModel private[classification] ( @Since("3.1.0") override def summary: LinearSVCTrainingSummary = super.summary - /** - * If the rawPrediction and prediction columns are set, this method returns the current model, - * otherwise it generates new columns for them and sets them as columns on a new copy of - * the current model - */ - private[classification] def findSummaryModel(): (LinearSVCModel, String, String) = { - val model = if ($(rawPredictionCol).isEmpty && $(predictionCol).isEmpty) { - copy(ParamMap.empty) - .setRawPredictionCol("rawPrediction_" + java.util.UUID.randomUUID.toString) - .setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString) - } else if ($(rawPredictionCol).isEmpty) { - copy(ParamMap.empty).setRawPredictionCol("rawPrediction_" + - java.util.UUID.randomUUID.toString) - } else if ($(predictionCol).isEmpty) { - copy(ParamMap.empty).setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString) - } else { - this - } - (model, model.getRawPredictionCol, model.getPredictionCol) - } - /** * Evaluates the model on a test dataset. * diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 00b2bede4ea4..1caaeccd7b0d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -19,8 +19,8 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.Since import org.apache.spark.ml.linalg.{DenseVector, Vector, VectorUDT} -import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._