Skip to content

Commit 44948a2

Browse files
holdenkmengxr
authored andcommitted
[SPARK-9723] [ML] params getordefault should throw more useful error
Params.getOrDefault should throw a more meaningful exception than what you get from a bad key lookup. Author: Holden Karau <[email protected]> Closes #8567 from holdenk/SPARK-9723-params-getordefault-should-throw-more-useful-error.
1 parent 03f3e91 commit 44948a2

File tree

3 files changed

+15
-8
lines changed

3 files changed

+15
-8
lines changed

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,8 @@ trait Params extends Identifiable with Serializable {
461461
*/
462462
final def getOrDefault[T](param: Param[T]): T = {
463463
shouldOwn(param)
464-
get(param).orElse(getDefault(param)).get
464+
get(param).orElse(getDefault(param)).getOrElse(
465+
throw new NoSuchElementException(s"Failed to find a default value for ${param.name}"))
465466
}
466467

467468
/** An alias for [[getOrDefault()]]. */

mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ class ParamsSuite extends SparkFunSuite {
4040

4141
assert(inputCol.toString === s"${uid}__inputCol")
4242

43+
intercept[java.util.NoSuchElementException] {
44+
solver.getOrDefault(solver.handleInvalid)
45+
}
46+
4347
intercept[IllegalArgumentException] {
4448
solver.setMaxIter(-1)
4549
}
@@ -102,12 +106,13 @@ class ParamsSuite extends SparkFunSuite {
102106

103107
test("params") {
104108
val solver = new TestParams()
105-
import solver.{maxIter, inputCol}
109+
import solver.{handleInvalid, maxIter, inputCol}
106110

107111
val params = solver.params
108-
assert(params.length === 2)
109-
assert(params(0).eq(inputCol), "params must be ordered by name")
110-
assert(params(1).eq(maxIter))
112+
assert(params.length === 3)
113+
assert(params(0).eq(handleInvalid), "params must be ordered by name")
114+
assert(params(1).eq(inputCol), "params must be ordered by name")
115+
assert(params(2).eq(maxIter))
111116

112117
assert(!solver.isSet(maxIter))
113118
assert(solver.isDefined(maxIter))
@@ -122,7 +127,7 @@ class ParamsSuite extends SparkFunSuite {
122127
assert(solver.explainParam(maxIter) ===
123128
"maxIter: maximum number of iterations (>= 0) (default: 10, current: 100)")
124129
assert(solver.explainParams() ===
125-
Seq(inputCol, maxIter).map(solver.explainParam).mkString("\n"))
130+
Seq(handleInvalid, inputCol, maxIter).map(solver.explainParam).mkString("\n"))
126131

127132
assert(solver.getParam("inputCol").eq(inputCol))
128133
assert(solver.getParam("maxIter").eq(maxIter))

mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717

1818
package org.apache.spark.ml.param
1919

20-
import org.apache.spark.ml.param.shared.{HasInputCol, HasMaxIter}
20+
import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasMaxIter}
2121
import org.apache.spark.ml.util.Identifiable
2222

2323
/** A subclass of Params for testing. */
24-
class TestParams(override val uid: String) extends Params with HasMaxIter with HasInputCol {
24+
class TestParams(override val uid: String) extends Params with HasHandleInvalid with HasMaxIter
25+
with HasInputCol {
2526

2627
def this() = this(Identifiable.randomUID("testParams"))
2728

0 commit comments

Comments
 (0)