From 29fdabae17d682ccfe487daf979e13d7b3569d91 Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Mon, 9 Dec 2019 15:03:57 +0530 Subject: [PATCH 01/32] expressionWithAlias for First, Last, StddevSamp, VarianceSamp --- .../spark/sql/catalyst/analysis/Analyzer.scala | 8 ++++---- .../catalyst/analysis/FunctionRegistry.scala | 18 +++++++++--------- .../sql/catalyst/analysis/TypeCoercion.scala | 4 ++-- .../aggregate/CentralMomentAgg.scala | 8 ++++---- .../catalyst/expressions/aggregate/First.scala | 6 +++--- .../catalyst/expressions/aggregate/Last.scala | 6 +++--- .../scala/org/apache/spark/sql/functions.scala | 12 ++++++------ 7 files changed, 31 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 659e4a5c86ec..4a90e25d57d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -692,10 +692,10 @@ class Analyzer( // Assumption is the aggregate function ignores nulls. This is true for all current // AggregateFunction's with the exception of First and Last in their default mode // (which we handle) and possibly some Hive UDAF's. - case First(expr, _) => - First(ifExpr(expr), Literal(true)) - case Last(expr, _) => - Last(ifExpr(expr), Literal(true)) + case First(funcName, expr, _) => + First(funcName, ifExpr(expr), Literal(true)) + case Last(funcName, expr, _) => + Last(funcName, ifExpr(expr), Literal(true)) case a: AggregateFunction => a.withNewChildren(a.children.map(ifExpr)) }.transform { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index f8a0c528bff0..b1752275ae86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -288,11 +288,11 @@ object FunctionRegistry { expression[CountIf]("count_if"), expression[CovPopulation]("covar_pop"), expression[CovSample]("covar_samp"), - expression[First]("first"), - expression[First]("first_value"), + expressionWithAlias[First]("first"), + expressionWithAlias[First]("first_value"), expression[Kurtosis]("kurtosis"), - expression[Last]("last"), - expression[Last]("last_value"), + expressionWithAlias[Last]("last"), + expressionWithAlias[Last]("last_value"), expression[Max]("max"), expression[MaxBy]("max_by"), expression[Average]("mean"), @@ -302,14 +302,14 @@ object FunctionRegistry { expression[Skewness]("skewness"), expression[ApproximatePercentile]("percentile_approx"), expression[ApproximatePercentile]("approx_percentile"), - expression[StddevSamp]("std"), - expression[StddevSamp]("stddev"), + expressionWithAlias[StddevSamp]("std"), + expressionWithAlias[StddevSamp]("stddev"), expression[StddevPop]("stddev_pop"), - expression[StddevSamp]("stddev_samp"), + expressionWithAlias[StddevSamp]("stddev_samp"), expression[Sum]("sum"), - expression[VarianceSamp]("variance"), + expressionWithAlias[VarianceSamp]("variance"), expression[VariancePop]("var_pop"), - expression[VarianceSamp]("var_samp"), + expressionWithAlias[VarianceSamp]("var_samp"), expression[CollectList]("collect_list"), expression[CollectSet]("collect_set"), expression[CountMinSketchAgg]("count_min_sketch"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index e76193fd9422..cc82eede2aca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -431,11 +431,11 @@ object TypeCoercion { case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) - case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType)) + case StddevSamp(funcName, e @ StringType()) => StddevSamp(funcName, Cast(e, DoubleType)) case UnaryMinus(e @ StringType()) => UnaryMinus(Cast(e, DoubleType)) case UnaryPositive(e @ StringType()) => UnaryPositive(Cast(e, DoubleType)) case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType)) - case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType)) + case VarianceSamp(funcName, e @ StringType()) => VarianceSamp(funcName, Cast(e, DoubleType)) case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType)) case Kurtosis(e @ StringType()) => Kurtosis(Cast(e, DoubleType)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 8ce8dfa19c01..51057e45996f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -165,7 +165,7 @@ case class StddevPop(child: Expression) extends CentralMomentAgg(child) { """, since = "1.6.0") // scalastyle:on line.size.limit -case class StddevSamp(child: Expression) extends CentralMomentAgg(child) { +case class StddevSamp(funcName: String, child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 @@ -174,7 +174,7 @@ case class StddevSamp(child: Expression) extends CentralMomentAgg(child) { If(n === 1.0, Double.NaN, sqrt(m2 / (n - 1.0)))) } - override def prettyName: String = "stddev_samp" + override def nodeName: String = funcName } // Compute the population variance of a column @@ -206,7 +206,7 @@ case class VariancePop(child: Expression) extends CentralMomentAgg(child) { 1.0 """, since = "1.6.0") -case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) { +case class VarianceSamp(funcName: String, child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 @@ -215,7 +215,7 @@ case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) { If(n === 1.0, Double.NaN, m2 / (n - 1.0))) } - override def prettyName: String = "var_samp" + override def nodeName: String = funcName } @ExpressionDescription( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index 9f351395846e..87c7d4e11cad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -44,10 +44,10 @@ import org.apache.spark.sql.types._ 5 """, since = "2.0.0") -case class First(child: Expression, ignoreNullsExpr: Expression) +case class First(funcName: String, child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate with ExpectsInputTypes { - def this(child: Expression) = this(child, Literal.create(false, BooleanType)) + def this(child: Expression) = this("first", child, Literal.create(false, BooleanType)) override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil @@ -113,5 +113,5 @@ case class First(child: Expression, ignoreNullsExpr: Expression) override lazy val evaluateExpression: AttributeReference = first - override def toString: String = s"first($child)${if (ignoreNulls) " ignore nulls"}" + override def toString: String = s"$funcName($child)${if (ignoreNulls) " ignore nulls"}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index 405719faaeb5..09c54f5bee78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -44,10 +44,10 @@ import org.apache.spark.sql.types._ 5 """, since = "2.0.0") -case class Last(child: Expression, ignoreNullsExpr: Expression) +case class Last(funcName: String, child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate with ExpectsInputTypes { - def this(child: Expression) = this(child, Literal.create(false, BooleanType)) + def this(child: Expression) = this("last", child, Literal.create(false, BooleanType)) override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil @@ -111,5 +111,5 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) override lazy val evaluateExpression: AttributeReference = last - override def toString: String = s"last($child)${if (ignoreNulls) " ignore nulls"}" + override def toString: String = s"$funcName($child)${if (ignoreNulls) " ignore nulls"}" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 72e9e337c425..723314629117 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -431,7 +431,7 @@ object functions { * @since 2.0.0 */ def first(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction { - new First(e.expr, Literal(ignoreNulls)) + new First("first", e.expr, Literal(ignoreNulls)) } /** @@ -556,7 +556,7 @@ object functions { * @since 2.0.0 */ def last(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction { - new Last(e.expr, Literal(ignoreNulls)) + new Last("last", e.expr, Literal(ignoreNulls)) } /** @@ -675,7 +675,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def stddev(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) } + def stddev(e: Column): Column = withAggregateFunction { StddevSamp("stddev", e.expr) } /** * Aggregate function: alias for `stddev_samp`. @@ -692,7 +692,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def stddev_samp(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) } + def stddev_samp(e: Column): Column = withAggregateFunction { StddevSamp("stddev_samp", e.expr) } /** * Aggregate function: returns the sample standard deviation of @@ -759,7 +759,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def variance(e: Column): Column = withAggregateFunction { VarianceSamp(e.expr) } + def variance(e: Column): Column = withAggregateFunction { VarianceSamp("variance", e.expr) } /** * Aggregate function: alias for `var_samp`. @@ -775,7 +775,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def var_samp(e: Column): Column = withAggregateFunction { VarianceSamp(e.expr) } + def var_samp(e: Column): Column = withAggregateFunction { VarianceSamp("var_samp", e.expr) } /** * Aggregate function: returns the unbiased variance of the values in a group. From 7d5f4be7d7144bfe86c47f40a149a33de044d036 Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Mon, 9 Dec 2019 16:38:29 +0530 Subject: [PATCH 02/32] Fixed errors --- .../optimizer/RewriteDistinctAggregates.scala | 3 ++- .../sql/catalyst/parser/AstBuilder.scala | 4 ++-- .../expressions/aggregate/LastTestSuite.scala | 5 ++-- .../parser/ExpressionParserSuite.scala | 24 +++++++++++-------- .../sql/execution/stat/StatFunctions.scala | 3 ++- 5 files changed, 23 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index b9468007cac6..a63eafa7b7db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -198,7 +198,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // Select the result of the first aggregate in the last aggregate. val result = AggregateExpression( - aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), Literal(true)), + aggregate.First("first", + evalWithinGroup(regularGroupId, operator.toAttribute), Literal(true)), mode = Complete, isDistinct = false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 858870a16141..89d6fceafd30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1510,7 +1510,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging */ override def visitFirst(ctx: FirstContext): Expression = withOrigin(ctx) { val ignoreNullsExpr = ctx.IGNORE != null - First(expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression() + First("first", expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression() } /** @@ -1518,7 +1518,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging */ override def visitLast(ctx: LastContext): Expression = withOrigin(ctx) { val ignoreNullsExpr = ctx.IGNORE != null - Last(expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression() + Last("last", expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression() } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala index ba36bc074e15..82ea9e70bd83 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala @@ -23,8 +23,9 @@ import org.apache.spark.sql.types.IntegerType class LastTestSuite extends SparkFunSuite { val input = AttributeReference("input", IntegerType, nullable = true)() - val evaluator = DeclarativeAggregateEvaluator(Last(input, Literal(false)), Seq(input)) - val evaluatorIgnoreNulls = DeclarativeAggregateEvaluator(Last(input, Literal(true)), Seq(input)) + val evaluator = DeclarativeAggregateEvaluator(Last("last", input, Literal(false)), Seq(input)) + val evaluatorIgnoreNulls = DeclarativeAggregateEvaluator( + Last("last", input, Literal(true)), Seq(input)) test("empty buffer") { assert(evaluator.initialize() === InternalRow(null, false)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 0b694ea95415..361cb8d7c4c6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -787,19 +787,23 @@ class ExpressionParserSuite extends AnalysisTest { } test("SPARK-19526 Support ignore nulls keywords for first and last") { - assertEqual("first(a ignore nulls)", First('a, Literal(true)).toAggregateExpression()) - assertEqual("first(a)", First('a, Literal(false)).toAggregateExpression()) - assertEqual("last(a ignore nulls)", Last('a, Literal(true)).toAggregateExpression()) - assertEqual("last(a)", Last('a, Literal(false)).toAggregateExpression()) + assertEqual("first(a ignore nulls)", First("first", 'a, Literal(true)).toAggregateExpression()) + assertEqual("first(a)", First("first", 'a, Literal(false)).toAggregateExpression()) + assertEqual("last(a ignore nulls)", Last("last", 'a, Literal(true)).toAggregateExpression()) + assertEqual("last(a)", Last("last", 'a, Literal(false)).toAggregateExpression()) } test("Support respect nulls keywords for first_value and last_value") { - assertEqual("first_value(a ignore nulls)", First('a, Literal(true)).toAggregateExpression()) - assertEqual("first_value(a respect nulls)", First('a, Literal(false)).toAggregateExpression()) - assertEqual("first_value(a)", First('a, Literal(false)).toAggregateExpression()) - assertEqual("last_value(a ignore nulls)", Last('a, Literal(true)).toAggregateExpression()) - assertEqual("last_value(a respect nulls)", Last('a, Literal(false)).toAggregateExpression()) - assertEqual("last_value(a)", Last('a, Literal(false)).toAggregateExpression()) + assertEqual("first_value(a ignore nulls)", + First("first", 'a, Literal(true)).toAggregateExpression()) + assertEqual("first_value(a respect nulls)", + First("first", 'a, Literal(false)).toAggregateExpression()) + assertEqual("first_value(a)", First("first", 'a, Literal(false)).toAggregateExpression()) + assertEqual("last_value(a ignore nulls)", + Last("last", 'a, Literal(true)).toAggregateExpression()) + assertEqual("last_value(a respect nulls)", + Last("last", 'a, Literal(false)).toAggregateExpression()) + assertEqual("last_value(a)", Last("last", 'a, Literal(false)).toAggregateExpression()) } test("timestamp literals") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index a6c9c2972df6..833925ae3253 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -254,7 +254,8 @@ object StatFunctions extends Logging { stats.toLowerCase(Locale.ROOT) match { case "count" => (child: Expression) => Count(child).toAggregateExpression() case "mean" => (child: Expression) => Average(child).toAggregateExpression() - case "stddev" => (child: Expression) => StddevSamp(child).toAggregateExpression() + case "stddev" => (child: Expression) => + StddevSamp("stddev_samp", child).toAggregateExpression() case "min" => (child: Expression) => Min(child).toAggregateExpression() case "max" => (child: Expression) => Max(child).toAggregateExpression() case _ => throw new IllegalArgumentException(s"$stats is not a recognised statistic") From 7ba7802f54766dc2200aea9364d04eecde1e02f0 Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Tue, 10 Dec 2019 10:42:41 +0530 Subject: [PATCH 03/32] +constructor First --- .../sql/catalyst/expressions/aggregate/First.scala | 3 +++ .../optimizer/RewriteDistinctAggregates.scala | 3 +-- .../sql/catalyst/parser/ExpressionParserSuite.scala | 11 ++++++----- .../main/scala/org/apache/spark/sql/functions.scala | 2 +- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index 87c7d4e11cad..375ce57d63d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -49,6 +49,9 @@ case class First(funcName: String, child: Expression, ignoreNullsExpr: Expressio def this(child: Expression) = this("first", child, Literal.create(false, BooleanType)) + def this(child: Expression, ignoreNullsExpr: Expression) = + this("first", child, ignoreNullsExpr) + override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index a63eafa7b7db..bf88f8dec7f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -198,8 +198,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // Select the result of the first aggregate in the last aggregate. val result = AggregateExpression( - aggregate.First("first", - evalWithinGroup(regularGroupId, operator.toAttribute), Literal(true)), + new aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), Literal(true)), mode = Complete, isDistinct = false) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 361cb8d7c4c6..ab98cb823047 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -787,18 +787,19 @@ class ExpressionParserSuite extends AnalysisTest { } test("SPARK-19526 Support ignore nulls keywords for first and last") { - assertEqual("first(a ignore nulls)", First("first", 'a, Literal(true)).toAggregateExpression()) - assertEqual("first(a)", First("first", 'a, Literal(false)).toAggregateExpression()) + assertEqual("first(a ignore nulls)", new First('a, Literal(true)).toAggregateExpression()) + assertEqual("first(a)", new First('a, Literal(false)).toAggregateExpression()) assertEqual("last(a ignore nulls)", Last("last", 'a, Literal(true)).toAggregateExpression()) assertEqual("last(a)", Last("last", 'a, Literal(false)).toAggregateExpression()) } test("Support respect nulls keywords for first_value and last_value") { assertEqual("first_value(a ignore nulls)", - First("first", 'a, Literal(true)).toAggregateExpression()) + new First('a, Literal(true)).toAggregateExpression()) assertEqual("first_value(a respect nulls)", - First("first", 'a, Literal(false)).toAggregateExpression()) - assertEqual("first_value(a)", First("first", 'a, Literal(false)).toAggregateExpression()) + new First('a, Literal(false)).toAggregateExpression()) + assertEqual("first_value(a)", + new First('a, Literal(false)).toAggregateExpression()) assertEqual("last_value(a ignore nulls)", Last("last", 'a, Literal(true)).toAggregateExpression()) assertEqual("last_value(a respect nulls)", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 723314629117..6bba1f29898b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -431,7 +431,7 @@ object functions { * @since 2.0.0 */ def first(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction { - new First("first", e.expr, Literal(ignoreNulls)) + new First(e.expr, Literal(ignoreNulls)) } /** From 759262d9752eaff035fb45415bb69b70c9f79b25 Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Tue, 10 Dec 2019 10:52:07 +0530 Subject: [PATCH 04/32] +constructor Last --- .../sql/catalyst/analysis/FunctionRegistry.scala | 4 ++-- .../sql/catalyst/expressions/aggregate/First.scala | 5 ++--- .../sql/catalyst/expressions/aggregate/Last.scala | 4 +++- .../apache/spark/sql/catalyst/parser/AstBuilder.scala | 4 ++-- .../expressions/aggregate/LastTestSuite.scala | 4 ++-- .../sql/catalyst/parser/ExpressionParserSuite.scala | 11 ++++++----- .../main/scala/org/apache/spark/sql/functions.scala | 2 +- 7 files changed, 18 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index b1752275ae86..11686609f0fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -304,12 +304,12 @@ object FunctionRegistry { expression[ApproximatePercentile]("approx_percentile"), expressionWithAlias[StddevSamp]("std"), expressionWithAlias[StddevSamp]("stddev"), - expression[StddevPop]("stddev_pop"), expressionWithAlias[StddevSamp]("stddev_samp"), + expression[StddevPop]("stddev_pop"), expression[Sum]("sum"), expressionWithAlias[VarianceSamp]("variance"), - expression[VariancePop]("var_pop"), expressionWithAlias[VarianceSamp]("var_samp"), + expression[VariancePop]("var_pop"), expression[CollectList]("collect_list"), expression[CollectSet]("collect_set"), expression[CountMinSketchAgg]("count_min_sketch"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index 375ce57d63d6..81f0179a34c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -47,10 +47,9 @@ import org.apache.spark.sql.types._ case class First(funcName: String, child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate with ExpectsInputTypes { - def this(child: Expression) = this("first", child, Literal.create(false, BooleanType)) + def this(child: Expression) = this(child, Literal.create(false, BooleanType)) - def this(child: Expression, ignoreNullsExpr: Expression) = - this("first", child, ignoreNullsExpr) + def this(child: Expression, ignoreNullsExpr: Expression) = this("first", child, ignoreNullsExpr) override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index 09c54f5bee78..5b7262997045 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -47,7 +47,9 @@ import org.apache.spark.sql.types._ case class Last(funcName: String, child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate with ExpectsInputTypes { - def this(child: Expression) = this("last", child, Literal.create(false, BooleanType)) + def this(child: Expression) = this(child, Literal.create(false, BooleanType)) + + def this(child: Expression, ignoreNullsExpr: Expression) = this("last", child, ignoreNullsExpr) override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 89d6fceafd30..f84599d4e76d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1510,7 +1510,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging */ override def visitFirst(ctx: FirstContext): Expression = withOrigin(ctx) { val ignoreNullsExpr = ctx.IGNORE != null - First("first", expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression() + new First(expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression() } /** @@ -1518,7 +1518,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging */ override def visitLast(ctx: LastContext): Expression = withOrigin(ctx) { val ignoreNullsExpr = ctx.IGNORE != null - Last("last", expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression() + new Last(expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression() } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala index 82ea9e70bd83..ccf8b1be4e58 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala @@ -23,9 +23,9 @@ import org.apache.spark.sql.types.IntegerType class LastTestSuite extends SparkFunSuite { val input = AttributeReference("input", IntegerType, nullable = true)() - val evaluator = DeclarativeAggregateEvaluator(Last("last", input, Literal(false)), Seq(input)) + val evaluator = DeclarativeAggregateEvaluator(new Last(input, Literal(false)), Seq(input)) val evaluatorIgnoreNulls = DeclarativeAggregateEvaluator( - Last("last", input, Literal(true)), Seq(input)) + new Last(input, Literal(true)), Seq(input)) test("empty buffer") { assert(evaluator.initialize() === InternalRow(null, false)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index ab98cb823047..2c813c84ecd5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -789,8 +789,8 @@ class ExpressionParserSuite extends AnalysisTest { test("SPARK-19526 Support ignore nulls keywords for first and last") { assertEqual("first(a ignore nulls)", new First('a, Literal(true)).toAggregateExpression()) assertEqual("first(a)", new First('a, Literal(false)).toAggregateExpression()) - assertEqual("last(a ignore nulls)", Last("last", 'a, Literal(true)).toAggregateExpression()) - assertEqual("last(a)", Last("last", 'a, Literal(false)).toAggregateExpression()) + assertEqual("last(a ignore nulls)", new Last('a, Literal(true)).toAggregateExpression()) + assertEqual("last(a)", new Last('a, Literal(false)).toAggregateExpression()) } test("Support respect nulls keywords for first_value and last_value") { @@ -801,10 +801,11 @@ class ExpressionParserSuite extends AnalysisTest { assertEqual("first_value(a)", new First('a, Literal(false)).toAggregateExpression()) assertEqual("last_value(a ignore nulls)", - Last("last", 'a, Literal(true)).toAggregateExpression()) + new Last('a, Literal(true)).toAggregateExpression()) assertEqual("last_value(a respect nulls)", - Last("last", 'a, Literal(false)).toAggregateExpression()) - assertEqual("last_value(a)", Last("last", 'a, Literal(false)).toAggregateExpression()) + new Last('a, Literal(false)).toAggregateExpression()) + assertEqual("last_value(a)", + new Last('a, Literal(false)).toAggregateExpression()) } test("timestamp literals") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6bba1f29898b..9b4752540c91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -556,7 +556,7 @@ object functions { * @since 2.0.0 */ def last(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction { - new Last("last", e.expr, Literal(ignoreNulls)) + new Last(e.expr, Literal(ignoreNulls)) } /** From 5ec101fa794c99a13f250f228f92d6620014d013 Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Tue, 10 Dec 2019 11:13:49 +0530 Subject: [PATCH 05/32] Fixed build error --- .../spark/sql/catalyst/expressions/aggregate/First.scala | 4 ++-- .../spark/sql/catalyst/expressions/aggregate/Last.scala | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index 81f0179a34c2..21c72fab9049 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -47,10 +47,10 @@ import org.apache.spark.sql.types._ case class First(funcName: String, child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate with ExpectsInputTypes { - def this(child: Expression) = this(child, Literal.create(false, BooleanType)) - def this(child: Expression, ignoreNullsExpr: Expression) = this("first", child, ignoreNullsExpr) + def this(child: Expression) = this(child, Literal.create(false, BooleanType)) + override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index 5b7262997045..90a7796bf0cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -47,10 +47,10 @@ import org.apache.spark.sql.types._ case class Last(funcName: String, child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate with ExpectsInputTypes { - def this(child: Expression) = this(child, Literal.create(false, BooleanType)) - def this(child: Expression, ignoreNullsExpr: Expression) = this("last", child, ignoreNullsExpr) + def this(child: Expression) = this(child, Literal.create(false, BooleanType)) + override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil override def nullable: Boolean = true From f210bb94eb25a6cfe59818f16baa57bf77153143 Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Tue, 10 Dec 2019 21:15:34 +0530 Subject: [PATCH 06/32] Fixed TC --- .../spark/sql/catalyst/analysis/FunctionRegistry.scala | 1 - .../spark/sql/catalyst/expressions/aggregate/First.scala | 7 ++++++- .../spark/sql/catalyst/expressions/aggregate/Last.scala | 3 +++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 11686609f0fc..e8e1de05b88e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -635,7 +635,6 @@ object FunctionRegistry { (implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = { val constructors = tag.runtimeClass.getConstructors .filter(_.getParameterTypes.head == classOf[String]) - assert(constructors.length == 1) val builder = (expressions: Seq[Expression]) => { val params = classOf[String] +: Seq.fill(expressions.size)(classOf[Expression]) val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index 21c72fab9049..1a763c456f2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -51,6 +51,9 @@ case class First(funcName: String, child: Expression, ignoreNullsExpr: Expressio def this(child: Expression) = this(child, Literal.create(false, BooleanType)) + def this(funcName: String, child: Expression) = + this(funcName, child, Literal.create(false, BooleanType)) + override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil override def nullable: Boolean = true @@ -115,5 +118,7 @@ case class First(funcName: String, child: Expression, ignoreNullsExpr: Expressio override lazy val evaluateExpression: AttributeReference = first - override def toString: String = s"$funcName($child)${if (ignoreNulls) " ignore nulls"}" + override def nodeName: String = funcName + + override def toString: String = s"$prettyName($child)${if (ignoreNulls) " ignore nulls"}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index 90a7796bf0cf..748a0a5b62ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -51,6 +51,9 @@ case class Last(funcName: String, child: Expression, ignoreNullsExpr: Expression def this(child: Expression) = this(child, Literal.create(false, BooleanType)) + def this(funcName: String, child: Expression) = + this(funcName, child, Literal.create(false, BooleanType)) + override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil override def nullable: Boolean = true From cc234b95c203e516b768163bab2776a171b51bc5 Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Tue, 10 Dec 2019 21:25:06 +0530 Subject: [PATCH 07/32] ScalaStyle Fix --- .../apache/spark/sql/catalyst/expressions/aggregate/Last.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index 748a0a5b62ec..0e84d2915810 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -53,7 +53,7 @@ case class Last(funcName: String, child: Expression, ignoreNullsExpr: Expression def this(funcName: String, child: Expression) = this(funcName, child, Literal.create(false, BooleanType)) - + override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil override def nullable: Boolean = true From 7617ec3db3db632c2d6771b96787ebea7482b884 Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Wed, 11 Dec 2019 08:58:51 +0530 Subject: [PATCH 08/32] Fixed TC --- .../resources/sql-tests/results/group-by.sql.out | 2 +- .../results/postgreSQL/window_part4.sql.out | 16 ++++++++-------- .../udf/postgreSQL/udf-aggregates_part1.sql.out | 4 ++-- .../sql-tests/results/udf/udf-group-by.sql.out | 2 +- .../org/apache/spark/sql/SQLQueryTestSuite.scala | 1 + 5 files changed, 13 insertions(+), 12 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index ed5ced8c8c0f..62a166649708 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -128,7 +128,7 @@ NULL 1 SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a), AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a) FROM testData -- !query 13 schema -struct +struct -- !query 13 output -0.2723801058145729 -1.5069204152249134 1 3 2.142857142857143 0.8095238095238094 0.8997354108424372 15 7 diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out index 54ceacd3b3b3..4721ceb03a96 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out @@ -241,7 +241,7 @@ NaN SELECT VARIANCE(n) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) FROM (VALUES(1,600),(2,470),(3,170),(4,430),(5,300)) r(i,n) -- !query 19 schema -struct +struct -- !query 19 output 16900.0 18491.666666666668 @@ -254,7 +254,7 @@ NaN SELECT VARIANCE(n) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) FROM (VALUES(1,600),(2,470),(3,170),(4,430),(5,300)) r(i,n) -- !query 20 schema -struct +struct -- !query 20 output 16900.0 18491.666666666668 @@ -267,7 +267,7 @@ NaN SELECT VARIANCE(n) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) FROM (VALUES(1,600),(2,470),(3,170),(4,430),(5,300)) r(i,n) -- !query 21 schema -struct +struct -- !query 21 output 16900.0 18491.666666666668 @@ -280,7 +280,7 @@ NaN SELECT VARIANCE(n) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) FROM (VALUES(1,600),(2,470),(3,170),(4,430),(5,300)) r(i,n) -- !query 22 schema -struct +struct -- !query 22 output 16900.0 18491.666666666668 @@ -405,7 +405,7 @@ NaN SELECT STDDEV(n) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) FROM (VALUES(0,NULL),(1,600),(2,470),(3,170),(4,430),(5,300)) r(i,n) -- !query 31 schema -struct +struct -- !query 31 output 130.0 135.9840676942217 @@ -419,7 +419,7 @@ NaN SELECT STDDEV(n) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) FROM (VALUES(0,NULL),(1,600),(2,470),(3,170),(4,430),(5,300)) r(i,n) -- !query 32 schema -struct +struct -- !query 32 output 130.0 135.9840676942217 @@ -433,7 +433,7 @@ NaN SELECT STDDEV(n) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) FROM (VALUES(0,NULL),(1,600),(2,470),(3,170),(4,430),(5,300)) r(i,n) -- !query 33 schema -struct +struct -- !query 33 output 130.0 135.9840676942217 @@ -447,7 +447,7 @@ NaN SELECT STDDEV(n) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) FROM (VALUES(0,NULL),(1,600),(2,470),(3,170),(4,430),(5,300)) r(i,n) -- !query 34 schema -struct +struct -- !query 34 output 130.0 135.9840676942217 diff --git a/sql/core/src/test/resources/sql-tests/results/udf/postgreSQL/udf-aggregates_part1.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/postgreSQL/udf-aggregates_part1.sql.out index a2f64717d73a..83cb34bbc17c 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/postgreSQL/udf-aggregates_part1.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/postgreSQL/udf-aggregates_part1.sql.out @@ -85,7 +85,7 @@ struct -- !query 10 SELECT udf(stddev_samp(b)) FROM aggtest -- !query 10 schema -struct +struct -- !query 10 output 151.38936080399804 @@ -101,7 +101,7 @@ struct -- !query 12 SELECT udf(var_samp(b)) FROM aggtest -- !query 12 schema -struct +struct -- !query 12 output 22918.738564643096 diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out index bdacd184158a..37d7a0fc55a5 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out @@ -128,7 +128,7 @@ NULL 1 SELECT SKEWNESS(udf(a)), udf(KURTOSIS(a)), udf(MIN(a)), MAX(udf(a)), udf(AVG(udf(a))), udf(VARIANCE(a)), STDDEV(udf(a)), udf(SUM(a)), udf(COUNT(a)) FROM testData -- !query 13 schema -struct +struct -- !query 13 output -0.2723801058145729 -1.5069204152249134 1 3 2.142857142857143 0.8095238095238094 0.8997354108424372 15 7 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index e6dcf0b86308..b0e866eef942 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -126,6 +126,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession { // 1. Maven can't get correct resource directory when resources in other jars. // 2. We test subclasses in the hive-thriftserver module. val sparkHome = { + sys.props.put("spark.test.home", "/home/root1/spark") assert(sys.props.contains("spark.test.home") || sys.env.contains("SPARK_HOME"), "spark.test.home or SPARK_HOME is not set.") sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME")) From 92e381b00d884107416a87854f95d33601b8a8fd Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Wed, 11 Dec 2019 09:19:46 +0530 Subject: [PATCH 09/32] Reduce constructors --- .../spark/sql/catalyst/dsl/package.scala | 4 ++-- .../expressions/aggregate/First.scala | 13 ++++++++---- .../catalyst/expressions/aggregate/Last.scala | 13 ++++++++---- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../optimizer/RewriteDistinctAggregates.scala | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 4 ++-- .../expressions/aggregate/LastTestSuite.scala | 4 ++-- .../optimizer/ReplaceOperatorSuite.scala | 2 +- .../parser/ExpressionParserSuite.scala | 20 +++++++++---------- .../org/apache/spark/sql/functions.scala | 4 ++-- 10 files changed, 39 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index b4a8bafe22df..34822f617bac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -173,8 +173,8 @@ package object dsl { def approxCountDistinct(e: Expression, rsd: Double = 0.05): Expression = HyperLogLogPlusPlus(e, rsd).toAggregateExpression() def avg(e: Expression): Expression = Average(e).toAggregateExpression() - def first(e: Expression): Expression = new First(e).toAggregateExpression() - def last(e: Expression): Expression = new Last(e).toAggregateExpression() + def first(e: Expression): Expression = First(e).toAggregateExpression() + def last(e: Expression): Expression = Last(e).toAggregateExpression() def min(e: Expression): Expression = Min(e).toAggregateExpression() def minDistinct(e: Expression): Expression = Min(e).toAggregateExpression(isDistinct = true) def max(e: Expression): Expression = Max(e).toAggregateExpression() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index 1a763c456f2d..e6dcaecb212e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -47,10 +47,6 @@ import org.apache.spark.sql.types._ case class First(funcName: String, child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate with ExpectsInputTypes { - def this(child: Expression, ignoreNullsExpr: Expression) = this("first", child, ignoreNullsExpr) - - def this(child: Expression) = this(child, Literal.create(false, BooleanType)) - def this(funcName: String, child: Expression) = this(funcName, child, Literal.create(false, BooleanType)) @@ -122,3 +118,12 @@ case class First(funcName: String, child: Expression, ignoreNullsExpr: Expressio override def toString: String = s"$prettyName($child)${if (ignoreNulls) " ignore nulls"}" } + +object First { + + def apply(child: Expression, ignoreNullsExpr: Expression): First = + First("first", child, ignoreNullsExpr) + + def apply(child: Expression): First = + First("first", child) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index 0e84d2915810..7bf6e4e6bfbb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -47,10 +47,6 @@ import org.apache.spark.sql.types._ case class Last(funcName: String, child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate with ExpectsInputTypes { - def this(child: Expression, ignoreNullsExpr: Expression) = this("last", child, ignoreNullsExpr) - - def this(child: Expression) = this(child, Literal.create(false, BooleanType)) - def this(funcName: String, child: Expression) = this(funcName, child, Literal.create(false, BooleanType)) @@ -118,3 +114,12 @@ case class Last(funcName: String, child: Expression, ignoreNullsExpr: Expression override def toString: String = s"$funcName($child)${if (ignoreNulls) " ignore nulls"}" } + +object Last { + + def apply(child: Expression, ignoreNullsExpr: Expression): Last = + Last("last", child, ignoreNullsExpr) + + def apply(child: Expression): Last = + Last("last", child) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 05fd5e35e22a..72f0c1afa3c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1539,7 +1539,7 @@ object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] { if (keyExprIds.contains(attr.exprId)) { attr } else { - Alias(new First(attr).toAggregateExpression(), attr.name)(attr.exprId) + Alias(First(attr).toAggregateExpression(), attr.name)(attr.exprId) } } // SPARK-22951: Physical aggregate operators distinguishes global aggregation and grouping diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index bf88f8dec7f8..b9468007cac6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -198,7 +198,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // Select the result of the first aggregate in the last aggregate. val result = AggregateExpression( - new aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), Literal(true)), + aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), Literal(true)), mode = Complete, isDistinct = false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index f84599d4e76d..858870a16141 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1510,7 +1510,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging */ override def visitFirst(ctx: FirstContext): Expression = withOrigin(ctx) { val ignoreNullsExpr = ctx.IGNORE != null - new First(expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression() + First(expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression() } /** @@ -1518,7 +1518,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging */ override def visitLast(ctx: LastContext): Expression = withOrigin(ctx) { val ignoreNullsExpr = ctx.IGNORE != null - new Last(expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression() + Last(expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression() } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala index ccf8b1be4e58..9e706b7137fe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala @@ -23,9 +23,9 @@ import org.apache.spark.sql.types.IntegerType class LastTestSuite extends SparkFunSuite { val input = AttributeReference("input", IntegerType, nullable = true)() - val evaluator = DeclarativeAggregateEvaluator(new Last(input, Literal(false)), Seq(input)) + val evaluator = DeclarativeAggregateEvaluator(Last(input, Literal(false)), Seq(input)) val evaluatorIgnoreNulls = DeclarativeAggregateEvaluator( - new Last(input, Literal(true)), Seq(input)) + Last(input, Literal(true)), Seq(input)) test("empty buffer") { assert(evaluator.initialize() === InternalRow(null, false)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index 9bf864f5201f..6bf3e3b94ecc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -203,7 +203,7 @@ class ReplaceOperatorSuite extends PlanTest { Seq(attrA), Seq( attrA, - Alias(new First(attrB).toAggregateExpression(), attrB.name)(attrB.exprId) + Alias(First(attrB).toAggregateExpression(), attrB.name)(attrB.exprId) ), input) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 2c813c84ecd5..9a50a10a471a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -787,25 +787,25 @@ class ExpressionParserSuite extends AnalysisTest { } test("SPARK-19526 Support ignore nulls keywords for first and last") { - assertEqual("first(a ignore nulls)", new First('a, Literal(true)).toAggregateExpression()) - assertEqual("first(a)", new First('a, Literal(false)).toAggregateExpression()) - assertEqual("last(a ignore nulls)", new Last('a, Literal(true)).toAggregateExpression()) - assertEqual("last(a)", new Last('a, Literal(false)).toAggregateExpression()) + assertEqual("first(a ignore nulls)", First('a, Literal(true)).toAggregateExpression()) + assertEqual("first(a)", First('a, Literal(false)).toAggregateExpression()) + assertEqual("last(a ignore nulls)", Last('a, Literal(true)).toAggregateExpression()) + assertEqual("last(a)", Last('a, Literal(false)).toAggregateExpression()) } test("Support respect nulls keywords for first_value and last_value") { assertEqual("first_value(a ignore nulls)", - new First('a, Literal(true)).toAggregateExpression()) + First('a, Literal(true)).toAggregateExpression()) assertEqual("first_value(a respect nulls)", - new First('a, Literal(false)).toAggregateExpression()) + First('a, Literal(false)).toAggregateExpression()) assertEqual("first_value(a)", - new First('a, Literal(false)).toAggregateExpression()) + First('a, Literal(false)).toAggregateExpression()) assertEqual("last_value(a ignore nulls)", - new Last('a, Literal(true)).toAggregateExpression()) + Last('a, Literal(true)).toAggregateExpression()) assertEqual("last_value(a respect nulls)", - new Last('a, Literal(false)).toAggregateExpression()) + Last('a, Literal(false)).toAggregateExpression()) assertEqual("last_value(a)", - new Last('a, Literal(false)).toAggregateExpression()) + Last('a, Literal(false)).toAggregateExpression()) } test("timestamp literals") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 9b4752540c91..b52ba2474a47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -431,7 +431,7 @@ object functions { * @since 2.0.0 */ def first(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction { - new First(e.expr, Literal(ignoreNulls)) + First(e.expr, Literal(ignoreNulls)) } /** @@ -556,7 +556,7 @@ object functions { * @since 2.0.0 */ def last(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction { - new Last(e.expr, Literal(ignoreNulls)) + Last(e.expr, Literal(ignoreNulls)) } /** From a2d75de6d4ee273fbc451effdf4b6a9303232085 Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Wed, 11 Dec 2019 11:55:06 +0530 Subject: [PATCH 10/32] nit --- .../src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index b0e866eef942..e6dcf0b86308 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -126,7 +126,6 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession { // 1. Maven can't get correct resource directory when resources in other jars. // 2. We test subclasses in the hive-thriftserver module. val sparkHome = { - sys.props.put("spark.test.home", "/home/root1/spark") assert(sys.props.contains("spark.test.home") || sys.env.contains("SPARK_HOME"), "spark.test.home or SPARK_HOME is not set.") sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME")) From 3ef62af9c8449d39bcb2c1f9655de4f57920c584 Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Wed, 11 Dec 2019 14:27:15 +0530 Subject: [PATCH 11/32] Removed unnecessary changes --- .../expressions/aggregate/LastTestSuite.scala | 3 +-- .../parser/ExpressionParserSuite.scala | 18 ++++++------------ .../sql/execution/stat/StatFunctions.scala | 2 +- 3 files changed, 8 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala index 9e706b7137fe..ba36bc074e15 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala @@ -24,8 +24,7 @@ import org.apache.spark.sql.types.IntegerType class LastTestSuite extends SparkFunSuite { val input = AttributeReference("input", IntegerType, nullable = true)() val evaluator = DeclarativeAggregateEvaluator(Last(input, Literal(false)), Seq(input)) - val evaluatorIgnoreNulls = DeclarativeAggregateEvaluator( - Last(input, Literal(true)), Seq(input)) + val evaluatorIgnoreNulls = DeclarativeAggregateEvaluator(Last(input, Literal(true)), Seq(input)) test("empty buffer") { assert(evaluator.initialize() === InternalRow(null, false)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 9a50a10a471a..0b694ea95415 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -794,18 +794,12 @@ class ExpressionParserSuite extends AnalysisTest { } test("Support respect nulls keywords for first_value and last_value") { - assertEqual("first_value(a ignore nulls)", - First('a, Literal(true)).toAggregateExpression()) - assertEqual("first_value(a respect nulls)", - First('a, Literal(false)).toAggregateExpression()) - assertEqual("first_value(a)", - First('a, Literal(false)).toAggregateExpression()) - assertEqual("last_value(a ignore nulls)", - Last('a, Literal(true)).toAggregateExpression()) - assertEqual("last_value(a respect nulls)", - Last('a, Literal(false)).toAggregateExpression()) - assertEqual("last_value(a)", - Last('a, Literal(false)).toAggregateExpression()) + assertEqual("first_value(a ignore nulls)", First('a, Literal(true)).toAggregateExpression()) + assertEqual("first_value(a respect nulls)", First('a, Literal(false)).toAggregateExpression()) + assertEqual("first_value(a)", First('a, Literal(false)).toAggregateExpression()) + assertEqual("last_value(a ignore nulls)", Last('a, Literal(true)).toAggregateExpression()) + assertEqual("last_value(a respect nulls)", Last('a, Literal(false)).toAggregateExpression()) + assertEqual("last_value(a)", Last('a, Literal(false)).toAggregateExpression()) } test("timestamp literals") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 833925ae3253..4ddf35cbb3ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -255,7 +255,7 @@ object StatFunctions extends Logging { case "count" => (child: Expression) => Count(child).toAggregateExpression() case "mean" => (child: Expression) => Average(child).toAggregateExpression() case "stddev" => (child: Expression) => - StddevSamp("stddev_samp", child).toAggregateExpression() + StddevSamp("stddev", child).toAggregateExpression() case "min" => (child: Expression) => Min(child).toAggregateExpression() case "max" => (child: Expression) => Max(child).toAggregateExpression() case _ => throw new IllegalArgumentException(s"$stats is not a recognised statistic") From 87c6ea3edb3c5a249c5d2f392a0171a2b1592b44 Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Wed, 11 Dec 2019 16:58:04 +0530 Subject: [PATCH 12/32] Fixed TC --- .../spark/sql/catalyst/expressions/aggregate/First.scala | 4 ++-- .../spark/sql/catalyst/expressions/aggregate/Last.scala | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index e6dcaecb212e..ffaa695c214a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -122,8 +122,8 @@ case class First(funcName: String, child: Expression, ignoreNullsExpr: Expressio object First { def apply(child: Expression, ignoreNullsExpr: Expression): First = - First("first", child, ignoreNullsExpr) + new First("first", child, ignoreNullsExpr) def apply(child: Expression): First = - First("first", child) + new First("first", child) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index 7bf6e4e6bfbb..fb42e9d594eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -118,8 +118,8 @@ case class Last(funcName: String, child: Expression, ignoreNullsExpr: Expression object Last { def apply(child: Expression, ignoreNullsExpr: Expression): Last = - Last("last", child, ignoreNullsExpr) + new Last("last", child, ignoreNullsExpr) def apply(child: Expression): Last = - Last("last", child) + new Last("last", child) } From 94805329761ca8a12cc8183b2644f5b869461562 Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Wed, 11 Dec 2019 19:02:38 +0530 Subject: [PATCH 13/32] Override flatArguments in VarianceSamp, StddevSamp --- .../sql/catalyst/expressions/aggregate/CentralMomentAgg.scala | 4 ++++ .../results/udf/postgreSQL/udf-aggregates_part1.sql.out | 4 ++-- .../test/resources/sql-tests/results/udf/udf-group-by.sql.out | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 51057e45996f..b7686f85a016 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -175,6 +175,8 @@ case class StddevSamp(funcName: String, child: Expression) extends CentralMoment } override def nodeName: String = funcName + + override def flatArguments: Iterator[Any] = Iterator(child) } // Compute the population variance of a column @@ -216,6 +218,8 @@ case class VarianceSamp(funcName: String, child: Expression) extends CentralMome } override def nodeName: String = funcName + + override def flatArguments: Iterator[Any] = Iterator(child) } @ExpressionDescription( diff --git a/sql/core/src/test/resources/sql-tests/results/udf/postgreSQL/udf-aggregates_part1.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/postgreSQL/udf-aggregates_part1.sql.out index 83cb34bbc17c..a2f64717d73a 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/postgreSQL/udf-aggregates_part1.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/postgreSQL/udf-aggregates_part1.sql.out @@ -85,7 +85,7 @@ struct -- !query 10 SELECT udf(stddev_samp(b)) FROM aggtest -- !query 10 schema -struct +struct -- !query 10 output 151.38936080399804 @@ -101,7 +101,7 @@ struct -- !query 12 SELECT udf(var_samp(b)) FROM aggtest -- !query 12 schema -struct +struct -- !query 12 output 22918.738564643096 diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out index 37d7a0fc55a5..8a396cc60e97 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out @@ -128,7 +128,7 @@ NULL 1 SELECT SKEWNESS(udf(a)), udf(KURTOSIS(a)), udf(MIN(a)), MAX(udf(a)), udf(AVG(udf(a))), udf(VARIANCE(a)), STDDEV(udf(a)), udf(SUM(a)), udf(COUNT(a)) FROM testData -- !query 13 schema -struct +struct -- !query 13 output -0.2723801058145729 -1.5069204152249134 1 3 2.142857142857143 0.8095238095238094 0.8997354108424372 15 7 From 5c540fc5cc2c149ccc7fd48296a1aa7acd9078de Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Thu, 12 Dec 2019 07:37:27 +0530 Subject: [PATCH 14/32] override flatArguments in First, Last --- .../spark/sql/catalyst/expressions/aggregate/First.scala | 2 ++ .../spark/sql/catalyst/expressions/aggregate/Last.scala | 6 +++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index ffaa695c214a..af8b2c0b1fda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -117,6 +117,8 @@ case class First(funcName: String, child: Expression, ignoreNullsExpr: Expressio override def nodeName: String = funcName override def toString: String = s"$prettyName($child)${if (ignoreNulls) " ignore nulls"}" + + override def flatArguments: Iterator[Any] = Iterator(child, ignoreNullsExpr) } object First { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index fb42e9d594eb..535a5f72f17f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -112,7 +112,11 @@ case class Last(funcName: String, child: Expression, ignoreNullsExpr: Expression override lazy val evaluateExpression: AttributeReference = last - override def toString: String = s"$funcName($child)${if (ignoreNulls) " ignore nulls"}" + override def nodeName: String = funcName + + override def toString: String = s"$prettyName($child)${if (ignoreNulls) " ignore nulls"}" + + override def flatArguments: Iterator[Any] = Iterator(child, ignoreNullsExpr) } object Last { From d780dfc8660c5f78b683b4baf810908ec2a143ca Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Thu, 12 Dec 2019 07:42:04 +0530 Subject: [PATCH 15/32] override flatArguments in BoolAnd, BoolOr --- .../expressions/aggregate/UnevaluableAggs.scala | 2 ++ .../sql-tests/results/udf/udf-group-by.sql.out | 10 +++++----- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala index acb0af0248a7..6a94933f3b02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala @@ -38,6 +38,8 @@ abstract class UnevaluableBooleanAggBase(arg: Expression) case _ => TypeCheckResult.TypeCheckSuccess } } + + override def flatArguments: Iterator[Any] = Iterator(arg) } @ExpressionDescription( diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out index 8a396cc60e97..a835740a6a86 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out @@ -293,7 +293,7 @@ struct<> -- !query 31 SELECT udf(every(v)), udf(some(v)), any(v) FROM test_agg WHERE 1 = 0 -- !query 31 schema -struct +struct -- !query 31 output NULL NULL NULL @@ -301,7 +301,7 @@ NULL NULL NULL -- !query 32 SELECT udf(every(udf(v))), some(v), any(v) FROM test_agg WHERE k = 4 -- !query 32 schema -struct +struct -- !query 32 output NULL NULL NULL @@ -309,7 +309,7 @@ NULL NULL NULL -- !query 33 SELECT every(v), udf(some(v)), any(v) FROM test_agg WHERE k = 5 -- !query 33 schema -struct +struct -- !query 33 output false true true @@ -317,7 +317,7 @@ false true true -- !query 34 SELECT udf(k), every(v), udf(some(v)), any(v) FROM test_agg GROUP BY udf(k) -- !query 34 schema -struct +struct -- !query 34 output 1 false true true 2 true true true @@ -339,7 +339,7 @@ struct -- !query 36 SELECT udf(k), udf(every(v)) FROM test_agg GROUP BY udf(k) HAVING every(v) IS NULL -- !query 36 schema -struct +struct -- !query 36 output 4 NULL From 85d95977bc71b987148412a08b74a06576c7bb63 Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Thu, 12 Dec 2019 15:42:06 +0530 Subject: [PATCH 16/32] add assert() --- .../apache/spark/sql/catalyst/analysis/FunctionRegistry.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index e8e1de05b88e..77e48522b8f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -635,6 +635,9 @@ object FunctionRegistry { (implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = { val constructors = tag.runtimeClass.getConstructors .filter(_.getParameterTypes.head == classOf[String]) + assert(constructors.length >= 1, + s"there is no constructor for ${tag.runtimeClass} " + + "which takes String as first argument") val builder = (expressions: Seq[Expression]) => { val params = classOf[String] +: Seq.fill(expressions.size)(classOf[Expression]) val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse { From a71e8a711f801bdc4d73d21b27e76e5652ac3ef3 Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Fri, 13 Dec 2019 18:13:30 +0530 Subject: [PATCH 17/32] expressionWithAlias for Average, ApproximatePercentile & override nodeName, flatArguments --- .../spark/sql/catalyst/analysis/FunctionRegistry.scala | 8 ++++---- .../expressions/aggregate/ApproximatePercentile.scala | 9 +++++++-- .../sql/catalyst/expressions/aggregate/Average.scala | 8 ++++++-- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 77e48522b8f2..d1a0717de3d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -282,7 +282,8 @@ object FunctionRegistry { // aggregate functions expression[HyperLogLogPlusPlus]("approx_count_distinct"), - expression[Average]("avg"), + expressionWithAlias[Average]("avg"), + expressionWithAlias[Average]("mean"), expression[Corr]("corr"), expression[Count]("count"), expression[CountIf]("count_if"), @@ -295,13 +296,12 @@ object FunctionRegistry { expressionWithAlias[Last]("last_value"), expression[Max]("max"), expression[MaxBy]("max_by"), - expression[Average]("mean"), expression[Min]("min"), expression[MinBy]("min_by"), expression[Percentile]("percentile"), expression[Skewness]("skewness"), - expression[ApproximatePercentile]("percentile_approx"), - expression[ApproximatePercentile]("approx_percentile"), + expressionWithAlias[ApproximatePercentile]("percentile_approx"), + expressionWithAlias[ApproximatePercentile]("approx_percentile"), expressionWithAlias[StddevSamp]("std"), expressionWithAlias[StddevSamp]("stddev"), expressionWithAlias[StddevSamp]("stddev_samp"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index ea0ed2e8fa11..7efb3b99445c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -67,6 +67,7 @@ import org.apache.spark.sql.types._ """, since = "2.1.0") case class ApproximatePercentile( + funcName: String, child: Expression, percentageExpression: Expression, accuracyExpression: Expression, @@ -75,7 +76,7 @@ case class ApproximatePercentile( extends TypedImperativeAggregate[PercentileDigest] with ImplicitCastInputTypes { def this(child: Expression, percentageExpression: Expression, accuracyExpression: Expression) = { - this(child, percentageExpression, accuracyExpression, 0, 0) + this("percentile_approx", child, percentageExpression, accuracyExpression, 0, 0) } def this(child: Expression, percentageExpression: Expression) = { @@ -185,7 +186,7 @@ case class ApproximatePercentile( if (returnPercentileArray) ArrayType(child.dataType, false) else child.dataType } - override def prettyName: String = "percentile_approx" + override def nodeName: String = funcName override def serialize(obj: PercentileDigest): Array[Byte] = { ApproximatePercentile.serializer.serialize(obj) @@ -194,6 +195,10 @@ case class ApproximatePercentile( override def deserialize(bytes: Array[Byte]): PercentileDigest = { ApproximatePercentile.serializer.deserialize(bytes) } + + override def flatArguments: Iterator[Any] = + Iterator(child, percentageExpression, accuracyExpression, + mutableAggBufferOffset, inputAggBufferOffset) } object ApproximatePercentile { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index aaad3c7bcefa..df680d01cd25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -35,9 +35,11 @@ import org.apache.spark.sql.types._ -3 days -11 hours -59 minutes -59 seconds """, since = "1.0.0") -case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { +case class Average( + funcName: String, child: Expression) + extends DeclarativeAggregate with ImplicitCastInputTypes { - override def prettyName: String = "avg" + override def nodeName: String = funcName override def children: Seq[Expression] = child :: Nil @@ -93,4 +95,6 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit coalesce(child.cast(sumDataType), Literal.default(sumDataType))), /* count = */ If(child.isNull, count, count + 1L) ) + + override def flatArguments: Iterator[Any] = Iterator(child) } From dd2d85d18bccc3b5fd79a3bf569058cb03bd02c6 Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Fri, 13 Dec 2019 20:02:04 +0530 Subject: [PATCH 18/32] Fixes for latest update --- .../spark/sql/catalyst/analysis/TypeCoercion.scala | 14 +++++++------- .../catalyst/expressions/aggregate/Average.scala | 4 ++++ .../spark/sql/catalyst/optimizer/Optimizer.scala | 8 ++++---- .../spark/sql/RelationalGroupedDataset.scala | 4 ++-- 4 files changed, 17 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index cc82eede2aca..763365f04965 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -429,7 +429,7 @@ object TypeCoercion { case Abs(e @ StringType()) => Abs(Cast(e, DoubleType)) case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) - case Average(e @ StringType()) => Average(Cast(e, DoubleType)) + case Average(funcName, e @ StringType()) => Average(funcName, Cast(e, DoubleType)) case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) case StddevSamp(funcName, e @ StringType()) => StddevSamp(funcName, Cast(e, DoubleType)) case UnaryMinus(e @ StringType()) => UnaryMinus(Cast(e, DoubleType)) @@ -613,15 +613,15 @@ object TypeCoercion { case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType)) case Sum(e @ FractionalType()) if e.dataType != DoubleType => Sum(Cast(e, DoubleType)) - case s @ Average(e @ DecimalType()) => s // Decimal is already the biggest. - case Average(e @ IntegralType()) if e.dataType != LongType => - Average(Cast(e, LongType)) - case Average(e @ FractionalType()) if e.dataType != DoubleType => - Average(Cast(e, DoubleType)) + case s @ Average(_, DecimalType()) => s // Decimal is already the biggest. + case Average(funcName, e @ IntegralType()) if e.dataType != LongType => + Average(funcName, Cast(e, LongType)) + case Average(funcName, e @ FractionalType()) if e.dataType != DoubleType => + Average(funcName, Cast(e, DoubleType)) // Hive lets you do aggregation of timestamps... for some reason case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType)) - case Average(e @ TimestampType()) => Average(Cast(e, DoubleType)) + case Average(funcName, e @ TimestampType()) => Average(funcName, Cast(e, DoubleType)) // Coalesce should return the first non-null value, which could be any column // from the list. So we need to make sure the return type is deterministic and diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index df680d01cd25..94e99b46d044 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -98,3 +98,7 @@ case class Average( override def flatArguments: Iterator[Any] = Iterator(child) } + +object Average{ + def apply(child: Expression): Average = Average("avg", child) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 72f0c1afa3c3..a4f7201f99a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1464,9 +1464,9 @@ object DecimalAggregates extends Rule[LogicalPlan] { MakeDecimal(we.copy(windowFunction = ae.copy(aggregateFunction = Sum(UnscaledValue(e)))), prec + 10, scale) - case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => + case Average(f, e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => val newAggExpr = - we.copy(windowFunction = ae.copy(aggregateFunction = Average(UnscaledValue(e)))) + we.copy(windowFunction = ae.copy(aggregateFunction = Average(f, UnscaledValue(e)))) Cast( Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), DecimalType(prec + 4, scale + 4), Option(SQLConf.get.sessionLocalTimeZone)) @@ -1477,8 +1477,8 @@ object DecimalAggregates extends Rule[LogicalPlan] { case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale) - case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => - val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e))) + case Average(f, e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => + val newAggExpr = ae.copy(aggregateFunction = Average(f, UnscaledValue(e))) Cast( Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), DecimalType(prec + 4, scale + 4), Option(SQLConf.get.sessionLocalTimeZone)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 52bd0ecb1fff..54655b743c7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -278,7 +278,7 @@ class RelationalGroupedDataset protected[sql]( */ @scala.annotation.varargs def mean(colNames: String*): DataFrame = { - aggregateNumericOrIntervalColumns(colNames : _*)(Average) + aggregateNumericOrIntervalColumns(colNames : _*)(Average.apply) } /** @@ -302,7 +302,7 @@ class RelationalGroupedDataset protected[sql]( */ @scala.annotation.varargs def avg(colNames: String*): DataFrame = { - aggregateNumericOrIntervalColumns(colNames : _*)(Average) + aggregateNumericOrIntervalColumns(colNames : _*)(Average.apply) } /** From c1b3afb3efde1504257a9af20910a11b1f0cf503 Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Sat, 14 Dec 2019 21:48:55 +0530 Subject: [PATCH 19/32] Fix ApproximatePercentile TC --- .../aggregate/ApproximatePercentile.scala | 31 +++++++++++++++--- .../ApproximatePercentileSuite.scala | 32 +++++++++---------- .../sql/execution/command/CommandUtils.scala | 2 +- .../sql/execution/stat/StatFunctions.scala | 2 +- 4 files changed, 45 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 7efb3b99445c..799315cfbac7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -75,12 +75,17 @@ case class ApproximatePercentile( override val inputAggBufferOffset: Int) extends TypedImperativeAggregate[PercentileDigest] with ImplicitCastInputTypes { - def this(child: Expression, percentageExpression: Expression, accuracyExpression: Expression) = { - this("percentile_approx", child, percentageExpression, accuracyExpression, 0, 0) + def this( + funcName: String, + child: Expression, + percentageExpression: Expression, + accuracyExpression: Expression) = { + this(funcName, child, percentageExpression, accuracyExpression, 0, 0) } - def this(child: Expression, percentageExpression: Expression) = { - this(child, percentageExpression, Literal(ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY)) + def this(funcName: String, child: Expression, percentageExpression: Expression) = { + this(funcName, child, percentageExpression, + Literal(ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY)) } // Mark as lazy so that accuracyExpression is not evaluated during tree transformation. @@ -326,4 +331,22 @@ object ApproximatePercentile { } val serializer: PercentileDigestSerializer = new PercentileDigestSerializer + + def apply( + child: Expression, + percentageExpression: Expression, + accuracyExpression: Expression, + mutableAggBufferOffset: Int, inputAggBufferOffset: Int): ApproximatePercentile = + new ApproximatePercentile("percentile_approx", child, percentageExpression, + accuracyExpression, mutableAggBufferOffset, inputAggBufferOffset) + + def apply( + child: Expression, + percentageExpression: Expression, + accuracyExpression: Expression): ApproximatePercentile = { + new ApproximatePercentile("percentile_approx", child, percentageExpression, accuracyExpression) + } + + def apply(child: Expression, percentageExpression: Expression): ApproximatePercentile = + new ApproximatePercentile("percentile_approx", child, percentageExpression) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala index 84b3cc79cef5..bad3f24178af 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala @@ -53,7 +53,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { } assert(compareEquals(buffer, serializer.deserialize(serializer.serialize(buffer)))) - val agg = new ApproximatePercentile(BoundReference(0, DoubleType, true), Literal(0.5)) + val agg = ApproximatePercentile(BoundReference(0, DoubleType, true), Literal(0.5)) assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer)) } @@ -103,7 +103,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { val childExpression = Cast(BoundReference(0, IntegerType, nullable = false), DoubleType) val percentageExpression = CreateArray(percentages.toSeq.map(Literal(_))) val accuracyExpression = Literal(10000) - val agg = new ApproximatePercentile(childExpression, percentageExpression, accuracyExpression) + val agg = ApproximatePercentile(childExpression, percentageExpression, accuracyExpression) assert(agg.nullable) val group1 = (0 until data.length / 2) @@ -140,7 +140,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { val percentage = 0.5D // Phase one, partial mode aggregation - val agg = new ApproximatePercentile(childExpression, Literal(percentage)) + val agg = ApproximatePercentile(childExpression, Literal(percentage)) .withNewInputAggBufferOffset(inputAggregationBufferOffset) .withNewMutableAggBufferOffset(mutableAggregationBufferOffset) @@ -170,12 +170,12 @@ class ApproximatePercentileSuite extends SparkFunSuite { // sql, single percentile assertEqual( s"percentile_approx(`a`, 0.5D, $defaultAccuracy)", - new ApproximatePercentile("a".attr, percentageExpression = Literal(0.5D)).sql: String) + ApproximatePercentile("a".attr, percentageExpression = Literal(0.5D)).sql: String) // sql, array of percentile assertEqual( s"percentile_approx(`a`, array(0.25D, 0.5D, 0.75D), $defaultAccuracy)", - new ApproximatePercentile( + ApproximatePercentile( "a".attr, percentageExpression = CreateArray(Seq(0.25D, 0.5D, 0.75D).map(Literal(_))) ).sql: String) @@ -183,13 +183,13 @@ class ApproximatePercentileSuite extends SparkFunSuite { // sql(isDistinct = false), single percentile assertEqual( s"percentile_approx(`a`, 0.5D, $defaultAccuracy)", - new ApproximatePercentile("a".attr, percentageExpression = Literal(0.5D)) + ApproximatePercentile("a".attr, percentageExpression = Literal(0.5D)) .sql(isDistinct = false)) // sql(isDistinct = false), array of percentile assertEqual( s"percentile_approx(`a`, array(0.25D, 0.5D, 0.75D), $defaultAccuracy)", - new ApproximatePercentile( + ApproximatePercentile( "a".attr, percentageExpression = CreateArray(Seq(0.25D, 0.5D, 0.75D).map(Literal(_))) ).sql(isDistinct = false)) @@ -197,13 +197,13 @@ class ApproximatePercentileSuite extends SparkFunSuite { // sql(isDistinct = true), single percentile assertEqual( s"percentile_approx(DISTINCT `a`, 0.5D, $defaultAccuracy)", - new ApproximatePercentile("a".attr, percentageExpression = Literal(0.5D)) + ApproximatePercentile("a".attr, percentageExpression = Literal(0.5D)) .sql(isDistinct = true)) // sql(isDistinct = true), array of percentile assertEqual( s"percentile_approx(DISTINCT `a`, array(0.25D, 0.5D, 0.75D), $defaultAccuracy)", - new ApproximatePercentile( + ApproximatePercentile( "a".attr, percentageExpression = CreateArray(Seq(0.25D, 0.5D, 0.75D).map(Literal(_))) ).sql(isDistinct = true)) @@ -211,7 +211,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { test("class ApproximatePercentile, fails analysis if percentage or accuracy is not a constant") { val attribute = AttributeReference("a", DoubleType)() - val wrongAccuracy = new ApproximatePercentile( + val wrongAccuracy = ApproximatePercentile( attribute, percentageExpression = Literal(0.5D), accuracyExpression = AttributeReference("b", IntegerType)()) @@ -221,7 +221,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { TypeCheckFailure("The accuracy or percentage provided must be a constant literal") ) - val wrongPercentage = new ApproximatePercentile( + val wrongPercentage = ApproximatePercentile( attribute, percentageExpression = attribute, accuracyExpression = Literal(10000)) @@ -233,7 +233,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { } test("class ApproximatePercentile, fails analysis if parameters are invalid") { - val wrongAccuracy = new ApproximatePercentile( + val wrongAccuracy = ApproximatePercentile( AttributeReference("a", DoubleType)(), percentageExpression = Literal(0.5D), accuracyExpression = Literal(-1)) @@ -249,7 +249,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { CreateArray(Seq(0D, 1D, 0.5D).map(Literal(_))) ) correctPercentageExpresions.foreach { percentageExpression => - val correctPercentage = new ApproximatePercentile( + val correctPercentage = ApproximatePercentile( AttributeReference("a", DoubleType)(), percentageExpression = percentageExpression, accuracyExpression = Literal(100)) @@ -265,7 +265,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { ) wrongPercentageExpressions.foreach { percentageExpression => - val wrongPercentage = new ApproximatePercentile( + val wrongPercentage = ApproximatePercentile( AttributeReference("a", DoubleType)(), percentageExpression = percentageExpression, accuracyExpression = Literal(100)) @@ -289,7 +289,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { accuracyExpressions.foreach { accuracyExpression => percentageExpressions.foreach { percentageExpression => - val agg = new ApproximatePercentile( + val agg = ApproximatePercentile( UnresolvedAttribute("a"), percentageExpression, accuracyExpression) @@ -309,7 +309,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { test("class ApproximatePercentile, null handling") { val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType) - val agg = new ApproximatePercentile(childExpression, Literal(0.5D)) + val agg = ApproximatePercentile(childExpression, Literal(0.5D)) val buffer = new GenericInternalRow(new Array[Any](1)) agg.initialize(buffer) // Empty aggregation buffer diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala index b644e6dc471d..422dde5d0fc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala @@ -214,7 +214,7 @@ object CommandUtils extends Logging { val namedExprs = attrsToGenHistogram.map { attr => val aggFunc = - new ApproximatePercentile(attr, Literal(percentiles), Literal(conf.percentileAccuracy)) + ApproximatePercentile(attr, Literal(percentiles), Literal(conf.percentileAccuracy)) val expr = aggFunc.toAggregateExpression() Alias(expr, expr.toString)() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 4ddf35cbb3ed..e62e64fbe858 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -248,7 +248,7 @@ object StatFunctions extends Logging { percentileIndex += 1 (child: Expression) => GetArrayItem( - new ApproximatePercentile(child, Literal.create(percentiles)).toAggregateExpression(), + ApproximatePercentile(child, Literal.create(percentiles)).toAggregateExpression(), Literal(index)) } else { stats.toLowerCase(Locale.ROOT) match { From e7a4e904baab4963d0c73aff5be7730a20bc08e3 Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Sat, 14 Dec 2019 22:37:18 +0530 Subject: [PATCH 20/32] UT fix --- .../execution/benchmark/ObjectHashAggregateExecBenchmark.scala | 2 +- .../spark/sql/hive/execution/ObjectHashAggregateSuite.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala index c475c7b21ab9..5787af91199a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala @@ -208,7 +208,7 @@ object ObjectHashAggregateExecBenchmark extends SqlBasedBenchmark { private def percentile_approx( column: Column, percentage: Double, isDistinct: Boolean = false): Column = { - val approxPercentile = new ApproximatePercentile(column.expr, Literal(percentage)) + val approxPercentile = ApproximatePercentile(column.expr, Literal(percentage)) Column(approxPercentile.toAggregateExpression(isDistinct)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala index 930f80146749..b6ef34dcfd43 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala @@ -175,7 +175,7 @@ class ObjectHashAggregateSuite private def percentile_approx( column: Column, percentage: Double, isDistinct: Boolean = false): Column = { - val approxPercentile = new ApproximatePercentile(column.expr, Literal(percentage)) + val approxPercentile = ApproximatePercentile(column.expr, Literal(percentage)) Column(approxPercentile.toAggregateExpression(isDistinct)) } From ca886f0ff7e470cd584411845a323dc7b8e52962 Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Wed, 18 Dec 2019 15:08:39 +0530 Subject: [PATCH 21/32] nit --- .../aggregate/ApproximatePercentile.scala | 22 +++++++++---------- .../expressions/aggregate/Average.scala | 2 +- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 799315cfbac7..1d3c99913349 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -76,10 +76,10 @@ case class ApproximatePercentile( extends TypedImperativeAggregate[PercentileDigest] with ImplicitCastInputTypes { def this( - funcName: String, - child: Expression, - percentageExpression: Expression, - accuracyExpression: Expression) = { + funcName: String, + child: Expression, + percentageExpression: Expression, + accuracyExpression: Expression) = { this(funcName, child, percentageExpression, accuracyExpression, 0, 0) } @@ -333,17 +333,17 @@ object ApproximatePercentile { val serializer: PercentileDigestSerializer = new PercentileDigestSerializer def apply( - child: Expression, - percentageExpression: Expression, - accuracyExpression: Expression, - mutableAggBufferOffset: Int, inputAggBufferOffset: Int): ApproximatePercentile = + child: Expression, + percentageExpression: Expression, + accuracyExpression: Expression, + mutableAggBufferOffset: Int, inputAggBufferOffset: Int): ApproximatePercentile = new ApproximatePercentile("percentile_approx", child, percentageExpression, accuracyExpression, mutableAggBufferOffset, inputAggBufferOffset) def apply( - child: Expression, - percentageExpression: Expression, - accuracyExpression: Expression): ApproximatePercentile = { + child: Expression, + percentageExpression: Expression, + accuracyExpression: Expression): ApproximatePercentile = { new ApproximatePercentile("percentile_approx", child, percentageExpression, accuracyExpression) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 94e99b46d044..d0e4071ce422 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.types._ """, since = "1.0.0") case class Average( - funcName: String, child: Expression) + funcName: String, child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { override def nodeName: String = funcName From 125cfacb9c7f029793a7e462f43bcb6fe4a925f5 Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Wed, 18 Dec 2019 23:11:00 +0530 Subject: [PATCH 22/32] expressionWithTreeNodeTag for ApproximatePercentile --- .../catalyst/analysis/FunctionRegistry.scala | 64 ++++++++++++++++++- .../sql/catalyst/expressions/Expression.scala | 4 +- .../aggregate/ApproximatePercentile.scala | 37 ++--------- .../ApproximatePercentileSuite.scala | 32 +++++----- .../sql/execution/command/CommandUtils.scala | 2 +- .../sql/execution/stat/StatFunctions.scala | 2 +- 6 files changed, 88 insertions(+), 53 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d1a0717de3d8..19067f3a8fec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -300,8 +300,8 @@ object FunctionRegistry { expression[MinBy]("min_by"), expression[Percentile]("percentile"), expression[Skewness]("skewness"), - expressionWithAlias[ApproximatePercentile]("percentile_approx"), - expressionWithAlias[ApproximatePercentile]("approx_percentile"), + expressionWithTNT[ApproximatePercentile]("percentile_approx"), + expressionWithTNT[ApproximatePercentile]("approx_percentile"), expressionWithAlias[StddevSamp]("std"), expressionWithAlias[StddevSamp]("stddev"), expressionWithAlias[StddevSamp]("stddev_samp"), @@ -669,6 +669,66 @@ object FunctionRegistry { (name, (expressionInfo[T](name), builder)) } + private def expressionWithTNT[T <: Expression](name: String) + (implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = { + + // For `RuntimeReplaceable`, skip the constructor with most arguments, which is the main + // constructor and contains non-parameter `child` and should not be used as function builder. + val constructors = if (classOf[RuntimeReplaceable].isAssignableFrom(tag.runtimeClass)) { + val all = tag.runtimeClass.getConstructors + val maxNumArgs = all.map(_.getParameterCount).max + all.filterNot(_.getParameterCount == maxNumArgs) + } else { + tag.runtimeClass.getConstructors + } + // See if we can find a constructor that accepts Seq[Expression] + val varargCtor = constructors.find(_.getParameterTypes.toSeq == Seq(classOf[Seq[_]])) + val builder = (expressions: Seq[Expression]) => { + if (varargCtor.isDefined) { + // If there is an apply method that accepts Seq[Expression], use that one. + try { + varargCtor.get.newInstance(expressions).asInstanceOf[Expression] + } catch { + // the exception is an invocation exception. To get a meaningful message, we need the + // cause. + case e: Exception => throw new AnalysisException(e.getCause.getMessage) + } + } else { + // Otherwise, find a constructor method that matches the number of arguments, and use that. + val params = Seq.fill(expressions.size)(classOf[Expression]) + val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse { + val validParametersCount = constructors + .filter(_.getParameterTypes.forall(_ == classOf[Expression])) + .map(_.getParameterCount).distinct.sorted + val invalidArgumentsMsg = if (validParametersCount.length == 0) { + s"Invalid arguments for function $name" + } else { + val expectedNumberOfParameters = if (validParametersCount.length == 1) { + validParametersCount.head.toString + } else { + validParametersCount.init.mkString("one of ", ", ", " and ") + + validParametersCount.last + } + s"Invalid number of arguments for function $name. " + + s"Expected: $expectedNumberOfParameters; Found: ${params.length}" + } + throw new AnalysisException(invalidArgumentsMsg) + } + try { + val exp = f.newInstance(expressions : _*).asInstanceOf[Expression] + exp.setTagValue(exp.FUNC_ALIAS, name) + exp + } catch { + // the exception is an invocation exception. To get a meaningful message, we need the + // cause. + case e: Exception => throw new AnalysisException(e.getCause.getMessage) + } + } + } + + (name, (expressionInfo[T](name), builder)) + } + /** * Creates a function registry lookup entry for cast aliases (SPARK-16730). * For example, if name is "int", and dataType is IntegerType, this means int(x) would become diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 4632957e7afd..c47893455636 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.trees.{TreeNode, TreeNodeTag} import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -286,6 +286,8 @@ abstract class Expression extends TreeNode[Expression] { override def simpleStringWithNodeId(): String = { throw new UnsupportedOperationException(s"$nodeName does not implement simpleStringWithNodeId") } + + val FUNC_ALIAS = TreeNodeTag[String]("functionAliasName") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 1d3c99913349..fb60fb90ad00 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -67,7 +67,6 @@ import org.apache.spark.sql.types._ """, since = "2.1.0") case class ApproximatePercentile( - funcName: String, child: Expression, percentageExpression: Expression, accuracyExpression: Expression, @@ -75,17 +74,12 @@ case class ApproximatePercentile( override val inputAggBufferOffset: Int) extends TypedImperativeAggregate[PercentileDigest] with ImplicitCastInputTypes { - def this( - funcName: String, - child: Expression, - percentageExpression: Expression, - accuracyExpression: Expression) = { - this(funcName, child, percentageExpression, accuracyExpression, 0, 0) + def this(child: Expression, percentageExpression: Expression, accuracyExpression: Expression) = { + this(child, percentageExpression, accuracyExpression, 0, 0) } - def this(funcName: String, child: Expression, percentageExpression: Expression) = { - this(funcName, child, percentageExpression, - Literal(ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY)) + def this(child: Expression, percentageExpression: Expression) = { + this(child, percentageExpression, Literal(ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY)) } // Mark as lazy so that accuracyExpression is not evaluated during tree transformation. @@ -191,7 +185,7 @@ case class ApproximatePercentile( if (returnPercentileArray) ArrayType(child.dataType, false) else child.dataType } - override def nodeName: String = funcName + override def nodeName: String = getTagValue(FUNC_ALIAS).getOrElse("percentile_approx") override def serialize(obj: PercentileDigest): Array[Byte] = { ApproximatePercentile.serializer.serialize(obj) @@ -200,10 +194,6 @@ case class ApproximatePercentile( override def deserialize(bytes: Array[Byte]): PercentileDigest = { ApproximatePercentile.serializer.deserialize(bytes) } - - override def flatArguments: Iterator[Any] = - Iterator(child, percentageExpression, accuracyExpression, - mutableAggBufferOffset, inputAggBufferOffset) } object ApproximatePercentile { @@ -332,21 +322,4 @@ object ApproximatePercentile { val serializer: PercentileDigestSerializer = new PercentileDigestSerializer - def apply( - child: Expression, - percentageExpression: Expression, - accuracyExpression: Expression, - mutableAggBufferOffset: Int, inputAggBufferOffset: Int): ApproximatePercentile = - new ApproximatePercentile("percentile_approx", child, percentageExpression, - accuracyExpression, mutableAggBufferOffset, inputAggBufferOffset) - - def apply( - child: Expression, - percentageExpression: Expression, - accuracyExpression: Expression): ApproximatePercentile = { - new ApproximatePercentile("percentile_approx", child, percentageExpression, accuracyExpression) - } - - def apply(child: Expression, percentageExpression: Expression): ApproximatePercentile = - new ApproximatePercentile("percentile_approx", child, percentageExpression) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala index bad3f24178af..84b3cc79cef5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala @@ -53,7 +53,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { } assert(compareEquals(buffer, serializer.deserialize(serializer.serialize(buffer)))) - val agg = ApproximatePercentile(BoundReference(0, DoubleType, true), Literal(0.5)) + val agg = new ApproximatePercentile(BoundReference(0, DoubleType, true), Literal(0.5)) assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer)) } @@ -103,7 +103,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { val childExpression = Cast(BoundReference(0, IntegerType, nullable = false), DoubleType) val percentageExpression = CreateArray(percentages.toSeq.map(Literal(_))) val accuracyExpression = Literal(10000) - val agg = ApproximatePercentile(childExpression, percentageExpression, accuracyExpression) + val agg = new ApproximatePercentile(childExpression, percentageExpression, accuracyExpression) assert(agg.nullable) val group1 = (0 until data.length / 2) @@ -140,7 +140,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { val percentage = 0.5D // Phase one, partial mode aggregation - val agg = ApproximatePercentile(childExpression, Literal(percentage)) + val agg = new ApproximatePercentile(childExpression, Literal(percentage)) .withNewInputAggBufferOffset(inputAggregationBufferOffset) .withNewMutableAggBufferOffset(mutableAggregationBufferOffset) @@ -170,12 +170,12 @@ class ApproximatePercentileSuite extends SparkFunSuite { // sql, single percentile assertEqual( s"percentile_approx(`a`, 0.5D, $defaultAccuracy)", - ApproximatePercentile("a".attr, percentageExpression = Literal(0.5D)).sql: String) + new ApproximatePercentile("a".attr, percentageExpression = Literal(0.5D)).sql: String) // sql, array of percentile assertEqual( s"percentile_approx(`a`, array(0.25D, 0.5D, 0.75D), $defaultAccuracy)", - ApproximatePercentile( + new ApproximatePercentile( "a".attr, percentageExpression = CreateArray(Seq(0.25D, 0.5D, 0.75D).map(Literal(_))) ).sql: String) @@ -183,13 +183,13 @@ class ApproximatePercentileSuite extends SparkFunSuite { // sql(isDistinct = false), single percentile assertEqual( s"percentile_approx(`a`, 0.5D, $defaultAccuracy)", - ApproximatePercentile("a".attr, percentageExpression = Literal(0.5D)) + new ApproximatePercentile("a".attr, percentageExpression = Literal(0.5D)) .sql(isDistinct = false)) // sql(isDistinct = false), array of percentile assertEqual( s"percentile_approx(`a`, array(0.25D, 0.5D, 0.75D), $defaultAccuracy)", - ApproximatePercentile( + new ApproximatePercentile( "a".attr, percentageExpression = CreateArray(Seq(0.25D, 0.5D, 0.75D).map(Literal(_))) ).sql(isDistinct = false)) @@ -197,13 +197,13 @@ class ApproximatePercentileSuite extends SparkFunSuite { // sql(isDistinct = true), single percentile assertEqual( s"percentile_approx(DISTINCT `a`, 0.5D, $defaultAccuracy)", - ApproximatePercentile("a".attr, percentageExpression = Literal(0.5D)) + new ApproximatePercentile("a".attr, percentageExpression = Literal(0.5D)) .sql(isDistinct = true)) // sql(isDistinct = true), array of percentile assertEqual( s"percentile_approx(DISTINCT `a`, array(0.25D, 0.5D, 0.75D), $defaultAccuracy)", - ApproximatePercentile( + new ApproximatePercentile( "a".attr, percentageExpression = CreateArray(Seq(0.25D, 0.5D, 0.75D).map(Literal(_))) ).sql(isDistinct = true)) @@ -211,7 +211,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { test("class ApproximatePercentile, fails analysis if percentage or accuracy is not a constant") { val attribute = AttributeReference("a", DoubleType)() - val wrongAccuracy = ApproximatePercentile( + val wrongAccuracy = new ApproximatePercentile( attribute, percentageExpression = Literal(0.5D), accuracyExpression = AttributeReference("b", IntegerType)()) @@ -221,7 +221,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { TypeCheckFailure("The accuracy or percentage provided must be a constant literal") ) - val wrongPercentage = ApproximatePercentile( + val wrongPercentage = new ApproximatePercentile( attribute, percentageExpression = attribute, accuracyExpression = Literal(10000)) @@ -233,7 +233,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { } test("class ApproximatePercentile, fails analysis if parameters are invalid") { - val wrongAccuracy = ApproximatePercentile( + val wrongAccuracy = new ApproximatePercentile( AttributeReference("a", DoubleType)(), percentageExpression = Literal(0.5D), accuracyExpression = Literal(-1)) @@ -249,7 +249,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { CreateArray(Seq(0D, 1D, 0.5D).map(Literal(_))) ) correctPercentageExpresions.foreach { percentageExpression => - val correctPercentage = ApproximatePercentile( + val correctPercentage = new ApproximatePercentile( AttributeReference("a", DoubleType)(), percentageExpression = percentageExpression, accuracyExpression = Literal(100)) @@ -265,7 +265,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { ) wrongPercentageExpressions.foreach { percentageExpression => - val wrongPercentage = ApproximatePercentile( + val wrongPercentage = new ApproximatePercentile( AttributeReference("a", DoubleType)(), percentageExpression = percentageExpression, accuracyExpression = Literal(100)) @@ -289,7 +289,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { accuracyExpressions.foreach { accuracyExpression => percentageExpressions.foreach { percentageExpression => - val agg = ApproximatePercentile( + val agg = new ApproximatePercentile( UnresolvedAttribute("a"), percentageExpression, accuracyExpression) @@ -309,7 +309,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { test("class ApproximatePercentile, null handling") { val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType) - val agg = ApproximatePercentile(childExpression, Literal(0.5D)) + val agg = new ApproximatePercentile(childExpression, Literal(0.5D)) val buffer = new GenericInternalRow(new Array[Any](1)) agg.initialize(buffer) // Empty aggregation buffer diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala index 422dde5d0fc1..b644e6dc471d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala @@ -214,7 +214,7 @@ object CommandUtils extends Logging { val namedExprs = attrsToGenHistogram.map { attr => val aggFunc = - ApproximatePercentile(attr, Literal(percentiles), Literal(conf.percentileAccuracy)) + new ApproximatePercentile(attr, Literal(percentiles), Literal(conf.percentileAccuracy)) val expr = aggFunc.toAggregateExpression() Alias(expr, expr.toString)() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index e62e64fbe858..4ddf35cbb3ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -248,7 +248,7 @@ object StatFunctions extends Logging { percentileIndex += 1 (child: Expression) => GetArrayItem( - ApproximatePercentile(child, Literal.create(percentiles)).toAggregateExpression(), + new ApproximatePercentile(child, Literal.create(percentiles)).toAggregateExpression(), Literal(index)) } else { stats.toLowerCase(Locale.ROOT) match { From 4ca20f488711fab43401c3a8a10516a9aa87b7a2 Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Wed, 18 Dec 2019 23:35:23 +0530 Subject: [PATCH 23/32] expressionWithTreeNodeTag for BoolAnd, BoolOr, StddevSamp and VarianceSamp --- .../catalyst/analysis/FunctionRegistry.scala | 20 +++++++++---------- .../sql/catalyst/analysis/TypeCoercion.scala | 4 ++-- .../aggregate/CentralMomentAgg.scala | 12 ++++------- .../aggregate/UnevaluableAggs.scala | 10 ++++------ .../ExpressionTypeCheckingSuite.scala | 4 ++-- .../sql/execution/stat/StatFunctions.scala | 3 +-- .../org/apache/spark/sql/functions.scala | 8 ++++---- 7 files changed, 27 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 19067f3a8fec..eb8abb33168e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -302,22 +302,22 @@ object FunctionRegistry { expression[Skewness]("skewness"), expressionWithTNT[ApproximatePercentile]("percentile_approx"), expressionWithTNT[ApproximatePercentile]("approx_percentile"), - expressionWithAlias[StddevSamp]("std"), - expressionWithAlias[StddevSamp]("stddev"), - expressionWithAlias[StddevSamp]("stddev_samp"), + expressionWithTNT[StddevSamp]("std"), + expressionWithTNT[StddevSamp]("stddev"), + expressionWithTNT[StddevSamp]("stddev_samp"), expression[StddevPop]("stddev_pop"), expression[Sum]("sum"), - expressionWithAlias[VarianceSamp]("variance"), - expressionWithAlias[VarianceSamp]("var_samp"), + expressionWithTNT[VarianceSamp]("variance"), + expressionWithTNT[VarianceSamp]("var_samp"), expression[VariancePop]("var_pop"), expression[CollectList]("collect_list"), expression[CollectSet]("collect_set"), expression[CountMinSketchAgg]("count_min_sketch"), - expressionWithAlias[BoolAnd]("every"), - expressionWithAlias[BoolAnd]("bool_and"), - expressionWithAlias[BoolOr]("any"), - expressionWithAlias[BoolOr]("some"), - expressionWithAlias[BoolOr]("bool_or"), + expressionWithTNT[BoolAnd]("every"), + expressionWithTNT[BoolAnd]("bool_and"), + expressionWithTNT[BoolOr]("any"), + expressionWithTNT[BoolOr]("some"), + expressionWithTNT[BoolOr]("bool_or"), // string functions expression[Ascii]("ascii"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 763365f04965..eab00b03c729 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -431,11 +431,11 @@ object TypeCoercion { case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(funcName, e @ StringType()) => Average(funcName, Cast(e, DoubleType)) case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) - case StddevSamp(funcName, e @ StringType()) => StddevSamp(funcName, Cast(e, DoubleType)) + case StddevSamp( e @ StringType()) => StddevSamp(Cast(e, DoubleType)) case UnaryMinus(e @ StringType()) => UnaryMinus(Cast(e, DoubleType)) case UnaryPositive(e @ StringType()) => UnaryPositive(Cast(e, DoubleType)) case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType)) - case VarianceSamp(funcName, e @ StringType()) => VarianceSamp(funcName, Cast(e, DoubleType)) + case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType)) case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType)) case Kurtosis(e @ StringType()) => Kurtosis(Cast(e, DoubleType)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index b7686f85a016..6c9ec6c19a45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -165,7 +165,7 @@ case class StddevPop(child: Expression) extends CentralMomentAgg(child) { """, since = "1.6.0") // scalastyle:on line.size.limit -case class StddevSamp(funcName: String, child: Expression) extends CentralMomentAgg(child) { +case class StddevSamp(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 @@ -174,9 +174,7 @@ case class StddevSamp(funcName: String, child: Expression) extends CentralMoment If(n === 1.0, Double.NaN, sqrt(m2 / (n - 1.0)))) } - override def nodeName: String = funcName - - override def flatArguments: Iterator[Any] = Iterator(child) + override def nodeName: String = getTagValue(FUNC_ALIAS).getOrElse("stddev_samp") } // Compute the population variance of a column @@ -208,7 +206,7 @@ case class VariancePop(child: Expression) extends CentralMomentAgg(child) { 1.0 """, since = "1.6.0") -case class VarianceSamp(funcName: String, child: Expression) extends CentralMomentAgg(child) { +case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 @@ -217,9 +215,7 @@ case class VarianceSamp(funcName: String, child: Expression) extends CentralMome If(n === 1.0, Double.NaN, m2 / (n - 1.0))) } - override def nodeName: String = funcName - - override def flatArguments: Iterator[Any] = Iterator(child) + override def nodeName: String = getTagValue(FUNC_ALIAS).getOrElse("var_samp") } @ExpressionDescription( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala index 6a94933f3b02..11a2d7ca016c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala @@ -38,8 +38,6 @@ abstract class UnevaluableBooleanAggBase(arg: Expression) case _ => TypeCheckResult.TypeCheckSuccess } } - - override def flatArguments: Iterator[Any] = Iterator(arg) } @ExpressionDescription( @@ -54,8 +52,8 @@ abstract class UnevaluableBooleanAggBase(arg: Expression) false """, since = "3.0.0") -case class BoolAnd(funcName: String, arg: Expression) extends UnevaluableBooleanAggBase(arg) { - override def nodeName: String = funcName +case class BoolAnd(arg: Expression) extends UnevaluableBooleanAggBase(arg) { + override def nodeName: String = getTagValue(FUNC_ALIAS).getOrElse("bool_and") } @ExpressionDescription( @@ -70,6 +68,6 @@ case class BoolAnd(funcName: String, arg: Expression) extends UnevaluableBoolean false """, since = "3.0.0") -case class BoolOr(funcName: String, arg: Expression) extends UnevaluableBooleanAggBase(arg) { - override def nodeName: String = funcName +case class BoolOr(arg: Expression) extends UnevaluableBooleanAggBase(arg) { + override def nodeName: String = getTagValue(FUNC_ALIAS).getOrElse("bool_or") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index c83759e8f4c1..feb927264ba6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -144,8 +144,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(Sum('stringField)) assertSuccess(Average('stringField)) assertSuccess(Min('arrayField)) - assertSuccess(new BoolAnd("bool_and", 'booleanField)) - assertSuccess(new BoolOr("bool_or", 'booleanField)) + assertSuccess(new BoolAnd('booleanField)) + assertSuccess(new BoolOr('booleanField)) assertError(Min('mapField), "min does not support ordering on type") assertError(Max('mapField), "max does not support ordering on type") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 4ddf35cbb3ed..a6c9c2972df6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -254,8 +254,7 @@ object StatFunctions extends Logging { stats.toLowerCase(Locale.ROOT) match { case "count" => (child: Expression) => Count(child).toAggregateExpression() case "mean" => (child: Expression) => Average(child).toAggregateExpression() - case "stddev" => (child: Expression) => - StddevSamp("stddev", child).toAggregateExpression() + case "stddev" => (child: Expression) => StddevSamp(child).toAggregateExpression() case "min" => (child: Expression) => Min(child).toAggregateExpression() case "max" => (child: Expression) => Max(child).toAggregateExpression() case _ => throw new IllegalArgumentException(s"$stats is not a recognised statistic") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b52ba2474a47..8b4bee4bb3b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -675,7 +675,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def stddev(e: Column): Column = withAggregateFunction { StddevSamp("stddev", e.expr) } + def stddev(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) } /** * Aggregate function: alias for `stddev_samp`. @@ -692,7 +692,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def stddev_samp(e: Column): Column = withAggregateFunction { StddevSamp("stddev_samp", e.expr) } + def stddev_samp(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) } /** * Aggregate function: returns the sample standard deviation of @@ -759,7 +759,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def variance(e: Column): Column = withAggregateFunction { VarianceSamp("variance", e.expr) } + def variance(e: Column): Column = withAggregateFunction { VarianceSamp(e.expr) } /** * Aggregate function: alias for `var_samp`. @@ -775,7 +775,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def var_samp(e: Column): Column = withAggregateFunction { VarianceSamp("var_samp", e.expr) } + def var_samp(e: Column): Column = withAggregateFunction { VarianceSamp(e.expr) } /** * Aggregate function: returns the unbiased variance of the values in a group. From aecdd8a1e20aee18b0f5a371512b46027393a88b Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Wed, 18 Dec 2019 23:57:55 +0530 Subject: [PATCH 24/32] expressionWithTreeNodeTag for First, Last and Average --- .../spark/sql/catalyst/analysis/Analyzer.scala | 6 ++---- .../catalyst/analysis/FunctionRegistry.scala | 12 ++++++------ .../sql/catalyst/analysis/TypeCoercion.scala | 14 +++++++------- .../spark/sql/catalyst/dsl/package.scala | 4 ++-- .../expressions/aggregate/Average.scala | 11 ++--------- .../catalyst/expressions/aggregate/First.scala | 18 +++--------------- .../catalyst/expressions/aggregate/Last.scala | 18 +++--------------- .../sql/catalyst/optimizer/Optimizer.scala | 10 +++++----- .../optimizer/ReplaceOperatorSuite.scala | 2 +- 9 files changed, 31 insertions(+), 64 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 4a90e25d57d0..c878bac25942 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -692,10 +692,8 @@ class Analyzer( // Assumption is the aggregate function ignores nulls. This is true for all current // AggregateFunction's with the exception of First and Last in their default mode // (which we handle) and possibly some Hive UDAF's. - case First(funcName, expr, _) => - First(funcName, ifExpr(expr), Literal(true)) - case Last(funcName, expr, _) => - Last(funcName, ifExpr(expr), Literal(true)) + case First(expr, _) => First(ifExpr(expr), Literal(true)) + case Last(expr, _) => Last(ifExpr(expr), Literal(true)) case a: AggregateFunction => a.withNewChildren(a.children.map(ifExpr)) }.transform { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index eb8abb33168e..79ab81c5c7de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -282,18 +282,18 @@ object FunctionRegistry { // aggregate functions expression[HyperLogLogPlusPlus]("approx_count_distinct"), - expressionWithAlias[Average]("avg"), - expressionWithAlias[Average]("mean"), + expressionWithTNT[Average]("avg"), + expressionWithTNT[Average]("mean"), expression[Corr]("corr"), expression[Count]("count"), expression[CountIf]("count_if"), expression[CovPopulation]("covar_pop"), expression[CovSample]("covar_samp"), - expressionWithAlias[First]("first"), - expressionWithAlias[First]("first_value"), + expressionWithTNT[First]("first"), + expressionWithTNT[First]("first_value"), expression[Kurtosis]("kurtosis"), - expressionWithAlias[Last]("last"), - expressionWithAlias[Last]("last_value"), + expressionWithTNT[Last]("last"), + expressionWithTNT[Last]("last_value"), expression[Max]("max"), expression[MaxBy]("max_by"), expression[Min]("min"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index eab00b03c729..1308d96fbb35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -429,7 +429,7 @@ object TypeCoercion { case Abs(e @ StringType()) => Abs(Cast(e, DoubleType)) case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) - case Average(funcName, e @ StringType()) => Average(funcName, Cast(e, DoubleType)) + case Average(e @ StringType()) => Average(Cast(e, DoubleType)) case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) case StddevSamp( e @ StringType()) => StddevSamp(Cast(e, DoubleType)) case UnaryMinus(e @ StringType()) => UnaryMinus(Cast(e, DoubleType)) @@ -613,15 +613,15 @@ object TypeCoercion { case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType)) case Sum(e @ FractionalType()) if e.dataType != DoubleType => Sum(Cast(e, DoubleType)) - case s @ Average(_, DecimalType()) => s // Decimal is already the biggest. - case Average(funcName, e @ IntegralType()) if e.dataType != LongType => - Average(funcName, Cast(e, LongType)) - case Average(funcName, e @ FractionalType()) if e.dataType != DoubleType => - Average(funcName, Cast(e, DoubleType)) + case s @ Average(DecimalType()) => s // Decimal is already the biggest. + case Average(e @ IntegralType()) if e.dataType != LongType => + Average(Cast(e, LongType)) + case Average(e @ FractionalType()) if e.dataType != DoubleType => + Average(Cast(e, DoubleType)) // Hive lets you do aggregation of timestamps... for some reason case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType)) - case Average(funcName, e @ TimestampType()) => Average(funcName, Cast(e, DoubleType)) + case Average(e @ TimestampType()) => Average(Cast(e, DoubleType)) // Coalesce should return the first non-null value, which could be any column // from the list. So we need to make sure the return type is deterministic and diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 34822f617bac..b4a8bafe22df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -173,8 +173,8 @@ package object dsl { def approxCountDistinct(e: Expression, rsd: Double = 0.05): Expression = HyperLogLogPlusPlus(e, rsd).toAggregateExpression() def avg(e: Expression): Expression = Average(e).toAggregateExpression() - def first(e: Expression): Expression = First(e).toAggregateExpression() - def last(e: Expression): Expression = Last(e).toAggregateExpression() + def first(e: Expression): Expression = new First(e).toAggregateExpression() + def last(e: Expression): Expression = new Last(e).toAggregateExpression() def min(e: Expression): Expression = Min(e).toAggregateExpression() def minDistinct(e: Expression): Expression = Min(e).toAggregateExpression(isDistinct = true) def max(e: Expression): Expression = Max(e).toAggregateExpression() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index d0e4071ce422..0457ddec662f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -35,11 +35,10 @@ import org.apache.spark.sql.types._ -3 days -11 hours -59 minutes -59 seconds """, since = "1.0.0") -case class Average( - funcName: String, child: Expression) +case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { - override def nodeName: String = funcName + override def nodeName: String = getTagValue(FUNC_ALIAS).getOrElse("avg") override def children: Seq[Expression] = child :: Nil @@ -95,10 +94,4 @@ case class Average( coalesce(child.cast(sumDataType), Literal.default(sumDataType))), /* count = */ If(child.isNull, count, count + 1L) ) - - override def flatArguments: Iterator[Any] = Iterator(child) -} - -object Average{ - def apply(child: Expression): Average = Average("avg", child) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index af8b2c0b1fda..9c3b4e43f8ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -44,11 +44,10 @@ import org.apache.spark.sql.types._ 5 """, since = "2.0.0") -case class First(funcName: String, child: Expression, ignoreNullsExpr: Expression) +case class First(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate with ExpectsInputTypes { - def this(funcName: String, child: Expression) = - this(funcName, child, Literal.create(false, BooleanType)) + def this(child: Expression) = this(child, Literal.create(false, BooleanType)) override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil @@ -114,18 +113,7 @@ case class First(funcName: String, child: Expression, ignoreNullsExpr: Expressio override lazy val evaluateExpression: AttributeReference = first - override def nodeName: String = funcName + override def nodeName: String = getTagValue(FUNC_ALIAS).getOrElse("first") override def toString: String = s"$prettyName($child)${if (ignoreNulls) " ignore nulls"}" - - override def flatArguments: Iterator[Any] = Iterator(child, ignoreNullsExpr) -} - -object First { - - def apply(child: Expression, ignoreNullsExpr: Expression): First = - new First("first", child, ignoreNullsExpr) - - def apply(child: Expression): First = - new First("first", child) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index 535a5f72f17f..82f7fb91d39c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -44,11 +44,10 @@ import org.apache.spark.sql.types._ 5 """, since = "2.0.0") -case class Last(funcName: String, child: Expression, ignoreNullsExpr: Expression) +case class Last(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate with ExpectsInputTypes { - def this(funcName: String, child: Expression) = - this(funcName, child, Literal.create(false, BooleanType)) + def this(child: Expression) = this(child, Literal.create(false, BooleanType)) override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil @@ -112,18 +111,7 @@ case class Last(funcName: String, child: Expression, ignoreNullsExpr: Expression override lazy val evaluateExpression: AttributeReference = last - override def nodeName: String = funcName + override def nodeName: String = getTagValue(FUNC_ALIAS).getOrElse("last") override def toString: String = s"$prettyName($child)${if (ignoreNulls) " ignore nulls"}" - - override def flatArguments: Iterator[Any] = Iterator(child, ignoreNullsExpr) -} - -object Last { - - def apply(child: Expression, ignoreNullsExpr: Expression): Last = - new Last("last", child, ignoreNullsExpr) - - def apply(child: Expression): Last = - new Last("last", child) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a4f7201f99a4..05fd5e35e22a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1464,9 +1464,9 @@ object DecimalAggregates extends Rule[LogicalPlan] { MakeDecimal(we.copy(windowFunction = ae.copy(aggregateFunction = Sum(UnscaledValue(e)))), prec + 10, scale) - case Average(f, e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => + case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => val newAggExpr = - we.copy(windowFunction = ae.copy(aggregateFunction = Average(f, UnscaledValue(e)))) + we.copy(windowFunction = ae.copy(aggregateFunction = Average(UnscaledValue(e)))) Cast( Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), DecimalType(prec + 4, scale + 4), Option(SQLConf.get.sessionLocalTimeZone)) @@ -1477,8 +1477,8 @@ object DecimalAggregates extends Rule[LogicalPlan] { case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale) - case Average(f, e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => - val newAggExpr = ae.copy(aggregateFunction = Average(f, UnscaledValue(e))) + case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => + val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e))) Cast( Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), DecimalType(prec + 4, scale + 4), Option(SQLConf.get.sessionLocalTimeZone)) @@ -1539,7 +1539,7 @@ object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] { if (keyExprIds.contains(attr.exprId)) { attr } else { - Alias(First(attr).toAggregateExpression(), attr.name)(attr.exprId) + Alias(new First(attr).toAggregateExpression(), attr.name)(attr.exprId) } } // SPARK-22951: Physical aggregate operators distinguishes global aggregation and grouping diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index 6bf3e3b94ecc..9bf864f5201f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -203,7 +203,7 @@ class ReplaceOperatorSuite extends PlanTest { Seq(attrA), Seq( attrA, - Alias(First(attrB).toAggregateExpression(), attrB.name)(attrB.exprId) + Alias(new First(attrB).toAggregateExpression(), attrB.name)(attrB.exprId) ), input) From bbd43973f4fbb1cb8875cfb38fcf67820b46fc53 Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Thu, 19 Dec 2019 00:00:29 +0530 Subject: [PATCH 25/32] Renaming to expressionWithTNT --- .../catalyst/analysis/FunctionRegistry.scala | 74 +++++-------------- 1 file changed, 18 insertions(+), 56 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 79ab81c5c7de..2e73f27a2bf8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -282,42 +282,42 @@ object FunctionRegistry { // aggregate functions expression[HyperLogLogPlusPlus]("approx_count_distinct"), - expressionWithTNT[Average]("avg"), - expressionWithTNT[Average]("mean"), + expressionWithAlias[Average]("avg"), + expressionWithAlias[Average]("mean"), expression[Corr]("corr"), expression[Count]("count"), expression[CountIf]("count_if"), expression[CovPopulation]("covar_pop"), expression[CovSample]("covar_samp"), - expressionWithTNT[First]("first"), - expressionWithTNT[First]("first_value"), + expressionWithAlias[First]("first"), + expressionWithAlias[First]("first_value"), expression[Kurtosis]("kurtosis"), - expressionWithTNT[Last]("last"), - expressionWithTNT[Last]("last_value"), + expressionWithAlias[Last]("last"), + expressionWithAlias[Last]("last_value"), expression[Max]("max"), expression[MaxBy]("max_by"), expression[Min]("min"), expression[MinBy]("min_by"), expression[Percentile]("percentile"), expression[Skewness]("skewness"), - expressionWithTNT[ApproximatePercentile]("percentile_approx"), - expressionWithTNT[ApproximatePercentile]("approx_percentile"), - expressionWithTNT[StddevSamp]("std"), - expressionWithTNT[StddevSamp]("stddev"), - expressionWithTNT[StddevSamp]("stddev_samp"), + expressionWithAlias[ApproximatePercentile]("percentile_approx"), + expressionWithAlias[ApproximatePercentile]("approx_percentile"), + expressionWithAlias[StddevSamp]("std"), + expressionWithAlias[StddevSamp]("stddev"), + expressionWithAlias[StddevSamp]("stddev_samp"), expression[StddevPop]("stddev_pop"), expression[Sum]("sum"), - expressionWithTNT[VarianceSamp]("variance"), - expressionWithTNT[VarianceSamp]("var_samp"), + expressionWithAlias[VarianceSamp]("variance"), + expressionWithAlias[VarianceSamp]("var_samp"), expression[VariancePop]("var_pop"), expression[CollectList]("collect_list"), expression[CollectSet]("collect_set"), expression[CountMinSketchAgg]("count_min_sketch"), - expressionWithTNT[BoolAnd]("every"), - expressionWithTNT[BoolAnd]("bool_and"), - expressionWithTNT[BoolOr]("any"), - expressionWithTNT[BoolOr]("some"), - expressionWithTNT[BoolOr]("bool_or"), + expressionWithAlias[BoolAnd]("every"), + expressionWithAlias[BoolAnd]("bool_and"), + expressionWithAlias[BoolOr]("any"), + expressionWithAlias[BoolOr]("some"), + expressionWithAlias[BoolOr]("bool_or"), // string functions expression[Ascii]("ascii"), @@ -633,44 +633,6 @@ object FunctionRegistry { private def expressionWithAlias[T <: Expression](name: String) (implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = { - val constructors = tag.runtimeClass.getConstructors - .filter(_.getParameterTypes.head == classOf[String]) - assert(constructors.length >= 1, - s"there is no constructor for ${tag.runtimeClass} " + - "which takes String as first argument") - val builder = (expressions: Seq[Expression]) => { - val params = classOf[String] +: Seq.fill(expressions.size)(classOf[Expression]) - val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse { - val validParametersCount = constructors - .filter(_.getParameterTypes.tail.forall(_ == classOf[Expression])) - .map(_.getParameterCount - 1).distinct.sorted - val invalidArgumentsMsg = if (validParametersCount.length == 0) { - s"Invalid arguments for function $name" - } else { - val expectedNumberOfParameters = if (validParametersCount.length == 1) { - validParametersCount.head.toString - } else { - validParametersCount.init.mkString("one of ", ", ", " and ") + - validParametersCount.last - } - s"Invalid number of arguments for function $name. " + - s"Expected: $expectedNumberOfParameters; Found: ${expressions.size}" - } - throw new AnalysisException(invalidArgumentsMsg) - } - try { - f.newInstance(name.toString +: expressions: _*).asInstanceOf[Expression] - } catch { - // the exception is an invocation exception. To get a meaningful message, we need the - // cause. - case e: Exception => throw new AnalysisException(e.getCause.getMessage) - } - } - (name, (expressionInfo[T](name), builder)) - } - - private def expressionWithTNT[T <: Expression](name: String) - (implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = { // For `RuntimeReplaceable`, skip the constructor with most arguments, which is the main // constructor and contains non-parameter `child` and should not be used as function builder. From 914691308b559500cf83d0868ff968849123d7f6 Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Thu, 19 Dec 2019 00:14:33 +0530 Subject: [PATCH 26/32] nit --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 6 ++++-- .../apache/spark/sql/catalyst/analysis/TypeCoercion.scala | 4 ++-- .../spark/sql/catalyst/expressions/aggregate/Average.scala | 3 +-- .../org/apache/spark/sql/RelationalGroupedDataset.scala | 4 ++-- .../src/main/scala/org/apache/spark/sql/functions.scala | 4 ++-- .../benchmark/ObjectHashAggregateExecBenchmark.scala | 2 +- .../spark/sql/hive/execution/ObjectHashAggregateSuite.scala | 2 +- 7 files changed, 13 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c878bac25942..659e4a5c86ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -692,8 +692,10 @@ class Analyzer( // Assumption is the aggregate function ignores nulls. This is true for all current // AggregateFunction's with the exception of First and Last in their default mode // (which we handle) and possibly some Hive UDAF's. - case First(expr, _) => First(ifExpr(expr), Literal(true)) - case Last(expr, _) => Last(ifExpr(expr), Literal(true)) + case First(expr, _) => + First(ifExpr(expr), Literal(true)) + case Last(expr, _) => + Last(ifExpr(expr), Literal(true)) case a: AggregateFunction => a.withNewChildren(a.children.map(ifExpr)) }.transform { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 1308d96fbb35..e76193fd9422 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -431,7 +431,7 @@ object TypeCoercion { case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) - case StddevSamp( e @ StringType()) => StddevSamp(Cast(e, DoubleType)) + case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType)) case UnaryMinus(e @ StringType()) => UnaryMinus(Cast(e, DoubleType)) case UnaryPositive(e @ StringType()) => UnaryPositive(Cast(e, DoubleType)) case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType)) @@ -613,7 +613,7 @@ object TypeCoercion { case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType)) case Sum(e @ FractionalType()) if e.dataType != DoubleType => Sum(Cast(e, DoubleType)) - case s @ Average(DecimalType()) => s // Decimal is already the biggest. + case s @ Average(e @ DecimalType()) => s // Decimal is already the biggest. case Average(e @ IntegralType()) if e.dataType != LongType => Average(Cast(e, LongType)) case Average(e @ FractionalType()) if e.dataType != DoubleType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 0457ddec662f..df76507390c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -35,8 +35,7 @@ import org.apache.spark.sql.types._ -3 days -11 hours -59 minutes -59 seconds """, since = "1.0.0") -case class Average(child: Expression) - extends DeclarativeAggregate with ImplicitCastInputTypes { +case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { override def nodeName: String = getTagValue(FUNC_ALIAS).getOrElse("avg") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 54655b743c7d..52bd0ecb1fff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -278,7 +278,7 @@ class RelationalGroupedDataset protected[sql]( */ @scala.annotation.varargs def mean(colNames: String*): DataFrame = { - aggregateNumericOrIntervalColumns(colNames : _*)(Average.apply) + aggregateNumericOrIntervalColumns(colNames : _*)(Average) } /** @@ -302,7 +302,7 @@ class RelationalGroupedDataset protected[sql]( */ @scala.annotation.varargs def avg(colNames: String*): DataFrame = { - aggregateNumericOrIntervalColumns(colNames : _*)(Average.apply) + aggregateNumericOrIntervalColumns(colNames : _*)(Average) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 8b4bee4bb3b1..72e9e337c425 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -431,7 +431,7 @@ object functions { * @since 2.0.0 */ def first(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction { - First(e.expr, Literal(ignoreNulls)) + new First(e.expr, Literal(ignoreNulls)) } /** @@ -556,7 +556,7 @@ object functions { * @since 2.0.0 */ def last(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction { - Last(e.expr, Literal(ignoreNulls)) + new Last(e.expr, Literal(ignoreNulls)) } /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala index 5787af91199a..c475c7b21ab9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala @@ -208,7 +208,7 @@ object ObjectHashAggregateExecBenchmark extends SqlBasedBenchmark { private def percentile_approx( column: Column, percentage: Double, isDistinct: Boolean = false): Column = { - val approxPercentile = ApproximatePercentile(column.expr, Literal(percentage)) + val approxPercentile = new ApproximatePercentile(column.expr, Literal(percentage)) Column(approxPercentile.toAggregateExpression(isDistinct)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala index b6ef34dcfd43..930f80146749 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala @@ -175,7 +175,7 @@ class ObjectHashAggregateSuite private def percentile_approx( column: Column, percentage: Double, isDistinct: Boolean = false): Column = { - val approxPercentile = ApproximatePercentile(column.expr, Literal(percentage)) + val approxPercentile = new ApproximatePercentile(column.expr, Literal(percentage)) Column(approxPercentile.toAggregateExpression(isDistinct)) } From 8e9e42bc99c2ed5a0cae299a009e4a409e01ce03 Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Thu, 19 Dec 2019 00:30:31 +0530 Subject: [PATCH 27/32] Avoid duplicate code --- .../catalyst/analysis/FunctionRegistry.scala | 63 ++----------------- 1 file changed, 5 insertions(+), 58 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 2e73f27a2bf8..124fd0a6f23d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -573,7 +573,7 @@ object FunctionRegistry { val functionSet: Set[FunctionIdentifier] = builtin.listFunction().toSet /** See usage above. */ - private def expression[T <: Expression](name: String) + private def expression[T <: Expression](name: String, isAliasName: Boolean = false) (implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = { // For `RuntimeReplaceable`, skip the constructor with most arguments, which is the main @@ -619,7 +619,9 @@ object FunctionRegistry { throw new AnalysisException(invalidArgumentsMsg) } try { - f.newInstance(expressions : _*).asInstanceOf[Expression] + val exp = f.newInstance(expressions : _*).asInstanceOf[Expression] + if (isAliasName) exp.setTagValue(exp.FUNC_ALIAS, name) + exp } catch { // the exception is an invocation exception. To get a meaningful message, we need the // cause. @@ -633,62 +635,7 @@ object FunctionRegistry { private def expressionWithAlias[T <: Expression](name: String) (implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = { - - // For `RuntimeReplaceable`, skip the constructor with most arguments, which is the main - // constructor and contains non-parameter `child` and should not be used as function builder. - val constructors = if (classOf[RuntimeReplaceable].isAssignableFrom(tag.runtimeClass)) { - val all = tag.runtimeClass.getConstructors - val maxNumArgs = all.map(_.getParameterCount).max - all.filterNot(_.getParameterCount == maxNumArgs) - } else { - tag.runtimeClass.getConstructors - } - // See if we can find a constructor that accepts Seq[Expression] - val varargCtor = constructors.find(_.getParameterTypes.toSeq == Seq(classOf[Seq[_]])) - val builder = (expressions: Seq[Expression]) => { - if (varargCtor.isDefined) { - // If there is an apply method that accepts Seq[Expression], use that one. - try { - varargCtor.get.newInstance(expressions).asInstanceOf[Expression] - } catch { - // the exception is an invocation exception. To get a meaningful message, we need the - // cause. - case e: Exception => throw new AnalysisException(e.getCause.getMessage) - } - } else { - // Otherwise, find a constructor method that matches the number of arguments, and use that. - val params = Seq.fill(expressions.size)(classOf[Expression]) - val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse { - val validParametersCount = constructors - .filter(_.getParameterTypes.forall(_ == classOf[Expression])) - .map(_.getParameterCount).distinct.sorted - val invalidArgumentsMsg = if (validParametersCount.length == 0) { - s"Invalid arguments for function $name" - } else { - val expectedNumberOfParameters = if (validParametersCount.length == 1) { - validParametersCount.head.toString - } else { - validParametersCount.init.mkString("one of ", ", ", " and ") + - validParametersCount.last - } - s"Invalid number of arguments for function $name. " + - s"Expected: $expectedNumberOfParameters; Found: ${params.length}" - } - throw new AnalysisException(invalidArgumentsMsg) - } - try { - val exp = f.newInstance(expressions : _*).asInstanceOf[Expression] - exp.setTagValue(exp.FUNC_ALIAS, name) - exp - } catch { - // the exception is an invocation exception. To get a meaningful message, we need the - // cause. - case e: Exception => throw new AnalysisException(e.getCause.getMessage) - } - } - } - - (name, (expressionInfo[T](name), builder)) + expression[T](name, true) } /** From 36418e2a54ba16e268cd54317f138433fe8fee3a Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Thu, 19 Dec 2019 01:03:18 +0530 Subject: [PATCH 28/32] small fix --- .../apache/spark/sql/catalyst/optimizer/finishAnalysis.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index c33027434152..f64b6e00373f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -47,8 +47,8 @@ object ReplaceExpressions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e: RuntimeReplaceable => e.child case CountIf(predicate) => Count(new NullIf(predicate, Literal.FalseLiteral)) - case BoolOr(_, arg) => Max(arg) - case BoolAnd(_, arg) => Min(arg) + case BoolOr(arg) => Max(arg) + case BoolAnd(arg) => Min(arg) } } From ce8ea17c9a1bf4133e6b22fdc0768c35d2f9d924 Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Thu, 19 Dec 2019 14:33:00 +0530 Subject: [PATCH 29/32] move FUNC_ALIAS to FunctionRegistry --- .../spark/sql/catalyst/analysis/FunctionRegistry.scala | 8 +++++--- .../spark/sql/catalyst/expressions/Expression.scala | 4 +--- .../expressions/aggregate/ApproximatePercentile.scala | 5 +++-- .../sql/catalyst/expressions/aggregate/Average.scala | 4 ++-- .../catalyst/expressions/aggregate/CentralMomentAgg.scala | 5 +++-- .../spark/sql/catalyst/expressions/aggregate/First.scala | 4 ++-- .../spark/sql/catalyst/expressions/aggregate/Last.scala | 4 ++-- .../catalyst/expressions/aggregate/UnevaluableAggs.scala | 6 +++--- 8 files changed, 21 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 124fd0a6f23d..9110395f830b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -22,7 +22,6 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.mutable import scala.reflect.ClassTag -import scala.util.{Failure, Success, Try} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException @@ -31,6 +30,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.xml._ +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.types._ @@ -193,6 +193,8 @@ object FunctionRegistry { type FunctionBuilder = Seq[Expression] => Expression + val FUNC_ALIAS = TreeNodeTag[String]("functionAliasName") + // Note: Whenever we add a new entry here, make sure we also update ExpressionToSQLSuite val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] = Map( // misc non-aggregate functions @@ -573,7 +575,7 @@ object FunctionRegistry { val functionSet: Set[FunctionIdentifier] = builtin.listFunction().toSet /** See usage above. */ - private def expression[T <: Expression](name: String, isAliasName: Boolean = false) + private def expression[T <: Expression](name: String, setAlias: Boolean = false) (implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = { // For `RuntimeReplaceable`, skip the constructor with most arguments, which is the main @@ -620,7 +622,7 @@ object FunctionRegistry { } try { val exp = f.newInstance(expressions : _*).asInstanceOf[Expression] - if (isAliasName) exp.setTagValue(exp.FUNC_ALIAS, name) + if (setAlias) exp.setTagValue(FUNC_ALIAS, name) exp } catch { // the exception is an invocation exception. To get a meaningful message, we need the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index c47893455636..4632957e7afd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.trees.{TreeNode, TreeNodeTag} +import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -286,8 +286,6 @@ abstract class Expression extends TreeNode[Expression] { override def simpleStringWithNodeId(): String = { throw new UnsupportedOperationException(s"$nodeName does not implement simpleStringWithNodeId") } - - val FUNC_ALIAS = TreeNodeTag[String]("functionAliasName") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index fb60fb90ad00..9bc423c615b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import com.google.common.primitives.{Doubles, Ints, Longs} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest @@ -185,7 +185,8 @@ case class ApproximatePercentile( if (returnPercentileArray) ArrayType(child.dataType, false) else child.dataType } - override def nodeName: String = getTagValue(FUNC_ALIAS).getOrElse("percentile_approx") + override def nodeName: String = + getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("percentile_approx") override def serialize(obj: PercentileDigest): Array[Byte] = { ApproximatePercentile.serializer.serialize(obj) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index df76507390c9..d788aa27459d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, TypeCheckResult} +import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils @@ -37,7 +37,7 @@ import org.apache.spark.sql.types._ since = "1.0.0") case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { - override def nodeName: String = getTagValue(FUNC_ALIAS).getOrElse("avg") + override def nodeName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg") override def children: Seq[Expression] = child :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 6c9ec6c19a45..927fe91b7533 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -174,7 +175,7 @@ case class StddevSamp(child: Expression) extends CentralMomentAgg(child) { If(n === 1.0, Double.NaN, sqrt(m2 / (n - 1.0)))) } - override def nodeName: String = getTagValue(FUNC_ALIAS).getOrElse("stddev_samp") + override def nodeName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("stddev_samp") } // Compute the population variance of a column @@ -215,7 +216,7 @@ case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) { If(n === 1.0, Double.NaN, m2 / (n - 1.0))) } - override def nodeName: String = getTagValue(FUNC_ALIAS).getOrElse("var_samp") + override def nodeName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("var_samp") } @ExpressionDescription( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index 9c3b4e43f8ca..d5d03357be89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ @@ -113,7 +113,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression) override lazy val evaluateExpression: AttributeReference = first - override def nodeName: String = getTagValue(FUNC_ALIAS).getOrElse("first") + override def nodeName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("first") override def toString: String = s"$prettyName($child)${if (ignoreNulls) " ignore nulls"}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index 82f7fb91d39c..dbafc52464fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ @@ -111,7 +111,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) override lazy val evaluateExpression: AttributeReference = last - override def nodeName: String = getTagValue(FUNC_ALIAS).getOrElse("last") + override def nodeName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("last") override def toString: String = s"$prettyName($child)${if (ignoreNulls) " ignore nulls"}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala index 11a2d7ca016c..a1cd4a77d044 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -53,7 +53,7 @@ abstract class UnevaluableBooleanAggBase(arg: Expression) """, since = "3.0.0") case class BoolAnd(arg: Expression) extends UnevaluableBooleanAggBase(arg) { - override def nodeName: String = getTagValue(FUNC_ALIAS).getOrElse("bool_and") + override def nodeName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("bool_and") } @ExpressionDescription( @@ -69,5 +69,5 @@ case class BoolAnd(arg: Expression) extends UnevaluableBooleanAggBase(arg) { """, since = "3.0.0") case class BoolOr(arg: Expression) extends UnevaluableBooleanAggBase(arg) { - override def nodeName: String = getTagValue(FUNC_ALIAS).getOrElse("bool_or") + override def nodeName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("bool_or") } From 4b536dd56464e205412ec02738ef4145517f7fbc Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Thu, 19 Dec 2019 14:38:14 +0530 Subject: [PATCH 30/32] Remove expressionWithAlias --- .../catalyst/analysis/FunctionRegistry.scala | 43 ++++++++----------- 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 9110395f830b..45af0f34b525 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -284,42 +284,42 @@ object FunctionRegistry { // aggregate functions expression[HyperLogLogPlusPlus]("approx_count_distinct"), - expressionWithAlias[Average]("avg"), - expressionWithAlias[Average]("mean"), + expression[Average]("avg"), + expression[Average]("mean", true), expression[Corr]("corr"), expression[Count]("count"), expression[CountIf]("count_if"), expression[CovPopulation]("covar_pop"), expression[CovSample]("covar_samp"), - expressionWithAlias[First]("first"), - expressionWithAlias[First]("first_value"), + expression[First]("first"), + expression[First]("first_value", true), expression[Kurtosis]("kurtosis"), - expressionWithAlias[Last]("last"), - expressionWithAlias[Last]("last_value"), + expression[Last]("last"), + expression[Last]("last_value", true), expression[Max]("max"), expression[MaxBy]("max_by"), expression[Min]("min"), expression[MinBy]("min_by"), expression[Percentile]("percentile"), expression[Skewness]("skewness"), - expressionWithAlias[ApproximatePercentile]("percentile_approx"), - expressionWithAlias[ApproximatePercentile]("approx_percentile"), - expressionWithAlias[StddevSamp]("std"), - expressionWithAlias[StddevSamp]("stddev"), - expressionWithAlias[StddevSamp]("stddev_samp"), + expression[ApproximatePercentile]("percentile_approx"), + expression[ApproximatePercentile]("approx_percentile", true), + expression[StddevSamp]("std", true), + expression[StddevSamp]("stddev", true), + expression[StddevSamp]("stddev_samp"), expression[StddevPop]("stddev_pop"), expression[Sum]("sum"), - expressionWithAlias[VarianceSamp]("variance"), - expressionWithAlias[VarianceSamp]("var_samp"), + expression[VarianceSamp]("variance", true), + expression[VarianceSamp]("var_samp"), expression[VariancePop]("var_pop"), expression[CollectList]("collect_list"), expression[CollectSet]("collect_set"), expression[CountMinSketchAgg]("count_min_sketch"), - expressionWithAlias[BoolAnd]("every"), - expressionWithAlias[BoolAnd]("bool_and"), - expressionWithAlias[BoolOr]("any"), - expressionWithAlias[BoolOr]("some"), - expressionWithAlias[BoolOr]("bool_or"), + expression[BoolAnd]("every", true), + expression[BoolAnd]("bool_and"), + expression[BoolOr]("any", true), + expression[BoolOr]("some", true), + expression[BoolOr]("bool_or"), // string functions expression[Ascii]("ascii"), @@ -634,12 +634,7 @@ object FunctionRegistry { (name, (expressionInfo[T](name), builder)) } - - private def expressionWithAlias[T <: Expression](name: String) - (implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = { - expression[T](name, true) - } - + /** * Creates a function registry lookup entry for cast aliases (SPARK-16730). * For example, if name is "int", and dataType is IntegerType, this means int(x) would become From 737f33a6e167955762305436bc78672f9b73c342 Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Thu, 19 Dec 2019 14:40:27 +0530 Subject: [PATCH 31/32] revert reorder --- .../spark/sql/catalyst/analysis/FunctionRegistry.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 45af0f34b525..7a8b88e5264a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -285,7 +285,6 @@ object FunctionRegistry { // aggregate functions expression[HyperLogLogPlusPlus]("approx_count_distinct"), expression[Average]("avg"), - expression[Average]("mean", true), expression[Corr]("corr"), expression[Count]("count"), expression[CountIf]("count_if"), @@ -298,6 +297,7 @@ object FunctionRegistry { expression[Last]("last_value", true), expression[Max]("max"), expression[MaxBy]("max_by"), + expression[Average]("mean", true), expression[Min]("min"), expression[MinBy]("min_by"), expression[Percentile]("percentile"), @@ -306,12 +306,12 @@ object FunctionRegistry { expression[ApproximatePercentile]("approx_percentile", true), expression[StddevSamp]("std", true), expression[StddevSamp]("stddev", true), - expression[StddevSamp]("stddev_samp"), expression[StddevPop]("stddev_pop"), + expression[StddevSamp]("stddev_samp"), expression[Sum]("sum"), expression[VarianceSamp]("variance", true), - expression[VarianceSamp]("var_samp"), expression[VariancePop]("var_pop"), + expression[VarianceSamp]("var_samp"), expression[CollectList]("collect_list"), expression[CollectSet]("collect_set"), expression[CountMinSketchAgg]("count_min_sketch"), @@ -634,7 +634,7 @@ object FunctionRegistry { (name, (expressionInfo[T](name), builder)) } - + /** * Creates a function registry lookup entry for cast aliases (SPARK-16730). * For example, if name is "int", and dataType is IntegerType, this means int(x) would become From 1920940363194225c6c1fd93b08bf8d6e565503a Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Thu, 19 Dec 2019 14:45:22 +0530 Subject: [PATCH 32/32] override prettyName instead of nodeName --- .../expressions/aggregate/ApproximatePercentile.scala | 2 +- .../spark/sql/catalyst/expressions/aggregate/Average.scala | 2 +- .../catalyst/expressions/aggregate/CentralMomentAgg.scala | 5 +++-- .../spark/sql/catalyst/expressions/aggregate/First.scala | 2 +- .../spark/sql/catalyst/expressions/aggregate/Last.scala | 2 +- 5 files changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 9bc423c615b3..b143ddef6a6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -185,7 +185,7 @@ case class ApproximatePercentile( if (returnPercentileArray) ArrayType(child.dataType, false) else child.dataType } - override def nodeName: String = + override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("percentile_approx") override def serialize(obj: PercentileDigest): Array[Byte] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index d788aa27459d..9bb048a9851e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.types._ since = "1.0.0") case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { - override def nodeName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg") + override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg") override def children: Seq[Expression] = child :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 927fe91b7533..bf402807d62d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -175,7 +175,8 @@ case class StddevSamp(child: Expression) extends CentralMomentAgg(child) { If(n === 1.0, Double.NaN, sqrt(m2 / (n - 1.0)))) } - override def nodeName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("stddev_samp") + override def prettyName: String = + getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("stddev_samp") } // Compute the population variance of a column @@ -216,7 +217,7 @@ case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) { If(n === 1.0, Double.NaN, m2 / (n - 1.0))) } - override def nodeName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("var_samp") + override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("var_samp") } @ExpressionDescription( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index d5d03357be89..8de866ed9fb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -113,7 +113,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression) override lazy val evaluateExpression: AttributeReference = first - override def nodeName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("first") + override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("first") override def toString: String = s"$prettyName($child)${if (ignoreNulls) " ignore nulls"}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index dbafc52464fa..f8af0cd1f303 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -111,7 +111,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) override lazy val evaluateExpression: AttributeReference = last - override def nodeName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("last") + override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("last") override def toString: String = s"$prettyName($child)${if (ignoreNulls) " ignore nulls"}" }