diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 18c40b370cb5..427f196344df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -178,7 +178,7 @@ trait CheckAnalysis extends PredicateHelper { s"of type ${condition.dataType.catalogString} is not a boolean.") case Aggregate(groupingExprs, aggregateExprs, child) => - def isAggregateExpression(expr: Expression) = { + def isAggregateExpression(expr: Expression): Boolean = { expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr) } @@ -376,6 +376,25 @@ trait CheckAnalysis extends PredicateHelper { throw new IllegalStateException( "Internal error: logical hint operator should have been removed during analysis") + case f @ Filter(condition, _) + if PlanHelper.specialExpressionsInUnsupportedOperator(f).nonEmpty => + val invalidExprSqls = PlanHelper.specialExpressionsInUnsupportedOperator(f).map(_.sql) + failAnalysis( + s""" + |Aggregate/Window/Generate expressions are not valid in where clause of the query. + |Expression in where clause: [${condition.sql}] + |Invalid expressions: [${invalidExprSqls.mkString(", ")}]""".stripMargin) + + case other if PlanHelper.specialExpressionsInUnsupportedOperator(other).nonEmpty => + val invalidExprSqls = + PlanHelper.specialExpressionsInUnsupportedOperator(other).map(_.sql) + failAnalysis( + s""" + |The query operator `${other.nodeName}` contains one or more unsupported + |expression types Aggregate, Window or Generate. + |Invalid expressions: [${invalidExprSqls.mkString(", ")}]""".stripMargin + ) + case _ => // Analysis successful! } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d4eb516534f1..6319d47c9a0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -43,38 +43,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) // - is still resolved // - only host special expressions in supported operators override protected def isPlanIntegral(plan: LogicalPlan): Boolean = { - !Utils.isTesting || (plan.resolved && checkSpecialExpressionIntegrity(plan)) - } - - /** - * Check if all operators in this plan hold structural integrity with regards to hosting special - * expressions. - * Returns true when all operators are integral. - */ - private def checkSpecialExpressionIntegrity(plan: LogicalPlan): Boolean = { - plan.find(specialExpressionInUnsupportedOperator).isEmpty - } - - /** - * Check if there's any expression in this query plan operator that is - * - A WindowExpression but the plan is not Window - * - An AggregateExpresion but the plan is not Aggregate or Window - * - A Generator but the plan is not Generate - * Returns true when this operator breaks structural integrity with one of the cases above. - */ - private def specialExpressionInUnsupportedOperator(plan: LogicalPlan): Boolean = { - val exprs = plan.expressions - exprs.flatMap { root => - root.find { - case e: WindowExpression - if !plan.isInstanceOf[Window] => true - case e: AggregateExpression - if !(plan.isInstanceOf[Aggregate] || plan.isInstanceOf[Window]) => true - case e: Generator - if !plan.isInstanceOf[Generate] => true - case _ => false - } - }.nonEmpty + !Utils.isTesting || (plan.resolved && + plan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty) } protected def fixedPoint = FixedPoint(SQLConf.get.optimizerMaxIterations) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/PlanHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/PlanHelper.scala new file mode 100644 index 000000000000..4a28d879d114 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/PlanHelper.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.{Expression, Generator, WindowExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression + +/** + * [[PlanHelper]] contains utility methods that can be used by Analyzer and Optimizer. + * It can also be container of methods that are common across multiple rules in Analyzer + * and Optimizer. + */ +object PlanHelper { + /** + * Check if there's any expression in this query plan operator that is + * - A WindowExpression but the plan is not Window + * - An AggregateExpresion but the plan is not Aggregate or Window + * - A Generator but the plan is not Generate + * Returns the list of invalid expressions that this operator hosts. This can happen when + * 1. The input query from users contain invalid expressions. + * Example : SELECT * FROM tab WHERE max(c1) > 0 + * 2. Query rewrites inadvertently produce plans that are invalid. + */ + def specialExpressionsInUnsupportedOperator(plan: LogicalPlan): Seq[Expression] = { + val exprs = plan.expressions + val invalidExpressions = exprs.flatMap { root => + root.collect { + case e: WindowExpression + if !plan.isInstanceOf[Window] => e + case e: AggregateExpression + if !(plan.isInstanceOf[Aggregate] || plan.isInstanceOf[Window]) => e + case e: Generator + if !plan.isInstanceOf[Generate] => e + } + } + invalidExpressions + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index ec71c1c93452..55ce93ead4a8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -599,4 +599,12 @@ class AnalysisErrorSuite extends AnalysisTest { assertAnalysisError(plan5, "Accessing outer query column is not allowed in" :: Nil) } + + test("Error on filter condition containing aggregate expressions") { + val a = AttributeReference("a", IntegerType)() + val b = AttributeReference("b", IntegerType)() + val plan = Filter('a === UnresolvedFunction("max", Seq(b), true), LocalRelation(a, b)) + assertAnalysisError(plan, + "Aggregate/Window/Generate expressions are not valid in where clause of the query" :: Nil) + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 7e81ff1aba37..66bc90914e0d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -141,3 +141,16 @@ SELECT every("true"); SELECT k, v, every(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg; SELECT k, v, some(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg; SELECT k, v, any(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg; + +-- Having referencing aggregate expressions is ok. +SELECT count(*) FROM test_agg HAVING count(*) > 1L; +SELECT k, max(v) FROM test_agg GROUP BY k HAVING max(v) = true; + +-- Aggrgate expressions can be referenced through an alias +SELECT * FROM (SELECT COUNT(*) AS cnt FROM test_agg) WHERE cnt > 1L; + +-- Error when aggregate expressions are in where clause directly +SELECT count(*) FROM test_agg WHERE count(*) > 1L; +SELECT count(*) FROM test_agg WHERE count(*) + 1L > 1L; +SELECT count(*) FROM test_agg WHERE k = 1 or k = 2 or count(*) + 1L > 1L or max(k) > 1; + diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql index e22cade93679..109ffa77d621 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql @@ -46,9 +46,10 @@ WHERE t1a IN (SELECT min(t2a) SELECT t1a FROM t1 GROUP BY 1 -HAVING EXISTS (SELECT 1 +HAVING EXISTS (SELECT t2a FROM t2 - WHERE t2a < min(t1a + t2a)); + GROUP BY 1 + HAVING t2a < min(t1a + t2a)); -- TC 01.04 -- Invalid due to mixure of outer and local references under an AggegatedExpression diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index daf47c4d0a39..3a5df254f2cd 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 46 +-- Number of queries: 52 -- !query 0 @@ -459,3 +459,65 @@ struct 1L +-- !query 46 schema +struct +-- !query 46 output +10 + + +-- !query 47 +SELECT k, max(v) FROM test_agg GROUP BY k HAVING max(v) = true +-- !query 47 schema +struct +-- !query 47 output +1 true +2 true +5 true + + +-- !query 48 +SELECT * FROM (SELECT COUNT(*) AS cnt FROM test_agg) WHERE cnt > 1L +-- !query 48 schema +struct +-- !query 48 output +10 + + +-- !query 49 +SELECT count(*) FROM test_agg WHERE count(*) > 1L +-- !query 49 schema +struct<> +-- !query 49 output +org.apache.spark.sql.AnalysisException + +Aggregate/Window/Generate expressions are not valid in where clause of the query. +Expression in where clause: [(count(1) > 1L)] +Invalid expressions: [count(1)]; + + +-- !query 50 +SELECT count(*) FROM test_agg WHERE count(*) + 1L > 1L +-- !query 50 schema +struct<> +-- !query 50 output +org.apache.spark.sql.AnalysisException + +Aggregate/Window/Generate expressions are not valid in where clause of the query. +Expression in where clause: [((count(1) + 1L) > 1L)] +Invalid expressions: [count(1)]; + + +-- !query 51 +SELECT count(*) FROM test_agg WHERE k = 1 or k = 2 or count(*) + 1L > 1L or max(k) > 1 +-- !query 51 schema +struct<> +-- !query 51 output +org.apache.spark.sql.AnalysisException + +Aggregate/Window/Generate expressions are not valid in where clause of the query. +Expression in where clause: [(((test_agg.`k` = 1) OR (test_agg.`k` = 2)) OR (((count(1) + 1L) > 1L) OR (max(test_agg.`k`) > 1)))] +Invalid expressions: [count(1), max(test_agg.`k`)]; diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out index e49978ddb1ce..7b47a6139f60 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out @@ -70,9 +70,10 @@ Resolved attribute(s) t2b#x missing from min(t2a)#x,t2c#x in operator !Filter t2 SELECT t1a FROM t1 GROUP BY 1 -HAVING EXISTS (SELECT 1 +HAVING EXISTS (SELECT t2a FROM t2 - WHERE t2a < min(t1a + t2a)) + GROUP BY 1 + HAVING t2a < min(t1a + t2a)) -- !query 5 schema struct<> -- !query 5 output