Skip to content

Commit 0d9594e

Browse files
committed
add getOrElse to ParamMap
1 parent eeeffe8 commit 0d9594e

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ trait Params extends Identifiable with Serializable {
104104

105105
/**
106106
* Returns all params. The default implementation uses Java reflection to list all public methods
107-
* that have return type [[Param]].
107+
* that return [[Param]] and have no arguments.
108108
*/
109109
def params: Array[Param[_]] = {
110110
val methods = this.getClass.getMethods
@@ -264,12 +264,13 @@ private[spark] object Params {
264264
* A param to value map.
265265
*/
266266
@AlphaComponent
267-
class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) extends Serializable {
267+
final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
268+
extends Serializable {
268269

269270
/**
270271
* Creates an empty param map.
271272
*/
272-
def this() = this(mutable.Map.empty[Param[Any], Any])
273+
def this() = this(mutable.Map.empty)
273274

274275
/**
275276
* Puts a (param, value) pair (overwrites if the input param exists).
@@ -291,21 +292,25 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
291292
}
292293

293294
/**
294-
* Optionally returns the value associated with a param or its default.
295+
* Optionally returns the value associated with a param.
295296
*/
296297
def get[T](param: Param[T]): Option[T] = {
297298
map.get(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]]
298299
}
299300

301+
/**
302+
* Returns the value associated with a param or a default value.
303+
*/
304+
def getOrElse[T](param: Param[T], default: T): T = {
305+
get(param).getOrElse(default)
306+
}
307+
300308
/**
301309
* Gets the value of the input param or its default value if it does not exist.
302310
* Raises a NoSuchElementException if there is no value associated with the input param.
303311
*/
304312
def apply[T](param: Param[T]): T = {
305-
val value = get(param)
306-
if (value.isDefined) {
307-
value.get
308-
} else {
313+
get(param).getOrElse {
309314
throw new NoSuchElementException(s"Cannot find param ${param.name}.")
310315
}
311316
}
@@ -326,7 +331,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
326331
}
327332

328333
/**
329-
* Make a copy of this param map.
334+
* Creates a copy of this param map.
330335
*/
331336
def copy: ParamMap = new ParamMap(map.clone())
332337

@@ -364,7 +369,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
364369
}
365370

366371
/**
367-
* Number of param pairs in this set.
372+
* Number of param pairs in this map.
368373
*/
369374
def size: Int = map.size
370375
}

0 commit comments

Comments
 (0)