Skip to content

Commit efff28b

Browse files
committed
Separate estimator and model params for read/write test.
1 parent d2a8797 commit efff28b

22 files changed

+47
-31
lines changed

mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -372,16 +372,18 @@ class DecisionTreeClassifierSuite
372372
// Categorical splits with tree depth 2
373373
val categoricalData: DataFrame =
374374
TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2)
375-
testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings, checkModelData)
375+
testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings,
376+
allParamSettings, checkModelData)
376377

377378
// Continuous splits with tree depth 2
378379
val continuousData: DataFrame =
379380
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
380-
testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings, checkModelData)
381+
testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings,
382+
allParamSettings, checkModelData)
381383

382384
// Continuous splits with tree depth 0
383385
testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings ++ Map("maxDepth" -> 0),
384-
checkModelData)
386+
allParamSettings ++ Map("maxDepth" -> 0), checkModelData)
385387
}
386388
}
387389

mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,8 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
374374

375375
val continuousData: DataFrame =
376376
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
377-
testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData)
377+
testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings,
378+
allParamSettings, checkModelData)
378379
}
379380
}
380381

mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
202202
}
203203
val svm = new LinearSVC()
204204
testEstimatorAndModelReadWrite(svm, smallBinaryDataset, LinearSVCSuite.allParamSettings,
205-
checkModelData)
205+
LinearSVCSuite.allParamSettings, checkModelData)
206206
}
207207
}
208208

mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2089,7 +2089,7 @@ class LogisticRegressionSuite
20892089
}
20902090
val lr = new LogisticRegression()
20912091
testEstimatorAndModelReadWrite(lr, smallBinaryDataset, LogisticRegressionSuite.allParamSettings,
2092-
checkModelData)
2092+
LogisticRegressionSuite.allParamSettings, checkModelData)
20932093
}
20942094

20952095
test("should support all NumericType labels and weights, and not support other types") {

mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,8 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
280280
assert(model.theta === model2.theta)
281281
}
282282
val nb = new NaiveBayes()
283-
testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData)
283+
testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings,
284+
NaiveBayesSuite.allParamSettings, checkModelData)
284285
}
285286

286287
test("should support all NumericType labels and weights, and not support other types") {

mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,8 @@ class RandomForestClassifierSuite
218218

219219
val continuousData: DataFrame =
220220
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
221-
testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData)
221+
testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings,
222+
allParamSettings, checkModelData)
222223
}
223224
}
224225

mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,8 @@ class BisectingKMeansSuite
138138
assert(model.clusterCenters === model2.clusterCenters)
139139
}
140140
val bisectingKMeans = new BisectingKMeans()
141-
testEstimatorAndModelReadWrite(
142-
bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, checkModelData)
141+
testEstimatorAndModelReadWrite(bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings,
142+
BisectingKMeansSuite.allParamSettings, checkModelData)
143143
}
144144
}
145145

mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
163163
assert(model.gaussians.map(_.cov) === model2.gaussians.map(_.cov))
164164
}
165165
val gm = new GaussianMixture()
166-
testEstimatorAndModelReadWrite(gm, dataset,
166+
testEstimatorAndModelReadWrite(gm, dataset, GaussianMixtureSuite.allParamSettings,
167167
GaussianMixtureSuite.allParamSettings, checkModelData)
168168
}
169169

mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
150150
assert(model.clusterCenters === model2.clusterCenters)
151151
}
152152
val kmeans = new KMeans()
153-
testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData)
153+
testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings,
154+
KMeansSuite.allParamSettings, checkModelData)
154155
}
155156
}
156157

mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
250250
Vectors.dense(model2.getDocConcentration) absTol 1e-6)
251251
}
252252
val lda = new LDA()
253-
testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, checkModelData)
253+
testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings,
254+
LDASuite.allParamSettings, checkModelData)
254255
}
255256

256257
test("read/write DistributedLDAModel") {
@@ -271,6 +272,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
271272
}
272273
val lda = new LDA()
273274
testEstimatorAndModelReadWrite(lda, dataset,
275+
LDASuite.allParamSettings ++ Map("optimizer" -> "em"),
274276
LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData)
275277
}
276278

0 commit comments

Comments
 (0)