Skip to content

Commit 29b004c

Browse files
committed
update ParamsSuite
1 parent 94fd98e commit 29b004c

File tree

1 file changed

+30
-7
lines changed

1 file changed

+30
-7
lines changed

mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,20 @@ import org.scalatest.FunSuite
2121

2222
class ParamsSuite extends FunSuite {
2323

24-
val solver = new TestParams()
25-
import solver.{inputCol, maxIter}
26-
2724
test("param") {
25+
val solver = new TestParams()
26+
import solver.maxIter
27+
2828
assert(maxIter.name === "maxIter")
2929
assert(maxIter.doc === "max number of iterations")
3030
assert(maxIter.parent.eq(solver))
3131
assert(maxIter.toString === "maxIter: max number of iterations")
32-
assert(solver.getMaxIter === 10)
33-
assert(!solver.isSet(inputCol))
3432
}
3533

3634
test("param pair") {
35+
val solver = new TestParams()
36+
import solver.maxIter
37+
3738
val pair0 = maxIter -> 5
3839
val pair1 = maxIter.w(5)
3940
val pair2 = ParamPair(maxIter, 5)
@@ -44,6 +45,9 @@ class ParamsSuite extends FunSuite {
4445
}
4546

4647
test("param map") {
48+
val solver = new TestParams()
49+
import solver.{maxIter, inputCol}
50+
4751
val map0 = ParamMap.empty
4852

4953
assert(!map0.contains(maxIter))
@@ -77,23 +81,42 @@ class ParamsSuite extends FunSuite {
7781
}
7882

7983
test("params") {
84+
val solver = new TestParams()
85+
import solver.{maxIter, inputCol}
86+
8087
val params = solver.params
8188
assert(params.length === 2)
8289
assert(params(0).eq(inputCol), "params must be ordered by name")
8390
assert(params(1).eq(maxIter))
84-
assert(solver.explainParams() === Seq(inputCol, maxIter).mkString("\n"))
91+
92+
assert(!solver.isSet(maxIter))
93+
assert(solver.isDefined(maxIter))
94+
assert(solver.getMaxIter === 10)
95+
solver.setMaxIter(100)
96+
assert(solver.isSet(maxIter))
97+
assert(solver.getMaxIter === 100)
98+
assert(!solver.isSet(inputCol))
99+
assert(!solver.isDefined(inputCol))
100+
intercept[NoSuchElementException](solver.getInputCol)
101+
102+
assert(
103+
solver.explain(maxIter) === "maxIter: max number of iterations (default: 10, current: 100)")
104+
assert(solver.explain(inputCol) === "inputCol: input column name (undefined)")
105+
assert(solver.explainParams() === Seq(inputCol, maxIter).map(solver.explain).mkString("\n"))
106+
85107
assert(solver.getParam("inputCol").eq(inputCol))
86108
assert(solver.getParam("maxIter").eq(maxIter))
87109
intercept[NoSuchElementException] {
88110
solver.getParam("abc")
89111
}
90-
assert(!solver.isSet(inputCol))
112+
91113
intercept[IllegalArgumentException] {
92114
solver.validate()
93115
}
94116
solver.validate(ParamMap(inputCol -> "input"))
95117
solver.setInputCol("input")
96118
assert(solver.isSet(inputCol))
119+
assert(solver.isDefined(inputCol))
97120
assert(solver.getInputCol === "input")
98121
solver.validate()
99122
intercept[IllegalArgumentException] {

0 commit comments

Comments
 (0)