From 2f7df5b40e121aa58bbafd16d223df08f3b4f411 Mon Sep 17 00:00:00 2001 From: LantaoJin Date: Tue, 12 Jan 2021 13:22:46 +0800 Subject: [PATCH] [SPARK-34082][SQL] Window expressions with alias inside WHERE and HAVING clauses fail with explicit exceptions --- .../sql/catalyst/analysis/Analyzer.scala | 21 +++++++++++++++++++ .../sql/DataFrameWindowFunctionsSuite.scala | 9 ++++++++ 2 files changed, 30 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index bf5dbb8200e8..eb4c40695c0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2932,6 +2932,16 @@ class Analyzer(override val catalogManager: CatalogManager) Project(windowOps.output ++ newExpressionsWithWindowFunctions, windowOps) } // end of addWindow + private def windowFunctionAliasInCondition( + condition: Expression, + windowExpressions: Seq[NamedExpression]): Boolean = { + val referenceNames = condition.references.map(_.name).toSet + windowExpressions.exists { + case Alias(_, name) if referenceNames.exists(r => resolver(r, name)) => true + case _ => false + } + } + // We have to use transformDown at here to make sure the rule of // "Aggregate with Having clause" will be triggered. def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { @@ -2939,6 +2949,14 @@ class Analyzer(override val catalogManager: CatalogManager) case Filter(condition, _) if hasWindowFunction(condition) => throw QueryCompilationErrors.windowFunctionNotAllowedError("WHERE") + case p @ Project(projectList, Filter(condition, _)) if hasWindowFunction(projectList) => + val (windowExpressions, _) = extract(projectList) + if (windowFunctionAliasInCondition(condition, windowExpressions)) { + throw QueryCompilationErrors.windowFunctionNotAllowedError("WHERE") + } else { + p + } + case UnresolvedHaving(condition, _) if hasWindowFunction(condition) => throw QueryCompilationErrors.windowFunctionNotAllowedError("HAVING") @@ -2949,6 +2967,9 @@ class Analyzer(override val catalogManager: CatalogManager) hasWindowFunction(aggregateExprs) && a.expressions.forall(_.resolved) => val (windowExpressions, aggregateExpressions) = extract(aggregateExprs) + if (windowFunctionAliasInCondition(condition, windowExpressions)) { + throw QueryCompilationErrors.windowFunctionNotAllowedError("HAVING") + } // Create an Aggregate operator to evaluate aggregation functions. val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child) // Add a Filter operator for conditions in the Having clause. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 3568ad3a7343..426cabf9d5fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -914,6 +914,8 @@ class DataFrameWindowFunctionsSuite extends QueryTest checkAnalysisError(sql("SELECT a FROM testData2 WHERE RANK() OVER(ORDER BY b) = 1"), "WHERE") checkAnalysisError( sql("SELECT * FROM testData2 WHERE b = 2 AND RANK() OVER(ORDER BY b) = 1"), "WHERE") + checkAnalysisError( + sql("SELECT a, RANK() OVER(ORDER BY b) AS s FROM testData2 WHERE b = 2 AND s = 1"), "WHERE") checkAnalysisError( sql("SELECT * FROM testData2 GROUP BY a HAVING a > AVG(b) AND RANK() OVER(ORDER BY a) = 1"), "HAVING") @@ -927,6 +929,13 @@ class DataFrameWindowFunctionsSuite extends QueryTest |GROUP BY a |HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin), "HAVING") + checkAnalysisError( + sql( + s"""SELECT a, MAX(b), RANK() OVER(ORDER BY a) AS s + |FROM testData2 + |GROUP BY a + |HAVING SUM(b) = 5 AND s = 1""".stripMargin), + "HAVING") } test("window functions in multiple selects") {