From 8f3581cf780dabdbf27ffd7924edf8c344d69b30 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 15 Dec 2017 19:11:18 +0100 Subject: [PATCH 01/18] [SPARK-22799][ML] Bucketizer should throw exception if single- and multi-column params are both set --- .../apache/spark/ml/feature/Bucketizer.scala | 8 +++--- .../spark/ml/feature/BucketizerSuite.scala | 25 ++++++++++++++++--- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 8299a3e95d82..e945909cd439 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -140,10 +140,10 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String * by `inputCol`. A warning will be printed if both are set. */ private[feature] def isBucketizeMultipleColumns(): Boolean = { - if (isSet(inputCols) && isSet(inputCol)) { - logWarning("Both `inputCol` and `inputCols` are set, we ignore `inputCols` and this " + - "`Bucketizer` only map one column specified by `inputCol`") - false + if (isSet(inputCols) && isSet(inputCol) || isSet(inputCols) && isSet(outputCol) || + isSet(inputCol) && isSet(outputCols)) { + throw new IllegalArgumentException("Both `inputCol` and `inputCols` are set, `Bucketizer` " + + "only supports setting either `inputCol` or `inputCols`.") } else if (isSet(inputCols)) { true } else { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index d9c97ae8067d..b1063c70a5bd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -402,14 +402,33 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } test("Both inputCol and inputCols are set") { - val bucket = new Bucketizer() + val feature1 = Array(-0.5, -0.3, 0.0, 0.2) + val feature2 = Array(-0.3, -0.2, 0.5, 0.0) + val df = feature1.zip(feature2).toSeq.toDF("feature1", "feature2") + + val invalid1 = new Bucketizer() .setInputCol("feature1") .setOutputCol("result") .setSplits(Array(-0.5, 0.0, 0.5)) .setInputCols(Array("feature1", "feature2")) - // When both are set, we ignore `inputCols` and just map the column specified by `inputCol`. - assert(bucket.isBucketizeMultipleColumns() == false) + val invalid2 = new Bucketizer() + .setOutputCol("result") + .setSplits(Array(-0.5, 0.0, 0.5)) + .setInputCols(Array("feature1", "feature2")) + + val invalid3 = new Bucketizer() + .setInputCol("feature1") + .setSplits(Array(-0.5, 0.0, 0.5)) + .setOutputCols(Array("result1", "result2")) + + Seq(invalid1, invalid2, invalid3).foreach { bucketizer => + // When both inputCol and inputCols are set, we throw Exception. + val e = intercept[IllegalArgumentException] { + bucketizer.transform(df) + } + assert(e.getMessage.contains("Both `inputCol` and `inputCols` are set")) + } } } From bb0c0d29f4eec137bbd90ae068a7f8a30c92ea9f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 19 Dec 2017 19:09:32 +0100 Subject: [PATCH 02/18] address comments --- .../apache/spark/ml/feature/Bucketizer.scala | 22 +++++++++---------- .../org/apache/spark/ml/param/params.scala | 8 +++++++ .../spark/ml/param/shared/sharedParams.scala | 20 +++++++++++++++++ 3 files changed, 39 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index e945909cd439..4e0d647dd1d7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -34,9 +34,9 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} /** * `Bucketizer` maps a column of continuous features to a column of feature buckets. Since 2.3.0, * `Bucketizer` can map multiple columns at once by setting the `inputCols` parameter. Note that - * when both the `inputCol` and `inputCols` parameters are set, a log warning will be printed and - * only `inputCol` will take effect, while `inputCols` will be ignored. The `splits` parameter is - * only used for single column usage, and `splitsArray` is for multiple columns. + * when both the `inputCol` and `inputCols` parameters are set, an Exception will be thrown. The + * `splits` parameter is only used for single column usage, and `splitsArray` is for multiple + * columns. */ @Since("1.4.0") final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) @@ -140,15 +140,15 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String * by `inputCol`. A warning will be printed if both are set. */ private[feature] def isBucketizeMultipleColumns(): Boolean = { - if (isSet(inputCols) && isSet(inputCol) || isSet(inputCols) && isSet(outputCol) || - isSet(inputCol) && isSet(outputCols)) { - throw new IllegalArgumentException("Both `inputCol` and `inputCols` are set, `Bucketizer` " + - "only supports setting either `inputCol` or `inputCols`.") - } else if (isSet(inputCols)) { - true - } else { - false + inputColsSanityCheck() + outputColsSanityCheck() + if (isSet(inputCol) && isSet(splitsArray)) { + raiseIncompatibleParamsException("inputCol", "splitsArray") + } + if (isSet(inputCols) && isSet(splits)) { + raiseIncompatibleParamsException("inputCols", "splits") } + isSet(inputCols) } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 1b4b401ac4aa..c5af53e91f4c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -834,6 +834,14 @@ trait Params extends Identifiable with Serializable { } to } + + final def raiseIncompatibleParamsException(paramName1: String, paramName2: String): Unit = { + throw new IllegalArgumentException( + s""" + |Both `$paramName1` and `$paramName2` are set, `${this.getClass.getName}` only supports + |setting either `$paramName1` or `$paramName2`. + """.stripMargin) + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 13425dacc9f1..931744d2b594 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -236,6 +236,16 @@ trait HasInputCols extends Params { /** @group getParam */ final def getInputCols: Array[String] = $(inputCols) + + final def inputColsSanityCheck(): Unit = { + this match { + case model: HasInputCol if isSet(inputCols) && isSet(model.inputCol) => + raiseIncompatibleParamsException("inputCols", "inputCol") + case model: HasOutputCol if isSet(inputCols) && isSet(model.outputCol) => + raiseIncompatibleParamsException("inputCols", "outputCol") + case _ => + } + } } /** @@ -272,6 +282,16 @@ trait HasOutputCols extends Params { /** @group getParam */ final def getOutputCols: Array[String] = $(outputCols) + + final def outputColsSanityCheck(): Unit = { + this match { + case model: HasInputCol if isSet(outputCols) && isSet(model.inputCol) => + raiseIncompatibleParamsException("outputCols", "inputCol") + case model: HasOutputCol if isSet(outputCols) && isSet(model.outputCol) => + raiseIncompatibleParamsException("outputCols", "outputCol") + case _ => + } + } } /** From 9f5680032e93ef392fb76d4518ab8a4b0f479aae Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 19 Dec 2017 19:13:22 +0100 Subject: [PATCH 03/18] fix doc --- .../src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 4e0d647dd1d7..6d4652439b64 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -137,7 +137,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String /** * Determines whether this `Bucketizer` is going to map multiple columns. If and only if * `inputCols` is set, it will map multiple columns. Otherwise, it just maps a column specified - * by `inputCol`. A warning will be printed if both are set. + * by `inputCol`. An exception will be thrown if both are set. */ private[feature] def isBucketizeMultipleColumns(): Boolean = { inputColsSanityCheck() From f593f5b5c80f80787767786aef2f82dde41ead99 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 19 Dec 2017 19:56:17 +0100 Subject: [PATCH 04/18] fix mima error --- mllib/src/main/scala/org/apache/spark/ml/param/params.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index c5af53e91f4c..9f0b5feef930 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -835,7 +835,7 @@ trait Params extends Identifiable with Serializable { to } - final def raiseIncompatibleParamsException(paramName1: String, paramName2: String): Unit = { + protected def raiseIncompatibleParamsException(paramName1: String, paramName2: String): Unit = { throw new IllegalArgumentException( s""" |Both `$paramName1` and `$paramName2` are set, `${this.getClass.getName}` only supports From 2ecdc73c83e4207a63fad65306dfe69de24d5bee Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 19 Dec 2017 20:41:06 +0100 Subject: [PATCH 05/18] use ParamValidators --- .../apache/spark/ml/feature/Bucketizer.scala | 7 ++-- .../org/apache/spark/ml/param/params.scala | 32 ++++++++++++++----- .../spark/ml/param/shared/sharedParams.scala | 20 ------------ 3 files changed, 27 insertions(+), 32 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 6d4652439b64..2bc974a45748 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -140,13 +140,12 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String * by `inputCol`. An exception will be thrown if both are set. */ private[feature] def isBucketizeMultipleColumns(): Boolean = { - inputColsSanityCheck() - outputColsSanityCheck() + ParamValidators.assertColOrCols(this) if (isSet(inputCol) && isSet(splitsArray)) { - raiseIncompatibleParamsException("inputCol", "splitsArray") + ParamValidators.raiseIncompatibleParamsException("inputCol", "splitsArray") } if (isSet(inputCols) && isSet(splits)) { - raiseIncompatibleParamsException("inputCols", "splits") + ParamValidators.raiseIncompatibleParamsException("inputCols", "splits") } isSet(inputCols) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 9f0b5feef930..233c2e09c58e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -31,6 +31,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkException import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.ml.linalg.{JsonMatrixConverter, JsonVectorConverter, Matrix, Vector} +import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.Identifiable /** @@ -249,6 +250,29 @@ object ParamValidators { def arrayLengthGt[T](lowerBound: Double): Array[T] => Boolean = { (value: Array[T]) => value.length > lowerBound } + + /** + * Checks that either inputCols and outputCols are set or inputCol and outputCol are set. If + * this is not true, an `IllegalArgumentException` is raised. + * @param model + */ + def assertColOrCols(model: Params): Unit = { + model match { + case m: HasInputCols with HasInputCol if m.isSet(m.inputCols) && m.isSet(m.inputCol) => + raiseIncompatibleParamsException("inputCols", "inputCol") + case m: HasOutputCols with HasInputCol if m.isSet(m.outputCols) && m.isSet(m.inputCol) => + raiseIncompatibleParamsException("outputCols", "inputCol") + case m: HasInputCols with HasOutputCol if m.isSet(m.inputCols) && m.isSet(m.outputCol) => + raiseIncompatibleParamsException("inputCols", "outputCol") + case m: HasOutputCols with HasOutputCol if m.isSet(m.outputCols) && m.isSet(m.outputCol) => + raiseIncompatibleParamsException("outputCols", "outputCol") + case _ => + } + } + + def raiseIncompatibleParamsException(paramName1: String, paramName2: String): Unit = { + throw new IllegalArgumentException(s"Both `$paramName1` and `$paramName2` are set.") + } } // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... @@ -834,14 +858,6 @@ trait Params extends Identifiable with Serializable { } to } - - protected def raiseIncompatibleParamsException(paramName1: String, paramName2: String): Unit = { - throw new IllegalArgumentException( - s""" - |Both `$paramName1` and `$paramName2` are set, `${this.getClass.getName}` only supports - |setting either `$paramName1` or `$paramName2`. - """.stripMargin) - } } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 931744d2b594..13425dacc9f1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -236,16 +236,6 @@ trait HasInputCols extends Params { /** @group getParam */ final def getInputCols: Array[String] = $(inputCols) - - final def inputColsSanityCheck(): Unit = { - this match { - case model: HasInputCol if isSet(inputCols) && isSet(model.inputCol) => - raiseIncompatibleParamsException("inputCols", "inputCol") - case model: HasOutputCol if isSet(inputCols) && isSet(model.outputCol) => - raiseIncompatibleParamsException("inputCols", "outputCol") - case _ => - } - } } /** @@ -282,16 +272,6 @@ trait HasOutputCols extends Params { /** @group getParam */ final def getOutputCols: Array[String] = $(outputCols) - - final def outputColsSanityCheck(): Unit = { - this match { - case model: HasInputCol if isSet(outputCols) && isSet(model.inputCol) => - raiseIncompatibleParamsException("outputCols", "inputCol") - case model: HasOutputCol if isSet(outputCols) && isSet(model.outputCol) => - raiseIncompatibleParamsException("outputCols", "outputCol") - case _ => - } - } } /** From 26fe05e953eaa87392cd037c7fd970037c538fea Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 19 Dec 2017 21:33:47 +0100 Subject: [PATCH 06/18] fix ut --- .../scala/org/apache/spark/ml/feature/BucketizerSuite.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index b1063c70a5bd..9515143baf9e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -423,11 +423,10 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCols(Array("result1", "result2")) Seq(invalid1, invalid2, invalid3).foreach { bucketizer => - // When both inputCol and inputCols are set, we throw Exception. - val e = intercept[IllegalArgumentException] { + // When both inputCol/outputCol and inputCols/outputCols are set, we throw Exception. + intercept[IllegalArgumentException] { bucketizer.transform(df) } - assert(e.getMessage.contains("Both `inputCol` and `inputCols` are set")) } } } From 64634b5b505b223f1588f5604c1049b947d1db32 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 19 Dec 2017 23:14:40 +0100 Subject: [PATCH 07/18] address review comments --- .../org/apache/spark/ml/feature/Bucketizer.scala | 16 ++++++++-------- .../scala/org/apache/spark/ml/param/params.scala | 8 +++++--- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 2bc974a45748..852af4bf4fbb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -137,16 +137,9 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String /** * Determines whether this `Bucketizer` is going to map multiple columns. If and only if * `inputCols` is set, it will map multiple columns. Otherwise, it just maps a column specified - * by `inputCol`. An exception will be thrown if both are set. + * by `inputCol`. */ private[feature] def isBucketizeMultipleColumns(): Boolean = { - ParamValidators.assertColOrCols(this) - if (isSet(inputCol) && isSet(splitsArray)) { - ParamValidators.raiseIncompatibleParamsException("inputCol", "splitsArray") - } - if (isSet(inputCols) && isSet(splits)) { - ParamValidators.raiseIncompatibleParamsException("inputCols", "splits") - } isSet(inputCols) } @@ -200,6 +193,13 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { + ParamValidators.assertColOrCols(this) + if (isSet(inputCol) && isSet(splitsArray)) { + ParamValidators.raiseIncompatibleParamsException("inputCol", "splitsArray") + } + if (isSet(inputCols) && isSet(splits)) { + ParamValidators.raiseIncompatibleParamsException("inputCols", "splits") + } if (isBucketizeMultipleColumns()) { var transformedSchema = schema $(inputCols).zip($(outputCols)).zipWithIndex.map { case ((inputCol, outputCol), idx) => diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 233c2e09c58e..b903e3586f83 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -256,7 +256,7 @@ object ParamValidators { * this is not true, an `IllegalArgumentException` is raised. * @param model */ - def assertColOrCols(model: Params): Unit = { + private[spark] def assertColOrCols(model: Params): Unit = { model match { case m: HasInputCols with HasInputCol if m.isSet(m.inputCols) && m.isSet(m.inputCol) => raiseIncompatibleParamsException("inputCols", "inputCol") @@ -270,8 +270,10 @@ object ParamValidators { } } - def raiseIncompatibleParamsException(paramName1: String, paramName2: String): Unit = { - throw new IllegalArgumentException(s"Both `$paramName1` and `$paramName2` are set.") + private[spark] def raiseIncompatibleParamsException( + paramName1: String, + paramName2: String): Unit = { + throw new IllegalArgumentException(s"`$paramName1` and `$paramName2` cannot be both set.") } } From 9872bfdb5a428a74ba387c5d86a271621fef0a04 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 20 Dec 2017 22:16:04 +0100 Subject: [PATCH 08/18] add checkMultiColumnParams --- .../spark/ml/feature/BucketizerSuite.scala | 29 +---------- .../apache/spark/ml/param/ParamsSuite.scala | 48 +++++++++++++++++++ 2 files changed, 50 insertions(+), 27 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 9515143baf9e..3ffd7b59034f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -401,33 +401,8 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } } - test("Both inputCol and inputCols are set") { - val feature1 = Array(-0.5, -0.3, 0.0, 0.2) - val feature2 = Array(-0.3, -0.2, 0.5, 0.0) - val df = feature1.zip(feature2).toSeq.toDF("feature1", "feature2") - - val invalid1 = new Bucketizer() - .setInputCol("feature1") - .setOutputCol("result") - .setSplits(Array(-0.5, 0.0, 0.5)) - .setInputCols(Array("feature1", "feature2")) - - val invalid2 = new Bucketizer() - .setOutputCol("result") - .setSplits(Array(-0.5, 0.0, 0.5)) - .setInputCols(Array("feature1", "feature2")) - - val invalid3 = new Bucketizer() - .setInputCol("feature1") - .setSplits(Array(-0.5, 0.0, 0.5)) - .setOutputCols(Array("result1", "result2")) - - Seq(invalid1, invalid2, invalid3).foreach { bucketizer => - // When both inputCol/outputCol and inputCols/outputCols are set, we throw Exception. - intercept[IllegalArgumentException] { - bucketizer.transform(df) - } - } + test("assert exception is thrown is both multi-column and single-column params are set") { + ParamsSuite.checkMultiColumnParams(classOf[Bucketizer], spark) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 85198ad4c913..03ae8f16f08b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -20,8 +20,11 @@ package org.apache.spark.ml.param import java.io.{ByteArrayOutputStream, ObjectOutputStream} import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.{Estimator, Transformer} import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.param.shared.{HasInputCol, HasInputCols, HasOutputCol, HasOutputCols} import org.apache.spark.ml.util.MyParams +import org.apache.spark.sql.{Dataset, SparkSession} class ParamsSuite extends SparkFunSuite { @@ -430,4 +433,49 @@ object ParamsSuite extends SparkFunSuite { require(copyReturnType === obj.getClass, s"${clazz.getName}.copy should return ${clazz.getName} instead of ${copyReturnType.getName}.") } + + /** + * Checks that the class throws an exception in case both `inputCols` and `inputCol` are set and + * in case both `outputCols` and `outputCol` are set. + * These checks are performed only whether the class extends respectively both `HasInputCols` and + * `HasInputCol` and both `HasOutputCols` and `HasOutputCol`. + * + * @param paramsClass The Class to be checked + * @param spark A `SparkSession` instance to use + */ + def checkMultiColumnParams(paramsClass: Class[_ <: Params], spark: SparkSession): Unit = { + import spark.implicits._ + // create fake input Dataset + val feature1 = Array(-1.0, 0.0, 1.0) + val feature2 = Array(1.0, 0.0, -1.0) + val df = feature1.zip(feature2).toSeq.toDF("feature1", "feature2") + + if (paramsClass.isAssignableFrom(classOf[HasInputCols]) + && paramsClass.isAssignableFrom(classOf[HasInputCol])) { + val model = paramsClass.newInstance() + model.set(model.asInstanceOf[HasInputCols].inputCols, Array("feature1", "feature2")) + model.set(model.asInstanceOf[HasInputCol].inputCol, "features1") + val e = intercept[IllegalArgumentException] { + model match { + case t: Transformer => t.transform(df) + case e: Estimator[_] => e.fit(df) + } + } + assert(e.getMessage.contains("cannot be both set")) + } + + if (paramsClass.isAssignableFrom(classOf[HasOutputCols]) + && paramsClass.isAssignableFrom(classOf[HasOutputCol])) { + val model = paramsClass.newInstance() + model.set(model.asInstanceOf[HasOutputCols].outputCols, Array("result1", "result2")) + model.set(model.asInstanceOf[HasOutputCol].outputCol, "result1") + val e = intercept[IllegalArgumentException] { + model match { + case t: Transformer => t.transform(df) + case e: Estimator[_] => e.fit(df) + } + } + assert(e.getMessage.contains("cannot be both set")) + } + } } From d0b8d06bd92f47a8b495411d966ba36f998cd5b3 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 21 Dec 2017 10:55:49 +0100 Subject: [PATCH 09/18] address review comments --- .../apache/spark/ml/feature/Bucketizer.scala | 2 +- .../org/apache/spark/ml/param/params.scala | 2 +- .../spark/ml/feature/BucketizerSuite.scala | 3 +- .../apache/spark/ml/param/ParamsSuite.scala | 28 ++++++++----------- 4 files changed, 16 insertions(+), 19 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 852af4bf4fbb..e654228a469a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -193,7 +193,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { - ParamValidators.assertColOrCols(this) + ParamValidators.checkMultiColumnParams(this) if (isSet(inputCol) && isSet(splitsArray)) { ParamValidators.raiseIncompatibleParamsException("inputCol", "splitsArray") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index b903e3586f83..9cbc84ed0f68 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -256,7 +256,7 @@ object ParamValidators { * this is not true, an `IllegalArgumentException` is raised. * @param model */ - private[spark] def assertColOrCols(model: Params): Unit = { + private[spark] def checkMultiColumnParams(model: Params): Unit = { model match { case m: HasInputCols with HasInputCol if m.isSet(m.inputCols) && m.isSet(m.inputCol) => raiseIncompatibleParamsException("inputCols", "inputCol") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 3ffd7b59034f..cd1ddaf26e96 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -402,7 +402,8 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } test("assert exception is thrown is both multi-column and single-column params are set") { - ParamsSuite.checkMultiColumnParams(classOf[Bucketizer], spark) + val df = Seq((0.5, 0.3), (0.5, -0.4)).toDF("feature1", "feature2") + ParamsSuite.testMultiColumnParams(classOf[Bucketizer], df) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 03ae8f16f08b..82a3b5df28f6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.{Estimator, Transformer} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.shared.{HasInputCol, HasInputCols, HasOutputCol, HasOutputCols} import org.apache.spark.ml.util.MyParams -import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.Dataset class ParamsSuite extends SparkFunSuite { @@ -441,24 +441,20 @@ object ParamsSuite extends SparkFunSuite { * `HasInputCol` and both `HasOutputCols` and `HasOutputCol`. * * @param paramsClass The Class to be checked - * @param spark A `SparkSession` instance to use + * @param dataset A `Dataset` to use in the tests */ - def checkMultiColumnParams(paramsClass: Class[_ <: Params], spark: SparkSession): Unit = { - import spark.implicits._ - // create fake input Dataset - val feature1 = Array(-1.0, 0.0, 1.0) - val feature2 = Array(1.0, 0.0, -1.0) - val df = feature1.zip(feature2).toSeq.toDF("feature1", "feature2") + def testMultiColumnParams(paramsClass: Class[_ <: Params], dataset: Dataset[_]): Unit = { + val cols = dataset.columns if (paramsClass.isAssignableFrom(classOf[HasInputCols]) && paramsClass.isAssignableFrom(classOf[HasInputCol])) { val model = paramsClass.newInstance() - model.set(model.asInstanceOf[HasInputCols].inputCols, Array("feature1", "feature2")) - model.set(model.asInstanceOf[HasInputCol].inputCol, "features1") + model.set(model.asInstanceOf[HasInputCols].inputCols, cols) + model.set(model.asInstanceOf[HasInputCol].inputCol, cols(0)) val e = intercept[IllegalArgumentException] { model match { - case t: Transformer => t.transform(df) - case e: Estimator[_] => e.fit(df) + case t: Transformer => t.transform(dataset) + case e: Estimator[_] => e.fit(dataset) } } assert(e.getMessage.contains("cannot be both set")) @@ -467,12 +463,12 @@ object ParamsSuite extends SparkFunSuite { if (paramsClass.isAssignableFrom(classOf[HasOutputCols]) && paramsClass.isAssignableFrom(classOf[HasOutputCol])) { val model = paramsClass.newInstance() - model.set(model.asInstanceOf[HasOutputCols].outputCols, Array("result1", "result2")) - model.set(model.asInstanceOf[HasOutputCol].outputCol, "result1") + model.set(model.asInstanceOf[HasOutputCols].outputCols, cols) + model.set(model.asInstanceOf[HasOutputCol].outputCol, cols(0)) val e = intercept[IllegalArgumentException] { model match { - case t: Transformer => t.transform(df) - case e: Estimator[_] => e.fit(df) + case t: Transformer => t.transform(dataset) + case e: Estimator[_] => e.fit(dataset) } } assert(e.getMessage.contains("cannot be both set")) From b20fb91da9137347ff6710cfbe2af78277e0036e Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 21 Dec 2017 20:40:57 +0100 Subject: [PATCH 10/18] remove two unneeded checks --- mllib/src/main/scala/org/apache/spark/ml/param/params.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 9cbc84ed0f68..5890de7385bd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -260,10 +260,6 @@ object ParamValidators { model match { case m: HasInputCols with HasInputCol if m.isSet(m.inputCols) && m.isSet(m.inputCol) => raiseIncompatibleParamsException("inputCols", "inputCol") - case m: HasOutputCols with HasInputCol if m.isSet(m.outputCols) && m.isSet(m.inputCol) => - raiseIncompatibleParamsException("outputCols", "inputCol") - case m: HasInputCols with HasOutputCol if m.isSet(m.inputCols) && m.isSet(m.outputCol) => - raiseIncompatibleParamsException("inputCols", "outputCol") case m: HasOutputCols with HasOutputCol if m.isSet(m.outputCols) && m.isSet(m.outputCol) => raiseIncompatibleParamsException("outputCols", "outputCol") case _ => From 09d652d7f75a11a4376236421f51dfc8910d7b20 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 30 Dec 2017 11:09:01 +0100 Subject: [PATCH 11/18] address comments --- .../org/apache/spark/ml/feature/Bucketizer.scala | 15 +++------------ .../scala/org/apache/spark/ml/param/params.scala | 2 +- .../apache/spark/ml/feature/BucketizerSuite.scala | 2 +- .../org/apache/spark/ml/param/ParamsSuite.scala | 2 +- 4 files changed, 6 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index e654228a469a..154a5f7ec448 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -134,20 +134,11 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("2.3.0") def setOutputCols(value: Array[String]): this.type = set(outputCols, value) - /** - * Determines whether this `Bucketizer` is going to map multiple columns. If and only if - * `inputCols` is set, it will map multiple columns. Otherwise, it just maps a column specified - * by `inputCol`. - */ - private[feature] def isBucketizeMultipleColumns(): Boolean = { - isSet(inputCols) - } - @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { val transformedSchema = transformSchema(dataset.schema) - val (inputColumns, outputColumns) = if (isBucketizeMultipleColumns()) { + val (inputColumns, outputColumns) = if (isSet(inputCols)) { ($(inputCols).toSeq, $(outputCols).toSeq) } else { (Seq($(inputCol)), Seq($(outputCol))) @@ -162,7 +153,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String } } - val seqOfSplits = if (isBucketizeMultipleColumns()) { + val seqOfSplits = if (isSet(inputCols)) { $(splitsArray).toSeq } else { Seq($(splits)) @@ -200,7 +191,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String if (isSet(inputCols) && isSet(splits)) { ParamValidators.raiseIncompatibleParamsException("inputCols", "splits") } - if (isBucketizeMultipleColumns()) { + if (isSet(inputCols)) { var transformedSchema = schema $(inputCols).zip($(outputCols)).zipWithIndex.map { case ((inputCol, outputCol), idx) => SchemaUtils.checkNumericType(transformedSchema, inputCol) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 5890de7385bd..f9cbcdbc3c74 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -269,7 +269,7 @@ object ParamValidators { private[spark] def raiseIncompatibleParamsException( paramName1: String, paramName2: String): Unit = { - throw new IllegalArgumentException(s"`$paramName1` and `$paramName2` cannot be both set.") + throw new IllegalArgumentException(s"`$paramName1` and `$paramName2` cannot both be set.") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index cd1ddaf26e96..94b0534956bc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -401,7 +401,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } } - test("assert exception is thrown is both multi-column and single-column params are set") { + test("assert exception is thrown if both multi-column and single-column params are set") { val df = Seq((0.5, 0.3), (0.5, -0.4)).toDF("feature1", "feature2") ParamsSuite.testMultiColumnParams(classOf[Bucketizer], df) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 82a3b5df28f6..0869f1363170 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -437,7 +437,7 @@ object ParamsSuite extends SparkFunSuite { /** * Checks that the class throws an exception in case both `inputCols` and `inputCol` are set and * in case both `outputCols` and `outputCol` are set. - * These checks are performed only whether the class extends respectively both `HasInputCols` and + * These checks are performed only when the class extends respectively both `HasInputCols` and * `HasInputCol` and both `HasOutputCols` and `HasOutputCol`. * * @param paramsClass The Class to be checked From a0c0fed21d6fe99d5fb283190fe64f66645a5880 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 8 Jan 2018 14:32:46 +0100 Subject: [PATCH 12/18] remove isBucketizeMultipleColumns --- .../org/apache/spark/ml/feature/BucketizerSuite.scala | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 94b0534956bc..af54875e956b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -216,8 +216,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCols(Array("result1", "result2")) .setSplitsArray(splits) - assert(bucketizer1.isBucketizeMultipleColumns()) - bucketizer1.transform(dataFrame).select("result1", "expected1", "result2", "expected2") BucketizerSuite.checkBucketResults(bucketizer1.transform(dataFrame), Seq("result1", "result2"), @@ -233,8 +231,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCols(Array("result")) .setSplitsArray(Array(splits(0))) - assert(bucketizer2.isBucketizeMultipleColumns()) - withClue("Invalid feature value -0.9 was not caught as an invalid feature!") { intercept[SparkException] { bucketizer2.transform(badDF1).collect() @@ -268,8 +264,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCols(Array("result1", "result2")) .setSplitsArray(splits) - assert(bucketizer.isBucketizeMultipleColumns()) - BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame), Seq("result1", "result2"), Seq("expected1", "expected2")) @@ -295,8 +289,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCols(Array("result1", "result2")) .setSplitsArray(splits) - assert(bucketizer.isBucketizeMultipleColumns()) - bucketizer.setHandleInvalid("keep") BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame), Seq("result1", "result2"), @@ -335,7 +327,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setInputCols(Array("myInputCol")) .setOutputCols(Array("myOutputCol")) .setSplitsArray(Array(Array(0.1, 0.8, 0.9))) - assert(t.isBucketizeMultipleColumns()) testDefaultReadWrite(t) } @@ -348,8 +339,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCols(Array("result1", "result2")) .setSplitsArray(Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5))) - assert(bucket.isBucketizeMultipleColumns()) - val pl = new Pipeline() .setStages(Array(bucket)) .fit(df) From 25b9bd4e2a5dd808a5db90fdbe5492872fefdeea Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 17 Jan 2018 17:49:04 +0100 Subject: [PATCH 13/18] address comments --- .../apache/spark/ml/feature/Bucketizer.scala | 13 +++---- .../org/apache/spark/ml/param/params.scala | 32 +++++++-------- .../spark/ml/feature/BucketizerSuite.scala | 7 +++- .../apache/spark/ml/param/ParamsSuite.scala | 39 +++++-------------- 4 files changed, 37 insertions(+), 54 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 154a5f7ec448..a3ea9c317200 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -184,16 +184,13 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { - ParamValidators.checkMultiColumnParams(this) - if (isSet(inputCol) && isSet(splitsArray)) { - ParamValidators.raiseIncompatibleParamsException("inputCol", "splitsArray") - } - if (isSet(inputCols) && isSet(splits)) { - ParamValidators.raiseIncompatibleParamsException("inputCols", "splits") - } + ParamValidators.checkExclusiveParams(this, "inputCol", "inputCols") + ParamValidators.checkExclusiveParams(this, "outputCol", "outputCols") + ParamValidators.checkExclusiveParams(this, "splits", "splitsArray") + if (isSet(inputCols)) { var transformedSchema = schema - $(inputCols).zip($(outputCols)).zipWithIndex.map { case ((inputCol, outputCol), idx) => + $(inputCols).zip($(outputCols)).zipWithIndex.foreach { case ((inputCol, outputCol), idx) => SchemaUtils.checkNumericType(transformedSchema, inputCol) transformedSchema = SchemaUtils.appendColumn(transformedSchema, prepOutputField($(splitsArray)(idx), outputCol)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index f9cbcdbc3c74..8fbb0e1b2a3b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -27,11 +27,11 @@ import scala.collection.mutable import org.json4s._ import org.json4s.jackson.JsonMethods._ +import org.slf4j.LoggerFactory import org.apache.spark.SparkException import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.ml.linalg.{JsonMatrixConverter, JsonVectorConverter, Matrix, Vector} -import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.Identifiable /** @@ -167,6 +167,8 @@ private[ml] object Param { @DeveloperApi object ParamValidators { + private val LOGGER = LoggerFactory.getLogger(ParamValidators.getClass) + /** (private[param]) Default validation always return true */ private[param] def alwaysTrue[T]: T => Boolean = (_: T) => true @@ -252,24 +254,22 @@ object ParamValidators { } /** - * Checks that either inputCols and outputCols are set or inputCol and outputCol are set. If - * this is not true, an `IllegalArgumentException` is raised. - * @param model + * Checks that only one of the params passed as arguments is set. If this is not true, an + * `IllegalArgumentException` is raised. */ - private[spark] def checkMultiColumnParams(model: Params): Unit = { - model match { - case m: HasInputCols with HasInputCol if m.isSet(m.inputCols) && m.isSet(m.inputCol) => - raiseIncompatibleParamsException("inputCols", "inputCol") - case m: HasOutputCols with HasOutputCol if m.isSet(m.outputCols) && m.isSet(m.outputCol) => - raiseIncompatibleParamsException("outputCols", "outputCol") - case _ => + def checkExclusiveParams(model: Params, params: String*): Unit = { + val (existingParams, nonExistingParams) = params.partition(model.hasParam) + if (nonExistingParams.nonEmpty) { + val pronoun = if (nonExistingParams.size == 1) "It" else "They" + LOGGER.warn(s"Ignored ${nonExistingParams.mkString("`", "`, `", "`")} while checking " + + s"exclusive params. $pronoun don't exist for the specified model the model.") } - } - private[spark] def raiseIncompatibleParamsException( - paramName1: String, - paramName2: String): Unit = { - throw new IllegalArgumentException(s"`$paramName1` and `$paramName2` cannot both be set.") + if (existingParams.count(paramName => model.isSet(model.getParam(paramName))) > 1) { + val paramString = existingParams.mkString("`", "`, `", "`") + throw new IllegalArgumentException(s"$paramString are exclusive, " + + "but more than one among them are set.") + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index af54875e956b..69027d1b1894 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -392,7 +392,12 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa test("assert exception is thrown if both multi-column and single-column params are set") { val df = Seq((0.5, 0.3), (0.5, -0.4)).toDF("feature1", "feature2") - ParamsSuite.testMultiColumnParams(classOf[Bucketizer], df) + ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"), + ("inputCols", Array("feature1", "feature2"))) + ParamsSuite.testExclusiveParams(new Bucketizer, df, ("outputCol", "result1"), + ("outputCols", Array("result1", "result2"))) + ParamsSuite.testExclusiveParams(new Bucketizer, df, ("splits", Array(-0.5, 0.0, 0.5)), + ("splitsArray", Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5)))) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 0869f1363170..94d6125c4c98 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -435,43 +435,24 @@ object ParamsSuite extends SparkFunSuite { } /** - * Checks that the class throws an exception in case both `inputCols` and `inputCol` are set and - * in case both `outputCols` and `outputCol` are set. - * These checks are performed only when the class extends respectively both `HasInputCols` and - * `HasInputCol` and both `HasOutputCols` and `HasOutputCol`. - * - * @param paramsClass The Class to be checked - * @param dataset A `Dataset` to use in the tests + * Checks that the class throws an exception in case multiple exclusive params are set + * The params to be checked are passed as arguments with their value. + * The checks are performed only if all the passed params are defined for the given model. */ - def testMultiColumnParams(paramsClass: Class[_ <: Params], dataset: Dataset[_]): Unit = { - val cols = dataset.columns - - if (paramsClass.isAssignableFrom(classOf[HasInputCols]) - && paramsClass.isAssignableFrom(classOf[HasInputCol])) { - val model = paramsClass.newInstance() - model.set(model.asInstanceOf[HasInputCols].inputCols, cols) - model.set(model.asInstanceOf[HasInputCol].inputCol, cols(0)) - val e = intercept[IllegalArgumentException] { - model match { - case t: Transformer => t.transform(dataset) - case e: Estimator[_] => e.fit(dataset) - } + def testExclusiveParams(model: Params, dataset: Dataset[_], + paramsAndValues: (String, Any)*): Unit = { + val params = paramsAndValues.map(_._1) + if (params.forall(model.hasParam)) { + paramsAndValues.foreach { case (paramName, paramValue) => + model.set(model.getParam(paramName), paramValue) } - assert(e.getMessage.contains("cannot be both set")) - } - - if (paramsClass.isAssignableFrom(classOf[HasOutputCols]) - && paramsClass.isAssignableFrom(classOf[HasOutputCol])) { - val model = paramsClass.newInstance() - model.set(model.asInstanceOf[HasOutputCols].outputCols, cols) - model.set(model.asInstanceOf[HasOutputCol].outputCol, cols(0)) val e = intercept[IllegalArgumentException] { model match { case t: Transformer => t.transform(dataset) case e: Estimator[_] => e.fit(dataset) } } - assert(e.getMessage.contains("cannot be both set")) + assert(e.getMessage.contains("are exclusive, but more than one")) } } } From 18bbf61e1fd79930c9ba0ed5ae5edfce6229ae55 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 19 Jan 2018 22:21:42 -0800 Subject: [PATCH 14/18] strengthened requirements about exclusive Params for single and multi column support --- .../apache/spark/ml/feature/Bucketizer.scala | 15 +++- .../org/apache/spark/ml/param/params.scala | 75 +++++++++++++++---- .../apache/spark/ml/param/ParamsSuite.scala | 27 ++++--- 3 files changed, 84 insertions(+), 33 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index a3ea9c317200..c13bf47eacb9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -32,7 +32,9 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} /** - * `Bucketizer` maps a column of continuous features to a column of feature buckets. Since 2.3.0, + * `Bucketizer` maps a column of continuous features to a column of feature buckets. + * + * Since 2.3.0, * `Bucketizer` can map multiple columns at once by setting the `inputCols` parameter. Note that * when both the `inputCol` and `inputCols` parameters are set, an Exception will be thrown. The * `splits` parameter is only used for single column usage, and `splitsArray` is for multiple @@ -184,11 +186,16 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { - ParamValidators.checkExclusiveParams(this, "inputCol", "inputCols") - ParamValidators.checkExclusiveParams(this, "outputCol", "outputCols") - ParamValidators.checkExclusiveParams(this, "splits", "splitsArray") + ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol, splits), + Seq(outputCols, splitsArray)) if (isSet(inputCols)) { + require(getInputCols.length == getOutputCols.length && + getInputCols.length == getSplitsArray.length, s"Bucketizer $this has mismatched Params " + + s"for multi-column transform. Params (inputCols, outputCols, splitsArray) should have " + + s"equal lengths, but they have different lengths: " + + s"(${getInputCols.length}, ${getOutputCols.length}, ${getSplitsArray.length}).") + var transformedSchema = schema $(inputCols).zip($(outputCols)).zipWithIndex.foreach { case ((inputCol, outputCol), idx) => SchemaUtils.checkNumericType(transformedSchema, inputCol) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 8fbb0e1b2a3b..bd15a8492e3b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -27,7 +27,6 @@ import scala.collection.mutable import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.slf4j.LoggerFactory import org.apache.spark.SparkException import org.apache.spark.annotation.{DeveloperApi, Since} @@ -167,8 +166,6 @@ private[ml] object Param { @DeveloperApi object ParamValidators { - private val LOGGER = LoggerFactory.getLogger(ParamValidators.getClass) - /** (private[param]) Default validation always return true */ private[param] def alwaysTrue[T]: T => Boolean = (_: T) => true @@ -254,21 +251,69 @@ object ParamValidators { } /** - * Checks that only one of the params passed as arguments is set. If this is not true, an - * `IllegalArgumentException` is raised. + * Utility for Param validity checks for Transformers which have both single- and multi-column + * support. This utility assumes that `inputCol` indicates single-column usage and + * that `inputCols` indicates multi-column usage. + * + * This checks to ensure that exactly one set of Params has been set, and it + * raises an `IllegalArgumentException` if not. + * + * @param singleColumnParams Params which should be set (or have defaults) if `inputCol` has been + * set. This does not need to include `inputCol`. + * @param multiColumnParams Params which should be set (or have defaults) if `inputCols` has been + * set. This does not need to include `inputCols`. */ - def checkExclusiveParams(model: Params, params: String*): Unit = { - val (existingParams, nonExistingParams) = params.partition(model.hasParam) - if (nonExistingParams.nonEmpty) { - val pronoun = if (nonExistingParams.size == 1) "It" else "They" - LOGGER.warn(s"Ignored ${nonExistingParams.mkString("`", "`, `", "`")} while checking " + - s"exclusive params. $pronoun don't exist for the specified model the model.") + def checkSingleVsMultiColumnParams( + model: Params, + singleColumnParams: Seq[Param[_]], + multiColumnParams: Seq[Param[_]]): Unit = { + val name = s"${model.getClass.getSimpleName} $model" + + def checkExclusiveParams( + isSingleCol: Boolean, + requiredParams: Seq[Param[_]], + excludedParams: Seq[Param[_]]): Unit = { + val badParamsMsgBuilder = new mutable.StringBuilder() + + val mustUnsetParams = excludedParams.filter(p => model.isSet(p)) + .map(_.name).mkString(", ") + if (mustUnsetParams.nonEmpty) + badParamsMsgBuilder ++= + s"The following Params are not applicable and should not be set: $mustUnsetParams." + + val mustSetParams = requiredParams.filter(p => !model.isDefined(p)) + .map(_.name).mkString(", ") + if (mustSetParams.nonEmpty) + badParamsMsgBuilder ++= + s"The following Params must be defined but are not set: $mustSetParams." + + val badParamsMsg = badParamsMsgBuilder.toString() + + if (badParamsMsg.nonEmpty) { + val errPrefix = if (isSingleCol) { + s"$name has the inputCol Param set for single-column transform." + } else { + s"$name has the inputCols Param set for multi-column transform." + } + throw new IllegalArgumentException(s"$errPrefix $badParamsMsg") + } } - if (existingParams.count(paramName => model.isSet(model.getParam(paramName))) > 1) { - val paramString = existingParams.mkString("`", "`, `", "`") - throw new IllegalArgumentException(s"$paramString are exclusive, " + - "but more than one among them are set.") + val inputCol = model.getParam("inputCol") + val inputCols = model.getParam("inputCols") + + if (model.isSet(inputCol)) { + require(!model.isSet(inputCols), s"$name requires " + + s"exactly one of inputCol, inputCols Params to be set, but both are set.") + + checkExclusiveParams(isSingleCol = true, requiredParams = singleColumnParams, + excludedParams = multiColumnParams) + } else if (model.isSet(inputCols)) { + checkExclusiveParams(isSingleCol = false, requiredParams = multiColumnParams, + excludedParams = singleColumnParams) + } else { + throw new IllegalArgumentException(s"$name requires " + + s"exactly one of inputCol, inputCols Params to be set, but neither is set.") } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 94d6125c4c98..6ecab7cbf696 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -435,24 +435,23 @@ object ParamsSuite extends SparkFunSuite { } /** - * Checks that the class throws an exception in case multiple exclusive params are set + * Checks that the class throws an exception in case multiple exclusive params are set. * The params to be checked are passed as arguments with their value. - * The checks are performed only if all the passed params are defined for the given model. */ - def testExclusiveParams(model: Params, dataset: Dataset[_], + def testExclusiveParams( + model: Params, + dataset: Dataset[_], paramsAndValues: (String, Any)*): Unit = { - val params = paramsAndValues.map(_._1) - if (params.forall(model.hasParam)) { - paramsAndValues.foreach { case (paramName, paramValue) => - model.set(model.getParam(paramName), paramValue) - } - val e = intercept[IllegalArgumentException] { - model match { - case t: Transformer => t.transform(dataset) - case e: Estimator[_] => e.fit(dataset) - } + val m = model.copy(ParamMap.empty) + paramsAndValues.foreach { case (paramName, paramValue) => + m.set(m.getParam(paramName), paramValue) + } + val e = intercept[IllegalArgumentException] { + m match { + case t: Transformer => t.transform(dataset) + case e: Estimator[_] => e.fit(dataset) } - assert(e.getMessage.contains("are exclusive, but more than one")) } + assert(e.getMessage.contains("are exclusive, but more than one")) } } From 8c162a335258fc320061bc00653ca1c3d0c13f24 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sun, 21 Jan 2018 12:00:47 +0100 Subject: [PATCH 15/18] fix style error --- mllib/src/main/scala/org/apache/spark/ml/param/params.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index bd15a8492e3b..9a83a5882ce2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -277,15 +277,17 @@ object ParamValidators { val mustUnsetParams = excludedParams.filter(p => model.isSet(p)) .map(_.name).mkString(", ") - if (mustUnsetParams.nonEmpty) + if (mustUnsetParams.nonEmpty) { badParamsMsgBuilder ++= s"The following Params are not applicable and should not be set: $mustUnsetParams." + } val mustSetParams = requiredParams.filter(p => !model.isDefined(p)) .map(_.name).mkString(", ") - if (mustSetParams.nonEmpty) + if (mustSetParams.nonEmpty) { badParamsMsgBuilder ++= s"The following Params must be defined but are not set: $mustSetParams." + } val badParamsMsg = badParamsMsgBuilder.toString() From 7894609642a0cd49627cb4853a275fb2408cec46 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sun, 21 Jan 2018 13:38:51 +0100 Subject: [PATCH 16/18] fixt ut error --- mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 6ecab7cbf696..ecada3de91cc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -452,6 +452,5 @@ object ParamsSuite extends SparkFunSuite { case e: Estimator[_] => e.fit(dataset) } } - assert(e.getMessage.contains("are exclusive, but more than one")) } } From ebc6d16586318155180e37d7a1a199aa1a8b9cf2 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 23 Jan 2018 11:58:35 +0100 Subject: [PATCH 17/18] add all cases to UT --- .../apache/spark/ml/feature/BucketizerSuite.scala | 14 ++++++++++++-- .../org/apache/spark/ml/param/ParamsSuite.scala | 2 +- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 69027d1b1894..ccf2465a40f5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -394,10 +394,20 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val df = Seq((0.5, 0.3), (0.5, -0.4)).toDF("feature1", "feature2") ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"), ("inputCols", Array("feature1", "feature2"))) - ParamsSuite.testExclusiveParams(new Bucketizer, df, ("outputCol", "result1"), + ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"), + ("outputCol", "result1"), ("splits", Array(-0.5, 0.0, 0.5)), ("outputCols", Array("result1", "result2"))) - ParamsSuite.testExclusiveParams(new Bucketizer, df, ("splits", Array(-0.5, 0.0, 0.5)), + ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"), + ("outputCol", "result1"), ("splits", Array(-0.5, 0.0, 0.5)), ("splitsArray", Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5)))) + + // this should fail because at least one of inputCol and inputCols must be set + ParamsSuite.testExclusiveParams(new Bucketizer, df, ("outputCol", "feature1"), + ("splits", Array(-0.5, 0.0, 0.5))) + + // the following should fail because not all the params are set + ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"), + ("outputCol", "result1")) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index ecada3de91cc..cb0267934527 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -446,7 +446,7 @@ object ParamsSuite extends SparkFunSuite { paramsAndValues.foreach { case (paramName, paramValue) => m.set(m.getParam(paramName), paramValue) } - val e = intercept[IllegalArgumentException] { + intercept[IllegalArgumentException] { m match { case t: Transformer => t.transform(dataset) case e: Estimator[_] => e.fit(dataset) From 2bc5cb4948f71ac77b6e5ff824cd8fc6b5d64910 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 24 Jan 2018 16:17:42 +0100 Subject: [PATCH 18/18] review comment --- .../scala/org/apache/spark/ml/feature/BucketizerSuite.scala | 3 +++ .../src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala | 1 - 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index ccf2465a40f5..7403680ae3fd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -408,6 +408,9 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa // the following should fail because not all the params are set ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"), ("outputCol", "result1")) + ParamsSuite.testExclusiveParams(new Bucketizer, df, + ("inputCols", Array("feature1", "feature2")), + ("outputCols", Array("result1", "result2"))) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index cb0267934527..36e06091d24d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -22,7 +22,6 @@ import java.io.{ByteArrayOutputStream, ObjectOutputStream} import org.apache.spark.SparkFunSuite import org.apache.spark.ml.{Estimator, Transformer} import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.param.shared.{HasInputCol, HasInputCols, HasOutputCol, HasOutputCols} import org.apache.spark.ml.util.MyParams import org.apache.spark.sql.Dataset