From 98fb73665a6cf71c2b7b9e17c9a519abec0d43b6 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 2 Sep 2015 00:07:00 -0700 Subject: [PATCH 1/5] Make a test for the current state of the world --- .../org/apache/spark/ml/param/ParamsSuite.scala | 15 ++++++++++----- .../org/apache/spark/ml/param/TestParams.scala | 5 +++-- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 2c878f8372a47..dfab82c8b67ad 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -40,6 +40,10 @@ class ParamsSuite extends SparkFunSuite { assert(inputCol.toString === s"${uid}__inputCol") + intercept[java.util.NoSuchElementException] { + solver.getOrDefault(solver.handleInvalid) + } + intercept[IllegalArgumentException] { solver.setMaxIter(-1) } @@ -102,12 +106,13 @@ class ParamsSuite extends SparkFunSuite { test("params") { val solver = new TestParams() - import solver.{maxIter, inputCol} + import solver.{handleInvalid, maxIter, inputCol} val params = solver.params - assert(params.length === 2) - assert(params(0).eq(inputCol), "params must be ordered by name") - assert(params(1).eq(maxIter)) + assert(params.length === 3) + assert(params(0).eq(handleInvalid), "params must be ordered by name") + assert(params(1).eq(inputCol), "params must be ordered by name") + assert(params(2).eq(maxIter)) assert(!solver.isSet(maxIter)) assert(solver.isDefined(maxIter)) @@ -122,7 +127,7 @@ class ParamsSuite extends SparkFunSuite { assert(solver.explainParam(maxIter) === "maxIter: maximum number of iterations (>= 0) (default: 10, current: 100)") assert(solver.explainParams() === - Seq(inputCol, maxIter).map(solver.explainParam).mkString("\n")) + Seq(handleInvalid, inputCol, maxIter).map(solver.explainParam).mkString("\n")) assert(solver.getParam("inputCol").eq(inputCol)) assert(solver.getParam("maxIter").eq(maxIter)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index 2759248344531..9d23547f28447 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -17,11 +17,12 @@ package org.apache.spark.ml.param -import org.apache.spark.ml.param.shared.{HasInputCol, HasMaxIter} +import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasMaxIter} import org.apache.spark.ml.util.Identifiable /** A subclass of Params for testing. */ -class TestParams(override val uid: String) extends Params with HasMaxIter with HasInputCol { +class TestParams(override val uid: String) extends Params with HasHandleInvalid with HasMaxIter + with HasInputCol { def this() = this(Identifiable.randomUID("testParams")) From 2f6c14d8678e2c4a81ab71a200dfa4cc4b8deea2 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 2 Sep 2015 00:10:25 -0700 Subject: [PATCH 2/5] Try keeping the same exception, but providing a string which explains the detail --- .../src/main/scala/org/apache/spark/ml/param/params.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 91c0a5631319d..31d60b4875a9b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -501,7 +501,13 @@ trait Params extends Identifiable with Serializable { */ final def getDefault[T](param: Param[T]): Option[T] = { shouldOwn(param) - defaultParamMap.get(param) + try { + defaultParamMap.get(param) + } catch { + case e: NoSuchElementException => + throw new NoSuchElementException("Failed to find a default value for param" + + param.name) + } } /** From 1b502b92e5f6df17e52f8ccdac72565cd8ce1fb1 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 2 Sep 2015 13:12:22 -0700 Subject: [PATCH 3/5] Use string interpolation --- mllib/src/main/scala/org/apache/spark/ml/param/params.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 31d60b4875a9b..8c5935f31fc33 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -505,8 +505,7 @@ trait Params extends Identifiable with Serializable { defaultParamMap.get(param) } catch { case e: NoSuchElementException => - throw new NoSuchElementException("Failed to find a default value for param" + - param.name) + throw new NoSuchElementException(s"Failed to find a default value for ${param.name}") } } From 0235a6a64976fff85c48eb2ecc8dec4230e6cafe Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 2 Sep 2015 15:14:39 -0700 Subject: [PATCH 4/5] use getOrElse --- .../src/main/scala/org/apache/spark/ml/param/params.scala | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 8c5935f31fc33..59ae77268debf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -501,12 +501,8 @@ trait Params extends Identifiable with Serializable { */ final def getDefault[T](param: Param[T]): Option[T] = { shouldOwn(param) - try { - defaultParamMap.get(param) - } catch { - case e: NoSuchElementException => - throw new NoSuchElementException(s"Failed to find a default value for ${param.name}") - } + defaultParamMap.getOrElse(param, + throw new NoSuchElementException(s"Failed to find a default value for ${param.name}")) } /** From bf324c49af29ed1012dee9b2b4c8dd723c59bb6c Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 2 Sep 2015 15:33:28 -0700 Subject: [PATCH 5/5] Move logic to correct function (oops) --- mllib/src/main/scala/org/apache/spark/ml/param/params.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 59ae77268debf..de32b7218c277 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -461,7 +461,8 @@ trait Params extends Identifiable with Serializable { */ final def getOrDefault[T](param: Param[T]): T = { shouldOwn(param) - get(param).orElse(getDefault(param)).get + get(param).orElse(getDefault(param)).getOrElse( + throw new NoSuchElementException(s"Failed to find a default value for ${param.name}")) } /** An alias for [[getOrDefault()]]. */ @@ -501,8 +502,7 @@ trait Params extends Identifiable with Serializable { */ final def getDefault[T](param: Param[T]): Option[T] = { shouldOwn(param) - defaultParamMap.getOrElse(param, - throw new NoSuchElementException(s"Failed to find a default value for ${param.name}")) + defaultParamMap.get(param) } /**