Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -372,16 +372,18 @@ class DecisionTreeClassifierSuite
// Categorical splits with tree depth 2
val categoricalData: DataFrame =
TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2)
testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings, checkModelData)
testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings,
Copy link
Contributor

@sethah sethah Mar 3, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would seriously reduce the amount of code changed (and therefore make this much easier to review :p ) to just add an extra constructor:

def testEstimatorAndModelReadWrite[
    E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable](
      estimator: E,
      dataset: Dataset[_],
      testParams: Map[String, Any],
      checkModelData: (M, M) => Unit): Unit = {
    testEstimatorAndModelReadWrite(estimator, dataset, testParams, testParams, checkModelData)
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sethah Thanks for your kindly remind. I was planing to write as your suggestion before I'm thinking:

  • Actually all Model only need to extends a fraction of Params from Estimator, so all testEstimatorAndModelReadWrite(estimator, dataset, testParams, checkModelData) should be changed to testEstimatorAndModelReadWrite(estimator, dataset, testEstimatorParams, testModelParams, checkModelData) eventually. I explicitly write with the later way in test suites to remind developers should separate their estimator and model params when adding new algorithms' read/write test. I'm afraid that developers are not aware of the separation if they refer other test suites and find almost all test cases only pass in testParams.
  • Though in the currently change, we pass in allParamSettings to both testEstimatorParams and testModelParams, this is because they share the same params set. Others like ALS will be pass in separate params. I think we should push forward to refactor *** and ***Model to separate their params, which could make models more succinct.
  • If this is a public API, I totally agree with you. However, this is an internal auxiliary function, I think all test cases will need to pass in separate params eventually, so I settle a matter at one go.

This is my two cents, I'm still open to hear your thoughts. If you have strongly opinion, I can update according your suggestion. Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good points. I still think it's better to just add the extra constructor, but I don't feel strongly about it. So we can proceed with whatever you feel is best. Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to let any refactoring of these tests happen as-needed. If there are specific cases that need to be done now, we should create JIRAs to track them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, actually almost all models are extends lots of params which are not necessary, I'd like to remove these params for models as todo list. I'll create JIRAs to track them. I'll merge this first since it blocks #17117. Thanks.

allParamSettings, checkModelData)

// Continuous splits with tree depth 2
val continuousData: DataFrame =
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings, checkModelData)
testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings,
allParamSettings, checkModelData)

// Continuous splits with tree depth 0
testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings ++ Map("maxDepth" -> 0),
checkModelData)
allParamSettings ++ Map("maxDepth" -> 0), checkModelData)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,8 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext

val continuousData: DataFrame =
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData)
testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings,
allParamSettings, checkModelData)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
}
val svm = new LinearSVC()
testEstimatorAndModelReadWrite(svm, smallBinaryDataset, LinearSVCSuite.allParamSettings,
checkModelData)
LinearSVCSuite.allParamSettings, checkModelData)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2089,7 +2089,7 @@ class LogisticRegressionSuite
}
val lr = new LogisticRegression()
testEstimatorAndModelReadWrite(lr, smallBinaryDataset, LogisticRegressionSuite.allParamSettings,
checkModelData)
LogisticRegressionSuite.allParamSettings, checkModelData)
}

test("should support all NumericType labels and weights, and not support other types") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,8 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
assert(model.theta === model2.theta)
}
val nb = new NaiveBayes()
testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData)
testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings,
NaiveBayesSuite.allParamSettings, checkModelData)
}

test("should support all NumericType labels and weights, and not support other types") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ class RandomForestClassifierSuite

val continuousData: DataFrame =
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData)
testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings,
allParamSettings, checkModelData)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ class BisectingKMeansSuite
assert(model.clusterCenters === model2.clusterCenters)
}
val bisectingKMeans = new BisectingKMeans()
testEstimatorAndModelReadWrite(
bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, checkModelData)
testEstimatorAndModelReadWrite(bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings,
BisectingKMeansSuite.allParamSettings, checkModelData)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
assert(model.gaussians.map(_.cov) === model2.gaussians.map(_.cov))
}
val gm = new GaussianMixture()
testEstimatorAndModelReadWrite(gm, dataset,
testEstimatorAndModelReadWrite(gm, dataset, GaussianMixtureSuite.allParamSettings,
GaussianMixtureSuite.allParamSettings, checkModelData)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(model.clusterCenters === model2.clusterCenters)
}
val kmeans = new KMeans()
testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData)
testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings,
KMeansSuite.allParamSettings, checkModelData)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
Vectors.dense(model2.getDocConcentration) absTol 1e-6)
}
val lda = new LDA()
testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, checkModelData)
testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings,
LDASuite.allParamSettings, checkModelData)
}

test("read/write DistributedLDAModel") {
Expand All @@ -271,6 +272,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
}
val lda = new LDA()
testEstimatorAndModelReadWrite(lda, dataset,
LDASuite.allParamSettings ++ Map("optimizer" -> "em"),
LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class BucketedRandomProjectionLSHSuite
}
val mh = new BucketedRandomProjectionLSH()
val settings = Map("inputCol" -> "keys", "outputCol" -> "values", "bucketLength" -> 1.0)
testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData)
testEstimatorAndModelReadWrite(mh, dataset, settings, settings, checkModelData)
}

test("hashFunction") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
assert(model.selectedFeatures === model2.selectedFeatures)
}
val nb = new ChiSqSelector
testEstimatorAndModelReadWrite(nb, dataset, ChiSqSelectorSuite.allParamSettings, checkModelData)
testEstimatorAndModelReadWrite(nb, dataset, ChiSqSelectorSuite.allParamSettings,
ChiSqSelectorSuite.allParamSettings, checkModelData)
}

test("should support all NumericType labels and not support other types") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
}
val mh = new MinHashLSH()
val settings = Map("inputCol" -> "keys", "outputCol" -> "values")
testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData)
testEstimatorAndModelReadWrite(mh, dataset, settings, settings, checkModelData)
}

test("hashFunction") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
model2.freqItemsets.sort("items").collect())
}
val fPGrowth = new FPGrowth()
testEstimatorAndModelReadWrite(
fPGrowth, dataset, FPGrowthSuite.allParamSettings, checkModelData)
testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings,
FPGrowthSuite.allParamSettings, checkModelData)
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -452,37 +452,26 @@ class ALSSuite
}

test("read/write") {
import ALSSuite._
val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)
val als = new ALS()
allEstimatorParamSettings.foreach { case (p, v) =>
als.set(als.getParam(p), v)
}
val spark = this.spark
import spark.implicits._
val model = als.fit(ratings.toDF())

// Test Estimator save/load
val als2 = testDefaultReadWrite(als)
allEstimatorParamSettings.foreach { case (p, v) =>
val param = als.getParam(p)
assert(als.get(param).get === als2.get(param).get)
}
import ALSSuite._
val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)

// Test Model save/load
val model2 = testDefaultReadWrite(model)
allModelParamSettings.foreach { case (p, v) =>
val param = model.getParam(p)
assert(model.get(param).get === model2.get(param).get)
}
assert(model.rank === model2.rank)
def getFactors(df: DataFrame): Set[(Int, Array[Float])] = {
df.select("id", "features").collect().map { case r =>
(r.getInt(0), r.getAs[Array[Float]](1))
}.toSet
}
assert(getFactors(model.userFactors) === getFactors(model2.userFactors))
assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors))

def checkModelData(model: ALSModel, model2: ALSModel): Unit = {
assert(model.rank === model2.rank)
assert(getFactors(model.userFactors) === getFactors(model2.userFactors))
assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors))
}

val als = new ALS()
testEstimatorAndModelReadWrite(als, ratings.toDF(), allEstimatorParamSettings,
allModelParamSettings, checkModelData)
}

test("input type validation") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,8 @@ class AFTSurvivalRegressionSuite
}
val aft = new AFTSurvivalRegression()
testEstimatorAndModelReadWrite(aft, datasetMultivariate,
AFTSurvivalRegressionSuite.allParamSettings, checkModelData)
AFTSurvivalRegressionSuite.allParamSettings, AFTSurvivalRegressionSuite.allParamSettings,
checkModelData)
}

test("SPARK-15892: Incorrectly merged AFTAggregator with zero total count") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,16 +165,17 @@ class DecisionTreeRegressorSuite
val categoricalData: DataFrame =
TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 0)
testEstimatorAndModelReadWrite(dt, categoricalData,
TreeTests.allParamSettings, checkModelData)
TreeTests.allParamSettings, TreeTests.allParamSettings, checkModelData)

// Continuous splits with tree depth 2
val continuousData: DataFrame =
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
testEstimatorAndModelReadWrite(dt, continuousData,
TreeTests.allParamSettings, checkModelData)
TreeTests.allParamSettings, TreeTests.allParamSettings, checkModelData)

// Continuous splits with tree depth 0
testEstimatorAndModelReadWrite(dt, continuousData,
TreeTests.allParamSettings ++ Map("maxDepth" -> 0),
TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "squared")
val continuousData: DataFrame =
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData)
testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings,
allParamSettings, checkModelData)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1418,6 +1418,7 @@ class GeneralizedLinearRegressionSuite

val glr = new GeneralizedLinearRegression()
testEstimatorAndModelReadWrite(glr, datasetPoissonLog,
GeneralizedLinearRegressionSuite.allParamSettings,
GeneralizedLinearRegressionSuite.allParamSettings, checkModelData)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class IsotonicRegressionSuite

val ir = new IsotonicRegression()
testEstimatorAndModelReadWrite(ir, dataset, IsotonicRegressionSuite.allParamSettings,
checkModelData)
IsotonicRegressionSuite.allParamSettings, checkModelData)
}

test("should support all NumericType labels and weights, and not support other types") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@ class LinearRegressionSuite
}
val lr = new LinearRegression()
testEstimatorAndModelReadWrite(lr, datasetWithWeight, LinearRegressionSuite.allParamSettings,
checkModelData)
LinearRegressionSuite.allParamSettings, checkModelData)
}

test("should support all NumericType labels and weights, and not support other types") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex

val continuousData: DataFrame =
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData)
testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings,
allParamSettings, checkModelData)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,12 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
* - Check Params on Estimator and Model
* - Compare model data
*
* This requires that the [[Estimator]] and [[Model]] share the same set of [[Param]]s.
* This requires that [[Model]]'s [[Param]]s should be a subset of [[Estimator]]'s [[Param]]s.
*
* @param estimator Estimator to test
* @param dataset Dataset to pass to [[Estimator.fit()]]
* @param testParams Set of [[Param]] values to set in estimator
* @param testEstimatorParams Set of [[Param]] values to set in estimator
* @param testModelParams Set of [[Param]] values to set in model
* @param checkModelData Method which takes the original and loaded [[Model]] and compares their
* data. This method does not need to check [[Param]] values.
* @tparam E Type of [[Estimator]]
Expand All @@ -99,24 +100,25 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable](
estimator: E,
dataset: Dataset[_],
testParams: Map[String, Any],
testEstimatorParams: Map[String, Any],
testModelParams: Map[String, Any],
checkModelData: (M, M) => Unit): Unit = {
// Set some Params to make sure set Params are serialized.
testParams.foreach { case (p, v) =>
testEstimatorParams.foreach { case (p, v) =>
estimator.set(estimator.getParam(p), v)
}
val model = estimator.fit(dataset)

// Test Estimator save/load
val estimator2 = testDefaultReadWrite(estimator)
testParams.foreach { case (p, v) =>
testEstimatorParams.foreach { case (p, v) =>
val param = estimator.getParam(p)
assert(estimator.get(param).get === estimator2.get(param).get)
}

// Test Model save/load
val model2 = testDefaultReadWrite(model)
testParams.foreach { case (p, v) =>
testModelParams.foreach { case (p, v) =>
val param = model.getParam(p)
assert(model.get(param).get === model2.get(param).get)
}
Expand Down