Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,14 @@ import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode}
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
DecisionTreeSuite => OldDecisionTreeSuite}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}

class DecisionTreeClassifierSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest {

import DecisionTreeClassifierSuite.compareAPIs
import testImplicits._
Expand Down Expand Up @@ -251,20 +250,18 @@ class DecisionTreeClassifierSuite

MLTestingUtils.checkCopyAndUids(dt, newTree)

val predictions = newTree.transform(newData)
.select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol)
.collect()

predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
assert(pred === rawPred.argmax,
s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.")
val sum = rawPred.toArray.sum
assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred,
"probability prediction mismatch")
testTransformer[(Vector, Double)](newData, newTree,
"prediction", "rawPrediction", "probability") {
case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
assert(pred === rawPred.argmax,
s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.")
val sum = rawPred.toArray.sum
assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred,
"probability prediction mismatch")
}

ProbabilisticClassifierSuite.testPredictMethods[
Vector, DecisionTreeClassificationModel](newTree, newData)
Vector, DecisionTreeClassificationModel](this, newTree, newData)
}

test("training with 1-category categorical feature") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,20 @@ import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.loss.LogLoss
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.util.Utils

/**
* Test suite for [[GBTClassifier]].
*/
class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
with DefaultReadWriteTest {
class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {

import testImplicits._
import GBTClassifierSuite.compareAPIs
Expand Down Expand Up @@ -126,30 +124,34 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext

// should predict all zeros
binaryModel.setThresholds(Array(0.0, 1.0))
val binaryZeroPredictions = binaryModel.transform(df).select("prediction").collect()
assert(binaryZeroPredictions.forall(_.getDouble(0) === 0.0))
testTransformer[(Double, Vector)](df, binaryModel, "prediction") {
case Row(prediction: Double) => prediction === 0.0
}

// should predict all ones
binaryModel.setThresholds(Array(1.0, 0.0))
val binaryOnePredictions = binaryModel.transform(df).select("prediction").collect()
assert(binaryOnePredictions.forall(_.getDouble(0) === 1.0))

testTransformer[(Double, Vector)](df, binaryModel, "prediction") {
case Row(prediction: Double) => prediction === 1.0
}

val gbtBase = new GBTClassifier
val model = gbtBase.fit(df)
val basePredictions = model.transform(df).select("prediction").collect()

// constant threshold scaling is the same as no thresholds
binaryModel.setThresholds(Array(1.0, 1.0))
val scaledPredictions = binaryModel.transform(df).select("prediction").collect()
assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) =>
scaled.getDouble(0) === base.getDouble(0)
})
testTransformerByGlobalCheckFunc[(Double, Vector)](df, binaryModel, "prediction") {
scaledPredictions: Seq[Row] =>
assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) =>
scaled.getDouble(0) === base.getDouble(0)
})
}

// force it to use the predict method
model.setRawPredictionCol("").setProbabilityCol("").setThresholds(Array(0, 1))
val predictionsWithPredict = model.transform(df).select("prediction").collect()
assert(predictionsWithPredict.forall(_.getDouble(0) === 0.0))
testTransformer[(Double, Vector)](df, model, "prediction") {
case Row(prediction: Double) => prediction === 0.0
}
}

test("GBTClassifier: Predictor, Classifier methods") {
Expand All @@ -169,61 +171,30 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
val blas = BLAS.getInstance()

val validationDataset = validationData.toDF(labelCol, featuresCol)
val results = gbtModel.transform(validationDataset)
// check that raw prediction is tree predictions dot tree weights
results.select(rawPredictionCol, featuresCol).collect().foreach {
case Row(raw: Vector, features: Vector) =>
testTransformer[(Double, Vector)](validationDataset, gbtModel,
"rawPrediction", "features", "probability", "prediction") {
case Row(raw: Vector, features: Vector, prob: Vector, pred: Double) =>
assert(raw.size === 2)
// check that raw prediction is tree predictions dot tree weights
val treePredictions = gbtModel.trees.map(_.rootNode.predictImpl(features).prediction)
val prediction = blas.ddot(gbtModel.numTrees, treePredictions, 1, gbtModel.treeWeights, 1)
assert(raw ~== Vectors.dense(-prediction, prediction) relTol eps)
}

// Compare rawPrediction with probability
results.select(rawPredictionCol, probabilityCol).collect().foreach {
case Row(raw: Vector, prob: Vector) =>
assert(raw.size === 2)
// Compare rawPrediction with probability
assert(prob.size === 2)
// Note: we should check other loss types for classification if they are added
val predFromRaw = raw.toDense.values.map(value => LogLoss.computeProbability(value))
assert(prob(0) ~== predFromRaw(0) relTol eps)
assert(prob(1) ~== predFromRaw(1) relTol eps)
assert(prob(0) + prob(1) ~== 1.0 absTol absEps)
}

// Compare prediction with probability
results.select(predictionCol, probabilityCol).collect().foreach {
case Row(pred: Double, prob: Vector) =>
// Compare prediction with probability
val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
assert(pred == predFromProb)
}

// force it to use raw2prediction
gbtModel.setRawPredictionCol(rawPredictionCol).setProbabilityCol("")
val resultsUsingRaw2Predict =
gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect()
resultsUsingRaw2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach {
case (pred1, pred2) => assert(pred1 === pred2)
}

// force it to use probability2prediction
gbtModel.setRawPredictionCol("").setProbabilityCol(probabilityCol)
val resultsUsingProb2Predict =
gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect()
resultsUsingProb2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach {
case (pred1, pred2) => assert(pred1 === pred2)
}

// force it to use predict
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why were these transformations and checks removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These testing code path has been covered by ProbabilisticClassifierSuite.testPredictMethods.

gbtModel.setRawPredictionCol("").setProbabilityCol("")
val resultsUsingPredict =
gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect()
resultsUsingPredict.zip(results.select(predictionCol).as[Double].collect()).foreach {
case (pred1, pred2) => assert(pred1 === pred2)
}

ProbabilisticClassifierSuite.testPredictMethods[
Vector, GBTClassificationModel](gbtModel, validationDataset)
Vector, GBTClassificationModel](this, gbtModel, validationDataset)
}

test("GBT parameter stepSize should be in interval (0, 1]") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,18 @@ import scala.util.Random

import breeze.linalg.{DenseVector => BDV}

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.classification.LinearSVCSuite._
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.optim.aggregator.HingeAggregator
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions.udf


class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
class LinearSVCSuite extends MLTest with DefaultReadWriteTest {

import testImplicits._

Expand Down Expand Up @@ -141,10 +139,11 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
threshold: Double,
expected: Set[(Int, Double)]): Unit = {
model.setThreshold(threshold)
val results = model.transform(df).select("id", "prediction").collect()
.map(r => (r.getInt(0), r.getDouble(1)))
.toSet
assert(results === expected, s"Failed for threshold = $threshold")
testTransformerByGlobalCheckFunc[(Int, Vector)](df, model, "id", "prediction") {
rows: Seq[Row] =>
val results = rows.map(r => (r.getInt(0), r.getDouble(1))).toSet
assert(results === expected, s"Failed for threshold = $threshold")
}
}

def checkResults(threshold: Double, expected: Set[(Int, Double)]): Unit = {
Expand Down
Loading