@@ -20,8 +20,8 @@ package org.apache.spark.ml.param
2020import java .lang .reflect .Modifier
2121import java .util .NoSuchElementException
2222
23- import scala .collection .mutable
2423import scala .annotation .varargs
24+ import scala .collection .mutable
2525
2626import org .apache .spark .annotation .{AlphaComponent , DeveloperApi }
2727import org .apache .spark .ml .util .Identifiable
@@ -90,45 +90,50 @@ class Param[T] (val parent: Params, val name: String, val doc: String, val isVal
9090 }
9191}
9292
93- /** Factory methods for common validation functions for [[Param.isValid ]] */
93+ /**
94+ * Factory methods for common validation functions for [[Param.isValid ]].
95+ * The numerical methods only support Int, Long, Float, and Double.
96+ */
9497object ParamValidate {
9598
9699 /** Default validation always return true */
97100 def default [T ]: T => Boolean = (_ : T ) => true
98101
99- /** Negate the given check */
100- def not [T ](isValid : T => Boolean ): T => Boolean = { (value : T ) =>
101- ! isValid(value)
102+ /**
103+ * Private method for checking numerical types and converting to Double.
104+ * This is mainly for the sake of compilation; type checks are really handled
105+ * by [[Params ]] setters and the [[ParamPair ]] constructor.
106+ */
107+ private def getDouble [T ](value : T ): Double = value match {
108+ case x : Int => x.toDouble
109+ case x : Long => x.toDouble
110+ case x : Float => x.toDouble
111+ case x : Double => x.toDouble
112+ case _ =>
113+ // The type should be checked before this is ever called.
114+ throw new IllegalArgumentException (" Numerical Param validation failed because" +
115+ s " of unexpected input type: ${value.getClass}" )
102116 }
103117
104- /** Combine two checks */
105- def and [T ](isValid1 : T => Boolean , isValid2 : T => Boolean ): T => Boolean = { (value : T ) =>
106- isValid1 (value) && isValid2(value)
118+ /** Check if value > lowerBound */
119+ def gt [T ](lowerBound : Double ): T => Boolean = { (value : T ) =>
120+ getDouble (value) > lowerBound
107121 }
108122
109- /** Check for value > lowerBound. Use [[not() ]] for <= check. */
110- def gt (lowerBound : Int ): Int => Boolean = { (value : Int ) => value > lowerBound }
111-
112- /** Check for value >= lowerBound. Use [[not() ]] for < check. */
113- def gtEq (lowerBound : Int ): Int => Boolean = { (value : Int ) => value >= lowerBound }
114-
115- /** Check for value > lowerBound. Use [[not() ]] for <= check. */
116- def gt (lowerBound : Long ): Long => Boolean = { (value : Long ) => value > lowerBound }
117-
118- /** Check for value >= lowerBound. Use [[not() ]] for < check. */
119- def gtEq (lowerBound : Long ): Long => Boolean = { (value : Long ) => value >= lowerBound }
120-
121- /** Check for value > lowerBound. Use [[not() ]] for <= check. */
122- def gt (lowerBound : Float ): Float => Boolean = { (value : Float ) => value > lowerBound }
123-
124- /** Check for value >= lowerBound. Use [[not() ]] for < check. */
125- def gtEq (lowerBound : Float ): Float => Boolean = { (value : Float ) => value >= lowerBound }
123+ /** Check if value >= lowerBound */
124+ def gtEq [T ](lowerBound : Double ): T => Boolean = { (value : T ) =>
125+ getDouble(value) >= lowerBound
126+ }
126127
127- /** Check for value > lowerBound. Use [[not() ]] for <= check. */
128- def gt (lowerBound : Double ): Double => Boolean = { (value : Double ) => value > lowerBound }
128+ /** Check if value < upperBound */
129+ def lt [T ](upperBound : Double ): T => Boolean = { (value : T ) =>
130+ getDouble(value) < upperBound
131+ }
129132
130- /** Check for value >= lowerBound. Use [[not() ]] for < check. */
131- def gtEq (lowerBound : Double ): Double => Boolean = { (value : Double ) => value >= lowerBound }
133+ /** Check if value <= upperBound */
134+ def ltEq [T ](upperBound : Double ): T => Boolean = { (value : T ) =>
135+ getDouble(value) <= upperBound
136+ }
132137
133138 /**
134139 * Check for value in range lowerBound to upperBound.
@@ -137,33 +142,31 @@ object ParamValidate {
137142 * @param upperInclusive If true, check for value <= upperBound.
138143 * If false, check for value < upperBound.
139144 */
140- def inRange [T <: Comparable [ T ] ](
141- lowerBound : T ,
142- upperBound : T ,
145+ def inRange [T ](
146+ lowerBound : Double ,
147+ upperBound : Double ,
143148 lowerInclusive : Boolean ,
144- upperInclusive : Boolean ): T => Boolean = { (x : T ) =>
145- val lowerValid = if (lowerInclusive) {
146- x.compareTo(lowerBound) >= 0
147- } else {
148- x.compareTo(lowerBound) > 0
149- }
150- val upperValid = if (upperInclusive) {
151- x.compareTo(upperBound) <= 0
152- } else {
153- x.compareTo(upperBound) < 0
154- }
149+ upperInclusive : Boolean ): T => Boolean = { (value : T ) =>
150+ val x : Double = getDouble(value)
151+ val lowerValid = if (lowerInclusive) x >= lowerBound else x > lowerBound
152+ val upperValid = if (upperInclusive) x <= upperBound else x < upperBound
155153 lowerValid && upperValid
156154 }
157155
158156 /** Version of [[inRange() ]] which uses inclusive be default: [lowerBound, upperBound] */
159- def inRange [T ](lowerBound : T , upperBound : T ): T => Boolean = {
157+ def inRange [T ](lowerBound : Double , upperBound : Double ): T => Boolean = {
160158 inRange[T ](lowerBound, upperBound, lowerInclusive = true , upperInclusive = true )
161159 }
162160
163161 /** Check for value in an allowed set of values. */
164162 def inArray [T ](allowed : Array [T ]): T => Boolean = { (value : T ) =>
165163 allowed.contains(value)
166164 }
165+
166+ /** Check for value in an allowed set of values. */
167+ def inArray [T ](allowed : java.util.List [T ]): T => Boolean = { (value : T ) =>
168+ allowed.contains(value)
169+ }
167170}
168171
169172// specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ...
0 commit comments