diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 7cc64d43858c..f8a0c528bff0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -313,11 +313,11 @@ object FunctionRegistry { expression[CollectList]("collect_list"), expression[CollectSet]("collect_set"), expression[CountMinSketchAgg]("count_min_sketch"), - expression[BoolAnd]("every"), - expression[BoolAnd]("bool_and"), - expression[BoolOr]("any"), - expression[BoolOr]("some"), - expression[BoolOr]("bool_or"), + expressionWithAlias[BoolAnd]("every"), + expressionWithAlias[BoolAnd]("bool_and"), + expressionWithAlias[BoolOr]("any"), + expressionWithAlias[BoolOr]("some"), + expressionWithAlias[BoolOr]("bool_or"), // string functions expression[Ascii]("ascii"), @@ -590,12 +590,12 @@ object FunctionRegistry { val builder = (expressions: Seq[Expression]) => { if (varargCtor.isDefined) { // If there is an apply method that accepts Seq[Expression], use that one. - Try(varargCtor.get.newInstance(expressions).asInstanceOf[Expression]) match { - case Success(e) => e - case Failure(e) => - // the exception is an invocation exception. To get a meaningful message, we need the - // cause. - throw new AnalysisException(e.getCause.getMessage) + try { + varargCtor.get.newInstance(expressions).asInstanceOf[Expression] + } catch { + // the exception is an invocation exception. To get a meaningful message, we need the + // cause. + case e: Exception => throw new AnalysisException(e.getCause.getMessage) } } else { // Otherwise, find a constructor method that matches the number of arguments, and use that. @@ -618,12 +618,12 @@ object FunctionRegistry { } throw new AnalysisException(invalidArgumentsMsg) } - Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match { - case Success(e) => e - case Failure(e) => - // the exception is an invocation exception. To get a meaningful message, we need the - // cause. - throw new AnalysisException(e.getCause.getMessage) + try { + f.newInstance(expressions : _*).asInstanceOf[Expression] + } catch { + // the exception is an invocation exception. To get a meaningful message, we need the + // cause. + case e: Exception => throw new AnalysisException(e.getCause.getMessage) } } } @@ -631,6 +631,42 @@ object FunctionRegistry { (name, (expressionInfo[T](name), builder)) } + private def expressionWithAlias[T <: Expression](name: String) + (implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = { + val constructors = tag.runtimeClass.getConstructors + .filter(_.getParameterTypes.head == classOf[String]) + assert(constructors.length == 1) + val builder = (expressions: Seq[Expression]) => { + val params = classOf[String] +: Seq.fill(expressions.size)(classOf[Expression]) + val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse { + val validParametersCount = constructors + .filter(_.getParameterTypes.tail.forall(_ == classOf[Expression])) + .map(_.getParameterCount - 1).distinct.sorted + val invalidArgumentsMsg = if (validParametersCount.length == 0) { + s"Invalid arguments for function $name" + } else { + val expectedNumberOfParameters = if (validParametersCount.length == 1) { + validParametersCount.head.toString + } else { + validParametersCount.init.mkString("one of ", ", ", " and ") + + validParametersCount.last + } + s"Invalid number of arguments for function $name. " + + s"Expected: $expectedNumberOfParameters; Found: ${expressions.size}" + } + throw new AnalysisException(invalidArgumentsMsg) + } + try { + f.newInstance(name.toString +: expressions: _*).asInstanceOf[Expression] + } catch { + // the exception is an invocation exception. To get a meaningful message, we need the + // cause. + case e: Exception => throw new AnalysisException(e.getCause.getMessage) + } + } + (name, (expressionInfo[T](name), builder)) + } + /** * Creates a function registry lookup entry for cast aliases (SPARK-16730). * For example, if name is "int", and dataType is IntegerType, this means int(x) would become diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala index c559fefe3a80..acb0af0248a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala @@ -52,8 +52,8 @@ abstract class UnevaluableBooleanAggBase(arg: Expression) false """, since = "3.0.0") -case class BoolAnd(arg: Expression) extends UnevaluableBooleanAggBase(arg) { - override def nodeName: String = "bool_and" +case class BoolAnd(funcName: String, arg: Expression) extends UnevaluableBooleanAggBase(arg) { + override def nodeName: String = funcName } @ExpressionDescription( @@ -68,6 +68,6 @@ case class BoolAnd(arg: Expression) extends UnevaluableBooleanAggBase(arg) { false """, since = "3.0.0") -case class BoolOr(arg: Expression) extends UnevaluableBooleanAggBase(arg) { - override def nodeName: String = "bool_or" +case class BoolOr(funcName: String, arg: Expression) extends UnevaluableBooleanAggBase(arg) { + override def nodeName: String = funcName } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index f64b6e00373f..c33027434152 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -47,8 +47,8 @@ object ReplaceExpressions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e: RuntimeReplaceable => e.child case CountIf(predicate) => Count(new NullIf(predicate, Literal.FalseLiteral)) - case BoolOr(arg) => Max(arg) - case BoolAnd(arg) => Min(arg) + case BoolOr(_, arg) => Max(arg) + case BoolAnd(_, arg) => Min(arg) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index feb927264ba6..c83759e8f4c1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -144,8 +144,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(Sum('stringField)) assertSuccess(Average('stringField)) assertSuccess(Min('arrayField)) - assertSuccess(new BoolAnd('booleanField)) - assertSuccess(new BoolOr('booleanField)) + assertSuccess(new BoolAnd("bool_and", 'booleanField)) + assertSuccess(new BoolOr("bool_or", 'booleanField)) assertError(Min('mapField), "min does not support ordering on type") assertError(Max('mapField), "max does not support ordering on type") diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 8b6e370a9867..ed5ced8c8c0f 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -293,7 +293,7 @@ struct<> -- !query 31 SELECT every(v), some(v), any(v), bool_and(v), bool_or(v) FROM test_agg WHERE 1 = 0 -- !query 31 schema -struct +struct -- !query 31 output NULL NULL NULL NULL NULL @@ -301,7 +301,7 @@ NULL NULL NULL NULL NULL -- !query 32 SELECT every(v), some(v), any(v), bool_and(v), bool_or(v) FROM test_agg WHERE k = 4 -- !query 32 schema -struct +struct -- !query 32 output NULL NULL NULL NULL NULL @@ -309,7 +309,7 @@ NULL NULL NULL NULL NULL -- !query 33 SELECT every(v), some(v), any(v), bool_and(v), bool_or(v) FROM test_agg WHERE k = 5 -- !query 33 schema -struct +struct -- !query 33 output false true true false true @@ -317,7 +317,7 @@ false true true false true -- !query 34 SELECT k, every(v), some(v), any(v), bool_and(v), bool_or(v) FROM test_agg GROUP BY k -- !query 34 schema -struct +struct -- !query 34 output 1 false true true false true 2 true true true true true @@ -329,7 +329,7 @@ struct +struct -- !query 35 output 1 false 3 false @@ -339,7 +339,7 @@ struct -- !query 36 SELECT k, every(v) FROM test_agg GROUP BY k HAVING every(v) IS NULL -- !query 36 schema -struct +struct -- !query 36 output 4 NULL @@ -380,7 +380,7 @@ SELECT every(1) struct<> -- !query 39 output org.apache.spark.sql.AnalysisException -cannot resolve 'bool_and(1)' due to data type mismatch: Input to function 'bool_and' should have been boolean, but it's [int].; line 1 pos 7 +cannot resolve 'every(1)' due to data type mismatch: Input to function 'every' should have been boolean, but it's [int].; line 1 pos 7 -- !query 40 @@ -389,7 +389,7 @@ SELECT some(1S) struct<> -- !query 40 output org.apache.spark.sql.AnalysisException -cannot resolve 'bool_or(1S)' due to data type mismatch: Input to function 'bool_or' should have been boolean, but it's [smallint].; line 1 pos 7 +cannot resolve 'some(1S)' due to data type mismatch: Input to function 'some' should have been boolean, but it's [smallint].; line 1 pos 7 -- !query 41 @@ -398,7 +398,7 @@ SELECT any(1L) struct<> -- !query 41 output org.apache.spark.sql.AnalysisException -cannot resolve 'bool_or(1L)' due to data type mismatch: Input to function 'bool_or' should have been boolean, but it's [bigint].; line 1 pos 7 +cannot resolve 'any(1L)' due to data type mismatch: Input to function 'any' should have been boolean, but it's [bigint].; line 1 pos 7 -- !query 42 @@ -407,7 +407,7 @@ SELECT every("true") struct<> -- !query 42 output org.apache.spark.sql.AnalysisException -cannot resolve 'bool_and('true')' due to data type mismatch: Input to function 'bool_and' should have been boolean, but it's [string].; line 1 pos 7 +cannot resolve 'every('true')' due to data type mismatch: Input to function 'every' should have been boolean, but it's [string].; line 1 pos 7 -- !query 43 @@ -431,7 +431,7 @@ cannot resolve 'bool_or(1.0D)' due to data type mismatch: Input to function 'boo -- !query 45 SELECT k, v, every(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg -- !query 45 schema -struct +struct -- !query 45 output 1 false false 1 true false @@ -448,7 +448,7 @@ struct +struct -- !query 46 output 1 false false 1 true true @@ -465,7 +465,7 @@ struct +struct -- !query 47 output 1 false false 1 true true diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out index ea2cab703eaa..bdacd184158a 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out @@ -293,7 +293,7 @@ struct<> -- !query 31 SELECT udf(every(v)), udf(some(v)), any(v) FROM test_agg WHERE 1 = 0 -- !query 31 schema -struct +struct -- !query 31 output NULL NULL NULL @@ -301,7 +301,7 @@ NULL NULL NULL -- !query 32 SELECT udf(every(udf(v))), some(v), any(v) FROM test_agg WHERE k = 4 -- !query 32 schema -struct +struct -- !query 32 output NULL NULL NULL @@ -309,7 +309,7 @@ NULL NULL NULL -- !query 33 SELECT every(v), udf(some(v)), any(v) FROM test_agg WHERE k = 5 -- !query 33 schema -struct +struct -- !query 33 output false true true @@ -317,7 +317,7 @@ false true true -- !query 34 SELECT udf(k), every(v), udf(some(v)), any(v) FROM test_agg GROUP BY udf(k) -- !query 34 schema -struct +struct -- !query 34 output 1 false true true 2 true true true @@ -329,7 +329,7 @@ struct +struct -- !query 35 output 1 false 3 false @@ -339,7 +339,7 @@ struct -- !query 36 SELECT udf(k), udf(every(v)) FROM test_agg GROUP BY udf(k) HAVING every(v) IS NULL -- !query 36 schema -struct +struct -- !query 36 output 4 NULL @@ -380,7 +380,7 @@ SELECT every(udf(1)) struct<> -- !query 39 output org.apache.spark.sql.AnalysisException -cannot resolve 'bool_and(CAST(udf(cast(1 as string)) AS INT))' due to data type mismatch: Input to function 'bool_and' should have been boolean, but it's [int].; line 1 pos 7 +cannot resolve 'every(CAST(udf(cast(1 as string)) AS INT))' due to data type mismatch: Input to function 'every' should have been boolean, but it's [int].; line 1 pos 7 -- !query 40 @@ -389,7 +389,7 @@ SELECT some(udf(1S)) struct<> -- !query 40 output org.apache.spark.sql.AnalysisException -cannot resolve 'bool_or(CAST(udf(cast(1 as string)) AS SMALLINT))' due to data type mismatch: Input to function 'bool_or' should have been boolean, but it's [smallint].; line 1 pos 7 +cannot resolve 'some(CAST(udf(cast(1 as string)) AS SMALLINT))' due to data type mismatch: Input to function 'some' should have been boolean, but it's [smallint].; line 1 pos 7 -- !query 41 @@ -398,7 +398,7 @@ SELECT any(udf(1L)) struct<> -- !query 41 output org.apache.spark.sql.AnalysisException -cannot resolve 'bool_or(CAST(udf(cast(1 as string)) AS BIGINT))' due to data type mismatch: Input to function 'bool_or' should have been boolean, but it's [bigint].; line 1 pos 7 +cannot resolve 'any(CAST(udf(cast(1 as string)) AS BIGINT))' due to data type mismatch: Input to function 'any' should have been boolean, but it's [bigint].; line 1 pos 7 -- !query 42 @@ -407,13 +407,13 @@ SELECT udf(every("true")) struct<> -- !query 42 output org.apache.spark.sql.AnalysisException -cannot resolve 'bool_and('true')' due to data type mismatch: Input to function 'bool_and' should have been boolean, but it's [string].; line 1 pos 11 +cannot resolve 'every('true')' due to data type mismatch: Input to function 'every' should have been boolean, but it's [string].; line 1 pos 11 -- !query 43 SELECT k, v, every(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg -- !query 43 schema -struct +struct -- !query 43 output 1 false false 1 true false @@ -430,7 +430,7 @@ struct +struct -- !query 44 output 1 false false 1 true true @@ -447,7 +447,7 @@ struct +struct -- !query 45 output 1 false false 1 true true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala index d5c16a30ade5..f968fbb27d4f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala @@ -95,8 +95,8 @@ class ExplainSuite extends QueryTest with SharedSparkSession { // plan should show the rewritten aggregate expression. val df = sql("SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k") checkKeywordsExistsInExplain(df, - "Aggregate [k#x], [k#x, min(v#x) AS bool_and(v)#x, max(v#x) AS bool_or(v)#x, " + - "max(v#x) AS bool_or(v)#x]") + "Aggregate [k#x], [k#x, min(v#x) AS every(v)#x, max(v#x) AS some(v)#x, " + + "max(v#x) AS any(v)#x]") } }