From 93d2bde8c01b626456094f049911f976d1c1660d Mon Sep 17 00:00:00 2001 From: maryannxue Date: Wed, 11 Jul 2018 15:50:50 -0700 Subject: [PATCH 1/2] [SPARK-24790] Allow complex aggregate expressions in Pivot --- .../sql/catalyst/analysis/Analyzer.scala | 23 ++++++----- .../sql/catalyst/analysis/CheckAnalysis.scala | 38 +++++++++++-------- .../test/resources/sql-tests/inputs/pivot.sql | 18 +++++++++ .../resources/sql-tests/results/pivot.sql.out | 34 ++++++++++++++++- 4 files changed, 83 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e187133d03b17..3103eb65b7955 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -509,12 +509,7 @@ class Analyzer( || !p.pivotColumn.resolved => p case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => // Check all aggregate expressions. - aggregates.foreach { e => - if (!isAggregateExpression(e)) { - throw new AnalysisException( - s"Aggregate expression required for pivot, found '$e'") - } - } + aggregates.foreach(checkValidAggregateExpression) // Group-by expressions coming from SQL are implicit and need to be deduced. val groupByExprs = groupByExprsOpt.getOrElse( (child.outputSet -- aggregates.flatMap(_.references) -- pivotColumn.references).toSeq) @@ -586,12 +581,16 @@ class Analyzer( } } - private def isAggregateExpression(expr: Expression): Boolean = { - expr match { - case Alias(e, _) => isAggregateExpression(e) - case AggregateExpression(_, _, _, _) => true - case _ => false - } + // TODO: Support Pandas UDF. + private def checkValidAggregateExpression(expr: Expression): Unit = expr match { + case expr: AggregateExpression => + checkAggregateFunctionArguments( + expr.aggregateFunction, _.isInstanceOf[AggregateExpression]) + case e: Attribute => + failAnalysis( + s"Aggregate expression required for pivot, but '${e.sql}' " + + s"did not appear in any aggregate function.") + case e => e.children.foreach(checkValidAggregateExpression) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index af256b98b34f3..a716110450fa5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -60,6 +60,27 @@ trait CheckAnalysis extends PredicateHelper { case _ => None } + protected def checkAggregateFunctionArguments( + aggFunction: Expression, + isAggregateExpression: Expression => Boolean): Unit = { + aggFunction.children.foreach { child => + child.foreach { + case expr: Expression if isAggregateExpression(expr) => + failAnalysis( + s"It is not allowed to use an aggregate function in the argument of " + + s"another aggregate function. Please use the inner aggregate function " + + s"in a sub-query.") + case _ => // OK + } + + if (!child.deterministic) { + failAnalysis( + s"nondeterministic expression ${child.sql} should not " + + s"appear in the arguments of an aggregate function.") + } + } + } + private def checkLimitClause(limitExpr: Expression): Unit = { limitExpr match { case e if !e.foldable => failAnalysis( @@ -171,22 +192,7 @@ trait CheckAnalysis extends PredicateHelper { case agg: AggregateExpression => agg.aggregateFunction case udf: PythonUDF => udf } - aggFunction.children.foreach { child => - child.foreach { - case expr: Expression if isAggregateExpression(expr) => - failAnalysis( - s"It is not allowed to use an aggregate function in the argument of " + - s"another aggregate function. Please use the inner aggregate function " + - s"in a sub-query.") - case other => // OK - } - - if (!child.deterministic) { - failAnalysis( - s"nondeterministic expression ${expr.sql} should not " + - s"appear in the arguments of an aggregate function.") - } - } + checkAggregateFunctionArguments(aggFunction, isAggregateExpression) case e: Attribute if groupingExprs.isEmpty => // Collect all [[AggregateExpressions]]s. val aggExprs = aggregateExprs.filter(_.collect { diff --git a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql index 01dea6c81c11b..b3d53adfbebe7 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql @@ -111,3 +111,21 @@ PIVOT ( sum(earnings) FOR year IN (2012, 2013) ); + +-- pivot with complex aggregate expressions +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + ceil(sum(earnings)), avg(earnings) + 1 as a1 + FOR course IN ('dotNET', 'Java') +); + +-- pivot with invalid arguments in aggregate expressions +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(avg(earnings)) + FOR course IN ('dotNET', 'Java') +); diff --git a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out index 85e3488990e20..922d8b9f9152c 100644 --- a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 13 +-- Number of queries: 15 -- !query 0 @@ -176,7 +176,7 @@ PIVOT ( struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -Aggregate expression required for pivot, found 'abs(earnings#x)'; +Aggregate expression required for pivot, but 'coursesales.`earnings`' did not appear in any aggregate function.; -- !query 12 @@ -192,3 +192,33 @@ struct<> -- !query 12 output org.apache.spark.sql.AnalysisException cannot resolve '`year`' given input columns: [__auto_generated_subquery_name.course, __auto_generated_subquery_name.earnings]; line 4 pos 0 + + +-- !query 13 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + ceil(sum(earnings)), avg(earnings) + 1 as a1 + FOR course IN ('dotNET', 'Java') +) +-- !query 13 schema +struct +-- !query 13 output +2012 15000 7501.0 20000 20001.0 +2013 48000 48001.0 30000 30001.0 + + +-- !query 14 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(avg(earnings)) + FOR course IN ('dotNET', 'Java') +) +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +It is not allowed to use an aggregate function in the argument of another aggregate function. Please use the inner aggregate function in a sub-query.; From e5d4384cf123d5da1cd2084d1535f02a2046f31e Mon Sep 17 00:00:00 2001 From: maryannxue Date: Thu, 12 Jul 2018 10:30:22 -0700 Subject: [PATCH 2/2] Defer aggregate expression check --- .../sql/catalyst/analysis/Analyzer.scala | 7 ++-- .../sql/catalyst/analysis/CheckAnalysis.scala | 38 ++++++++----------- 2 files changed, 20 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3103eb65b7955..c91f103f82da9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -581,11 +581,12 @@ class Analyzer( } } + // Support any aggregate expression that can appear in an Aggregate plan except Pandas UDF. // TODO: Support Pandas UDF. private def checkValidAggregateExpression(expr: Expression): Unit = expr match { - case expr: AggregateExpression => - checkAggregateFunctionArguments( - expr.aggregateFunction, _.isInstanceOf[AggregateExpression]) + case _: AggregateExpression => // OK and leave the argument check to CheckAnalysis. + case expr: PythonUDF if PythonUDF.isGroupedAggPandasUDF(expr) => + failAnalysis("Pandas UDF aggregate expressions are currently not supported in pivot.") case e: Attribute => failAnalysis( s"Aggregate expression required for pivot, but '${e.sql}' " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index a716110450fa5..af256b98b34f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -60,27 +60,6 @@ trait CheckAnalysis extends PredicateHelper { case _ => None } - protected def checkAggregateFunctionArguments( - aggFunction: Expression, - isAggregateExpression: Expression => Boolean): Unit = { - aggFunction.children.foreach { child => - child.foreach { - case expr: Expression if isAggregateExpression(expr) => - failAnalysis( - s"It is not allowed to use an aggregate function in the argument of " + - s"another aggregate function. Please use the inner aggregate function " + - s"in a sub-query.") - case _ => // OK - } - - if (!child.deterministic) { - failAnalysis( - s"nondeterministic expression ${child.sql} should not " + - s"appear in the arguments of an aggregate function.") - } - } - } - private def checkLimitClause(limitExpr: Expression): Unit = { limitExpr match { case e if !e.foldable => failAnalysis( @@ -192,7 +171,22 @@ trait CheckAnalysis extends PredicateHelper { case agg: AggregateExpression => agg.aggregateFunction case udf: PythonUDF => udf } - checkAggregateFunctionArguments(aggFunction, isAggregateExpression) + aggFunction.children.foreach { child => + child.foreach { + case expr: Expression if isAggregateExpression(expr) => + failAnalysis( + s"It is not allowed to use an aggregate function in the argument of " + + s"another aggregate function. Please use the inner aggregate function " + + s"in a sub-query.") + case other => // OK + } + + if (!child.deterministic) { + failAnalysis( + s"nondeterministic expression ${expr.sql} should not " + + s"appear in the arguments of an aggregate function.") + } + } case e: Attribute if groupingExprs.isEmpty => // Collect all [[AggregateExpressions]]s. val aggExprs = aggregateExprs.filter(_.collect {