Skip to content

Commit 2ecdc73

Browse files
committed
use ParamValidators
1 parent f593f5b commit 2ecdc73

File tree

3 files changed

+27
-32
lines changed

3 files changed

+27
-32
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,13 +140,12 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
140140
* by `inputCol`. An exception will be thrown if both are set.
141141
*/
142142
private[feature] def isBucketizeMultipleColumns(): Boolean = {
143-
inputColsSanityCheck()
144-
outputColsSanityCheck()
143+
ParamValidators.assertColOrCols(this)
145144
if (isSet(inputCol) && isSet(splitsArray)) {
146-
raiseIncompatibleParamsException("inputCol", "splitsArray")
145+
ParamValidators.raiseIncompatibleParamsException("inputCol", "splitsArray")
147146
}
148147
if (isSet(inputCols) && isSet(splits)) {
149-
raiseIncompatibleParamsException("inputCols", "splits")
148+
ParamValidators.raiseIncompatibleParamsException("inputCols", "splits")
150149
}
151150
isSet(inputCols)
152151
}

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import org.json4s.jackson.JsonMethods._
3131
import org.apache.spark.SparkException
3232
import org.apache.spark.annotation.{DeveloperApi, Since}
3333
import org.apache.spark.ml.linalg.{JsonMatrixConverter, JsonVectorConverter, Matrix, Vector}
34+
import org.apache.spark.ml.param.shared._
3435
import org.apache.spark.ml.util.Identifiable
3536

3637
/**
@@ -249,6 +250,29 @@ object ParamValidators {
249250
def arrayLengthGt[T](lowerBound: Double): Array[T] => Boolean = { (value: Array[T]) =>
250251
value.length > lowerBound
251252
}
253+
254+
/**
255+
* Checks that either inputCols and outputCols are set or inputCol and outputCol are set. If
256+
* this is not true, an `IllegalArgumentException` is raised.
257+
* @param model
258+
*/
259+
def assertColOrCols(model: Params): Unit = {
260+
model match {
261+
case m: HasInputCols with HasInputCol if m.isSet(m.inputCols) && m.isSet(m.inputCol) =>
262+
raiseIncompatibleParamsException("inputCols", "inputCol")
263+
case m: HasOutputCols with HasInputCol if m.isSet(m.outputCols) && m.isSet(m.inputCol) =>
264+
raiseIncompatibleParamsException("outputCols", "inputCol")
265+
case m: HasInputCols with HasOutputCol if m.isSet(m.inputCols) && m.isSet(m.outputCol) =>
266+
raiseIncompatibleParamsException("inputCols", "outputCol")
267+
case m: HasOutputCols with HasOutputCol if m.isSet(m.outputCols) && m.isSet(m.outputCol) =>
268+
raiseIncompatibleParamsException("outputCols", "outputCol")
269+
case _ =>
270+
}
271+
}
272+
273+
def raiseIncompatibleParamsException(paramName1: String, paramName2: String): Unit = {
274+
throw new IllegalArgumentException(s"Both `$paramName1` and `$paramName2` are set.")
275+
}
252276
}
253277

254278
// specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ...
@@ -834,14 +858,6 @@ trait Params extends Identifiable with Serializable {
834858
}
835859
to
836860
}
837-
838-
protected def raiseIncompatibleParamsException(paramName1: String, paramName2: String): Unit = {
839-
throw new IllegalArgumentException(
840-
s"""
841-
|Both `$paramName1` and `$paramName2` are set, `${this.getClass.getName}` only supports
842-
|setting either `$paramName1` or `$paramName2`.
843-
""".stripMargin)
844-
}
845861
}
846862

847863
/**

mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -236,16 +236,6 @@ trait HasInputCols extends Params {
236236

237237
/** @group getParam */
238238
final def getInputCols: Array[String] = $(inputCols)
239-
240-
final def inputColsSanityCheck(): Unit = {
241-
this match {
242-
case model: HasInputCol if isSet(inputCols) && isSet(model.inputCol) =>
243-
raiseIncompatibleParamsException("inputCols", "inputCol")
244-
case model: HasOutputCol if isSet(inputCols) && isSet(model.outputCol) =>
245-
raiseIncompatibleParamsException("inputCols", "outputCol")
246-
case _ =>
247-
}
248-
}
249239
}
250240

251241
/**
@@ -282,16 +272,6 @@ trait HasOutputCols extends Params {
282272

283273
/** @group getParam */
284274
final def getOutputCols: Array[String] = $(outputCols)
285-
286-
final def outputColsSanityCheck(): Unit = {
287-
this match {
288-
case model: HasInputCol if isSet(outputCols) && isSet(model.inputCol) =>
289-
raiseIncompatibleParamsException("outputCols", "inputCol")
290-
case model: HasOutputCol if isSet(outputCols) && isSet(model.outputCol) =>
291-
raiseIncompatibleParamsException("outputCols", "outputCol")
292-
case _ =>
293-
}
294-
}
295275
}
296276

297277
/**

0 commit comments

Comments
 (0)