diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index f8a0c528bff0..7a8b88e5264a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -22,7 +22,6 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.mutable import scala.reflect.ClassTag -import scala.util.{Failure, Success, Try} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException @@ -31,6 +30,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.xml._ +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.types._ @@ -193,6 +193,8 @@ object FunctionRegistry { type FunctionBuilder = Seq[Expression] => Expression + val FUNC_ALIAS = TreeNodeTag[String]("functionAliasName") + // Note: Whenever we add a new entry here, make sure we also update ExpressionToSQLSuite val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] = Map( // misc non-aggregate functions @@ -289,35 +291,35 @@ object FunctionRegistry { expression[CovPopulation]("covar_pop"), expression[CovSample]("covar_samp"), expression[First]("first"), - expression[First]("first_value"), + expression[First]("first_value", true), expression[Kurtosis]("kurtosis"), expression[Last]("last"), - expression[Last]("last_value"), + expression[Last]("last_value", true), expression[Max]("max"), expression[MaxBy]("max_by"), - expression[Average]("mean"), + expression[Average]("mean", true), expression[Min]("min"), expression[MinBy]("min_by"), expression[Percentile]("percentile"), expression[Skewness]("skewness"), expression[ApproximatePercentile]("percentile_approx"), - expression[ApproximatePercentile]("approx_percentile"), - expression[StddevSamp]("std"), - expression[StddevSamp]("stddev"), + expression[ApproximatePercentile]("approx_percentile", true), + expression[StddevSamp]("std", true), + expression[StddevSamp]("stddev", true), expression[StddevPop]("stddev_pop"), expression[StddevSamp]("stddev_samp"), expression[Sum]("sum"), - expression[VarianceSamp]("variance"), + expression[VarianceSamp]("variance", true), expression[VariancePop]("var_pop"), expression[VarianceSamp]("var_samp"), expression[CollectList]("collect_list"), expression[CollectSet]("collect_set"), expression[CountMinSketchAgg]("count_min_sketch"), - expressionWithAlias[BoolAnd]("every"), - expressionWithAlias[BoolAnd]("bool_and"), - expressionWithAlias[BoolOr]("any"), - expressionWithAlias[BoolOr]("some"), - expressionWithAlias[BoolOr]("bool_or"), + expression[BoolAnd]("every", true), + expression[BoolAnd]("bool_and"), + expression[BoolOr]("any", true), + expression[BoolOr]("some", true), + expression[BoolOr]("bool_or"), // string functions expression[Ascii]("ascii"), @@ -573,7 +575,7 @@ object FunctionRegistry { val functionSet: Set[FunctionIdentifier] = builtin.listFunction().toSet /** See usage above. */ - private def expression[T <: Expression](name: String) + private def expression[T <: Expression](name: String, setAlias: Boolean = false) (implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = { // For `RuntimeReplaceable`, skip the constructor with most arguments, which is the main @@ -619,7 +621,9 @@ object FunctionRegistry { throw new AnalysisException(invalidArgumentsMsg) } try { - f.newInstance(expressions : _*).asInstanceOf[Expression] + val exp = f.newInstance(expressions : _*).asInstanceOf[Expression] + if (setAlias) exp.setTagValue(FUNC_ALIAS, name) + exp } catch { // the exception is an invocation exception. To get a meaningful message, we need the // cause. @@ -631,42 +635,6 @@ object FunctionRegistry { (name, (expressionInfo[T](name), builder)) } - private def expressionWithAlias[T <: Expression](name: String) - (implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = { - val constructors = tag.runtimeClass.getConstructors - .filter(_.getParameterTypes.head == classOf[String]) - assert(constructors.length == 1) - val builder = (expressions: Seq[Expression]) => { - val params = classOf[String] +: Seq.fill(expressions.size)(classOf[Expression]) - val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse { - val validParametersCount = constructors - .filter(_.getParameterTypes.tail.forall(_ == classOf[Expression])) - .map(_.getParameterCount - 1).distinct.sorted - val invalidArgumentsMsg = if (validParametersCount.length == 0) { - s"Invalid arguments for function $name" - } else { - val expectedNumberOfParameters = if (validParametersCount.length == 1) { - validParametersCount.head.toString - } else { - validParametersCount.init.mkString("one of ", ", ", " and ") + - validParametersCount.last - } - s"Invalid number of arguments for function $name. " + - s"Expected: $expectedNumberOfParameters; Found: ${expressions.size}" - } - throw new AnalysisException(invalidArgumentsMsg) - } - try { - f.newInstance(name.toString +: expressions: _*).asInstanceOf[Expression] - } catch { - // the exception is an invocation exception. To get a meaningful message, we need the - // cause. - case e: Exception => throw new AnalysisException(e.getCause.getMessage) - } - } - (name, (expressionInfo[T](name), builder)) - } - /** * Creates a function registry lookup entry for cast aliases (SPARK-16730). * For example, if name is "int", and dataType is IntegerType, this means int(x) would become diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index ea0ed2e8fa11..b143ddef6a6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import com.google.common.primitives.{Doubles, Ints, Longs} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest @@ -185,7 +185,8 @@ case class ApproximatePercentile( if (returnPercentileArray) ArrayType(child.dataType, false) else child.dataType } - override def prettyName: String = "percentile_approx" + override def prettyName: String = + getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("percentile_approx") override def serialize(obj: PercentileDigest): Array[Byte] = { ApproximatePercentile.serializer.serialize(obj) @@ -321,4 +322,5 @@ object ApproximatePercentile { } val serializer: PercentileDigestSerializer = new PercentileDigestSerializer + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index aaad3c7bcefa..9bb048a9851e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, TypeCheckResult} +import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils @@ -37,7 +37,7 @@ import org.apache.spark.sql.types._ since = "1.0.0") case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { - override def prettyName: String = "avg" + override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg") override def children: Seq[Expression] = child :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 8ce8dfa19c01..bf402807d62d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -174,7 +175,8 @@ case class StddevSamp(child: Expression) extends CentralMomentAgg(child) { If(n === 1.0, Double.NaN, sqrt(m2 / (n - 1.0)))) } - override def prettyName: String = "stddev_samp" + override def prettyName: String = + getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("stddev_samp") } // Compute the population variance of a column @@ -215,7 +217,7 @@ case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) { If(n === 1.0, Double.NaN, m2 / (n - 1.0))) } - override def prettyName: String = "var_samp" + override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("var_samp") } @ExpressionDescription( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index 9f351395846e..8de866ed9fb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ @@ -113,5 +113,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression) override lazy val evaluateExpression: AttributeReference = first - override def toString: String = s"first($child)${if (ignoreNulls) " ignore nulls"}" + override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("first") + + override def toString: String = s"$prettyName($child)${if (ignoreNulls) " ignore nulls"}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index 405719faaeb5..f8af0cd1f303 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ @@ -111,5 +111,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) override lazy val evaluateExpression: AttributeReference = last - override def toString: String = s"last($child)${if (ignoreNulls) " ignore nulls"}" + override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("last") + + override def toString: String = s"$prettyName($child)${if (ignoreNulls) " ignore nulls"}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala index acb0af0248a7..a1cd4a77d044 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -52,8 +52,8 @@ abstract class UnevaluableBooleanAggBase(arg: Expression) false """, since = "3.0.0") -case class BoolAnd(funcName: String, arg: Expression) extends UnevaluableBooleanAggBase(arg) { - override def nodeName: String = funcName +case class BoolAnd(arg: Expression) extends UnevaluableBooleanAggBase(arg) { + override def nodeName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("bool_and") } @ExpressionDescription( @@ -68,6 +68,6 @@ case class BoolAnd(funcName: String, arg: Expression) extends UnevaluableBoolean false """, since = "3.0.0") -case class BoolOr(funcName: String, arg: Expression) extends UnevaluableBooleanAggBase(arg) { - override def nodeName: String = funcName +case class BoolOr(arg: Expression) extends UnevaluableBooleanAggBase(arg) { + override def nodeName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("bool_or") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index c33027434152..f64b6e00373f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -47,8 +47,8 @@ object ReplaceExpressions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e: RuntimeReplaceable => e.child case CountIf(predicate) => Count(new NullIf(predicate, Literal.FalseLiteral)) - case BoolOr(_, arg) => Max(arg) - case BoolAnd(_, arg) => Min(arg) + case BoolOr(arg) => Max(arg) + case BoolAnd(arg) => Min(arg) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index f944b4ad87e4..86a1f1fb58a0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -153,8 +153,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(Sum(Symbol("stringField"))) assertSuccess(Average(Symbol("stringField"))) assertSuccess(Min(Symbol("arrayField"))) - assertSuccess(new BoolAnd("bool_and", Symbol("booleanField"))) - assertSuccess(new BoolOr("bool_or", Symbol("booleanField"))) + assertSuccess(new BoolAnd(Symbol("booleanField"))) + assertSuccess(new BoolOr(Symbol("booleanField"))) assertError(Min(Symbol("mapField")), "min does not support ordering on type") assertError(Max(Symbol("mapField")), "max does not support ordering on type") diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index ed5ced8c8c0f..62a166649708 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -128,7 +128,7 @@ NULL 1 SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a), AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a) FROM testData -- !query 13 schema -struct +struct -- !query 13 output -0.2723801058145729 -1.5069204152249134 1 3 2.142857142857143 0.8095238095238094 0.8997354108424372 15 7 diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out index 54ceacd3b3b3..4721ceb03a96 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out @@ -241,7 +241,7 @@ NaN SELECT VARIANCE(n) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) FROM (VALUES(1,600),(2,470),(3,170),(4,430),(5,300)) r(i,n) -- !query 19 schema -struct +struct -- !query 19 output 16900.0 18491.666666666668 @@ -254,7 +254,7 @@ NaN SELECT VARIANCE(n) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) FROM (VALUES(1,600),(2,470),(3,170),(4,430),(5,300)) r(i,n) -- !query 20 schema -struct +struct -- !query 20 output 16900.0 18491.666666666668 @@ -267,7 +267,7 @@ NaN SELECT VARIANCE(n) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) FROM (VALUES(1,600),(2,470),(3,170),(4,430),(5,300)) r(i,n) -- !query 21 schema -struct +struct -- !query 21 output 16900.0 18491.666666666668 @@ -280,7 +280,7 @@ NaN SELECT VARIANCE(n) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) FROM (VALUES(1,600),(2,470),(3,170),(4,430),(5,300)) r(i,n) -- !query 22 schema -struct +struct -- !query 22 output 16900.0 18491.666666666668 @@ -405,7 +405,7 @@ NaN SELECT STDDEV(n) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) FROM (VALUES(0,NULL),(1,600),(2,470),(3,170),(4,430),(5,300)) r(i,n) -- !query 31 schema -struct +struct -- !query 31 output 130.0 135.9840676942217 @@ -419,7 +419,7 @@ NaN SELECT STDDEV(n) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) FROM (VALUES(0,NULL),(1,600),(2,470),(3,170),(4,430),(5,300)) r(i,n) -- !query 32 schema -struct +struct -- !query 32 output 130.0 135.9840676942217 @@ -433,7 +433,7 @@ NaN SELECT STDDEV(n) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) FROM (VALUES(0,NULL),(1,600),(2,470),(3,170),(4,430),(5,300)) r(i,n) -- !query 33 schema -struct +struct -- !query 33 output 130.0 135.9840676942217 @@ -447,7 +447,7 @@ NaN SELECT STDDEV(n) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) FROM (VALUES(0,NULL),(1,600),(2,470),(3,170),(4,430),(5,300)) r(i,n) -- !query 34 schema -struct +struct -- !query 34 output 130.0 135.9840676942217 diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out index bdacd184158a..a835740a6a86 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out @@ -128,7 +128,7 @@ NULL 1 SELECT SKEWNESS(udf(a)), udf(KURTOSIS(a)), udf(MIN(a)), MAX(udf(a)), udf(AVG(udf(a))), udf(VARIANCE(a)), STDDEV(udf(a)), udf(SUM(a)), udf(COUNT(a)) FROM testData -- !query 13 schema -struct +struct -- !query 13 output -0.2723801058145729 -1.5069204152249134 1 3 2.142857142857143 0.8095238095238094 0.8997354108424372 15 7 @@ -293,7 +293,7 @@ struct<> -- !query 31 SELECT udf(every(v)), udf(some(v)), any(v) FROM test_agg WHERE 1 = 0 -- !query 31 schema -struct +struct -- !query 31 output NULL NULL NULL @@ -301,7 +301,7 @@ NULL NULL NULL -- !query 32 SELECT udf(every(udf(v))), some(v), any(v) FROM test_agg WHERE k = 4 -- !query 32 schema -struct +struct -- !query 32 output NULL NULL NULL @@ -309,7 +309,7 @@ NULL NULL NULL -- !query 33 SELECT every(v), udf(some(v)), any(v) FROM test_agg WHERE k = 5 -- !query 33 schema -struct +struct -- !query 33 output false true true @@ -317,7 +317,7 @@ false true true -- !query 34 SELECT udf(k), every(v), udf(some(v)), any(v) FROM test_agg GROUP BY udf(k) -- !query 34 schema -struct +struct -- !query 34 output 1 false true true 2 true true true @@ -339,7 +339,7 @@ struct -- !query 36 SELECT udf(k), udf(every(v)) FROM test_agg GROUP BY udf(k) HAVING every(v) IS NULL -- !query 36 schema -struct +struct -- !query 36 output 4 NULL