Skip to content

Commit bb0c0d2

Browse files
committed
address comments
1 parent 8f3581c commit bb0c0d2

File tree

3 files changed

+39
-11
lines changed

3 files changed

+39
-11
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
3434
/**
3535
* `Bucketizer` maps a column of continuous features to a column of feature buckets. Since 2.3.0,
3636
* `Bucketizer` can map multiple columns at once by setting the `inputCols` parameter. Note that
37-
* when both the `inputCol` and `inputCols` parameters are set, a log warning will be printed and
38-
* only `inputCol` will take effect, while `inputCols` will be ignored. The `splits` parameter is
39-
* only used for single column usage, and `splitsArray` is for multiple columns.
37+
* when both the `inputCol` and `inputCols` parameters are set, an Exception will be thrown. The
38+
* `splits` parameter is only used for single column usage, and `splitsArray` is for multiple
39+
* columns.
4040
*/
4141
@Since("1.4.0")
4242
final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
@@ -140,15 +140,15 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
140140
* by `inputCol`. A warning will be printed if both are set.
141141
*/
142142
private[feature] def isBucketizeMultipleColumns(): Boolean = {
143-
if (isSet(inputCols) && isSet(inputCol) || isSet(inputCols) && isSet(outputCol) ||
144-
isSet(inputCol) && isSet(outputCols)) {
145-
throw new IllegalArgumentException("Both `inputCol` and `inputCols` are set, `Bucketizer` " +
146-
"only supports setting either `inputCol` or `inputCols`.")
147-
} else if (isSet(inputCols)) {
148-
true
149-
} else {
150-
false
143+
inputColsSanityCheck()
144+
outputColsSanityCheck()
145+
if (isSet(inputCol) && isSet(splitsArray)) {
146+
raiseIncompatibleParamsException("inputCol", "splitsArray")
147+
}
148+
if (isSet(inputCols) && isSet(splits)) {
149+
raiseIncompatibleParamsException("inputCols", "splits")
151150
}
151+
isSet(inputCols)
152152
}
153153

154154
@Since("2.0.0")

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,14 @@ trait Params extends Identifiable with Serializable {
834834
}
835835
to
836836
}
837+
838+
final def raiseIncompatibleParamsException(paramName1: String, paramName2: String): Unit = {
839+
throw new IllegalArgumentException(
840+
s"""
841+
|Both `$paramName1` and `$paramName2` are set, `${this.getClass.getName}` only supports
842+
|setting either `$paramName1` or `$paramName2`.
843+
""".stripMargin)
844+
}
837845
}
838846

839847
/**

mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,16 @@ trait HasInputCols extends Params {
236236

237237
/** @group getParam */
238238
final def getInputCols: Array[String] = $(inputCols)
239+
240+
final def inputColsSanityCheck(): Unit = {
241+
this match {
242+
case model: HasInputCol if isSet(inputCols) && isSet(model.inputCol) =>
243+
raiseIncompatibleParamsException("inputCols", "inputCol")
244+
case model: HasOutputCol if isSet(inputCols) && isSet(model.outputCol) =>
245+
raiseIncompatibleParamsException("inputCols", "outputCol")
246+
case _ =>
247+
}
248+
}
239249
}
240250

241251
/**
@@ -272,6 +282,16 @@ trait HasOutputCols extends Params {
272282

273283
/** @group getParam */
274284
final def getOutputCols: Array[String] = $(outputCols)
285+
286+
final def outputColsSanityCheck(): Unit = {
287+
this match {
288+
case model: HasInputCol if isSet(outputCols) && isSet(model.inputCol) =>
289+
raiseIncompatibleParamsException("outputCols", "inputCol")
290+
case model: HasOutputCol if isSet(outputCols) && isSet(model.outputCol) =>
291+
raiseIncompatibleParamsException("outputCols", "outputCol")
292+
case _ =>
293+
}
294+
}
275295
}
276296

277297
/**

0 commit comments

Comments
 (0)