Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ trait CheckAnalysis extends PredicateHelper {
s"of type ${condition.dataType.catalogString} is not a boolean.")

case Aggregate(groupingExprs, aggregateExprs, child) =>
def isAggregateExpression(expr: Expression) = {
def isAggregateExpression(expr: Expression): Boolean = {
expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr)
}

Expand Down Expand Up @@ -376,6 +376,25 @@ trait CheckAnalysis extends PredicateHelper {
throw new IllegalStateException(
"Internal error: logical hint operator should have been removed during analysis")

case f @ Filter(condition, _)
if PlanHelper.specialExpressionsInUnsupportedOperator(f).nonEmpty =>
val invalidExprSqls = PlanHelper.specialExpressionsInUnsupportedOperator(f).map(_.sql)
failAnalysis(
s"""
|Aggregate/Window/Generate expressions are not valid in where clause of the query.
|Expression in where clause: [${condition.sql}]
|Invalid expressions: [${invalidExprSqls.mkString(", ")}]""".stripMargin)

case other if PlanHelper.specialExpressionsInUnsupportedOperator(other).nonEmpty =>
val invalidExprSqls =
PlanHelper.specialExpressionsInUnsupportedOperator(other).map(_.sql)
failAnalysis(
s"""
|The query operator `${other.nodeName}` contains one or more unsupported
|expression types Aggregate, Window or Generate.
|Invalid expressions: [${invalidExprSqls.mkString(", ")}]""".stripMargin
)

case _ => // Analysis successful!
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,38 +43,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
// - is still resolved
// - only host special expressions in supported operators
override protected def isPlanIntegral(plan: LogicalPlan): Boolean = {
!Utils.isTesting || (plan.resolved && checkSpecialExpressionIntegrity(plan))
}

/**
* Check if all operators in this plan hold structural integrity with regards to hosting special
* expressions.
* Returns true when all operators are integral.
*/
private def checkSpecialExpressionIntegrity(plan: LogicalPlan): Boolean = {
plan.find(specialExpressionInUnsupportedOperator).isEmpty
}

/**
* Check if there's any expression in this query plan operator that is
* - A WindowExpression but the plan is not Window
* - An AggregateExpresion but the plan is not Aggregate or Window
* - A Generator but the plan is not Generate
* Returns true when this operator breaks structural integrity with one of the cases above.
*/
private def specialExpressionInUnsupportedOperator(plan: LogicalPlan): Boolean = {
val exprs = plan.expressions
exprs.flatMap { root =>
root.find {
case e: WindowExpression
if !plan.isInstanceOf[Window] => true
case e: AggregateExpression
if !(plan.isInstanceOf[Aggregate] || plan.isInstanceOf[Window]) => true
case e: Generator
if !plan.isInstanceOf[Generate] => true
case _ => false
}
}.nonEmpty
!Utils.isTesting || (plan.resolved &&
plan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty)
}

protected def fixedPoint = FixedPoint(SQLConf.get.optimizerMaxIterations)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.expressions.{Expression, Generator, WindowExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression

/**
* [[PlanHelper]] contains utility methods that can be used by Analyzer and Optimizer.
* It can also be container of methods that are common across multiple rules in Analyzer
* and Optimizer.
*/
object PlanHelper {
/**
* Check if there's any expression in this query plan operator that is
* - A WindowExpression but the plan is not Window
* - An AggregateExpresion but the plan is not Aggregate or Window
* - A Generator but the plan is not Generate
* Returns the list of invalid expressions that this operator hosts. This can happen when
* 1. The input query from users contain invalid expressions.
* Example : SELECT * FROM tab WHERE max(c1) > 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW how about SELECT max(c1) FROM tab WHERE max(c1) > 0? Does it work?

Copy link
Contributor Author

@dilipbiswal dilipbiswal Mar 29, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan No, this case does not work. But this should work and i have a test case for this.

SELECT max FROM (SELECT max(v1) AS max FROM tab) WHERE max > 1

* 2. Query rewrites inadvertently produce plans that are invalid.
*/
def specialExpressionsInUnsupportedOperator(plan: LogicalPlan): Seq[Expression] = {
val exprs = plan.expressions
val invalidExpressions = exprs.flatMap { root =>
root.collect {
case e: WindowExpression
if !plan.isInstanceOf[Window] => e
case e: AggregateExpression
if !(plan.isInstanceOf[Aggregate] || plan.isInstanceOf[Window]) => e
case e: Generator
if !plan.isInstanceOf[Generate] => e
}
}
invalidExpressions
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -599,4 +599,12 @@ class AnalysisErrorSuite extends AnalysisTest {
assertAnalysisError(plan5,
"Accessing outer query column is not allowed in" :: Nil)
}

test("Error on filter condition containing aggregate expressions") {
val a = AttributeReference("a", IntegerType)()
val b = AttributeReference("b", IntegerType)()
val plan = Filter('a === UnresolvedFunction("max", Seq(b), true), LocalRelation(a, b))
assertAnalysisError(plan,
"Aggregate/Window/Generate expressions are not valid in where clause of the query" :: Nil)
}
}
13 changes: 13 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/group-by.sql
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,16 @@ SELECT every("true");
SELECT k, v, every(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg;
SELECT k, v, some(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg;
SELECT k, v, any(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg;

-- Having referencing aggregate expressions is ok.
SELECT count(*) FROM test_agg HAVING count(*) > 1L;
SELECT k, max(v) FROM test_agg GROUP BY k HAVING max(v) = true;

-- Aggrgate expressions can be referenced through an alias
SELECT * FROM (SELECT COUNT(*) AS cnt FROM test_agg) WHERE cnt > 1L;

-- Error when aggregate expressions are in where clause directly
SELECT count(*) FROM test_agg WHERE count(*) > 1L;
SELECT count(*) FROM test_agg WHERE count(*) + 1L > 1L;
SELECT count(*) FROM test_agg WHERE k = 1 or k = 2 or count(*) + 1L > 1L or max(k) > 1;

Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ WHERE t1a IN (SELECT min(t2a)
SELECT t1a
FROM t1
GROUP BY 1
HAVING EXISTS (SELECT 1
HAVING EXISTS (SELECT t2a
FROM t2
WHERE t2a < min(t1a + t2a));
GROUP BY 1
HAVING t2a < min(t1a + t2a));

-- TC 01.04
-- Invalid due to mixure of outer and local references under an AggegatedExpression
Expand Down
64 changes: 63 additions & 1 deletion sql/core/src/test/resources/sql-tests/results/group-by.sql.out
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 46
-- Number of queries: 52


-- !query 0
Expand Down Expand Up @@ -459,3 +459,65 @@ struct<k:int,v:boolean,any(v) OVER (PARTITION BY k ORDER BY v ASC NULLS FIRST RA
5 NULL NULL
5 false false
5 true true


-- !query 46
SELECT count(*) FROM test_agg HAVING count(*) > 1L
-- !query 46 schema
struct<count(1):bigint>
-- !query 46 output
10


-- !query 47
SELECT k, max(v) FROM test_agg GROUP BY k HAVING max(v) = true
-- !query 47 schema
struct<k:int,max(v):boolean>
-- !query 47 output
1 true
2 true
5 true


-- !query 48
SELECT * FROM (SELECT COUNT(*) AS cnt FROM test_agg) WHERE cnt > 1L
-- !query 48 schema
struct<cnt:bigint>
-- !query 48 output
10


-- !query 49
SELECT count(*) FROM test_agg WHERE count(*) > 1L
-- !query 49 schema
struct<>
-- !query 49 output
org.apache.spark.sql.AnalysisException

Aggregate/Window/Generate expressions are not valid in where clause of the query.
Expression in where clause: [(count(1) > 1L)]
Invalid expressions: [count(1)];


-- !query 50
SELECT count(*) FROM test_agg WHERE count(*) + 1L > 1L
-- !query 50 schema
struct<>
-- !query 50 output
org.apache.spark.sql.AnalysisException

Aggregate/Window/Generate expressions are not valid in where clause of the query.
Expression in where clause: [((count(1) + 1L) > 1L)]
Invalid expressions: [count(1)];


-- !query 51
SELECT count(*) FROM test_agg WHERE k = 1 or k = 2 or count(*) + 1L > 1L or max(k) > 1
-- !query 51 schema
struct<>
-- !query 51 output
org.apache.spark.sql.AnalysisException

Aggregate/Window/Generate expressions are not valid in where clause of the query.
Expression in where clause: [(((test_agg.`k` = 1) OR (test_agg.`k` = 2)) OR (((count(1) + 1L) > 1L) OR (max(test_agg.`k`) > 1)))]
Invalid expressions: [count(1), max(test_agg.`k`)];
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,10 @@ Resolved attribute(s) t2b#x missing from min(t2a)#x,t2c#x in operator !Filter t2
SELECT t1a
FROM t1
GROUP BY 1
HAVING EXISTS (SELECT 1
HAVING EXISTS (SELECT t2a
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan because, the child subquery plan can not be resolved due to missing attributes.

+- Project [1#81]
:     +- !Filter (t2a#9 < min((outer(t1a#6) + t2a#9))#86)
:        +- Aggregate [1], [1 AS 1#81, min((outer(t1a#6) + t2a#9)) AS min((outer(t1a#6) + t2a#9))#86]
:           +- SubqueryAlias `t2`
:              +- Project [t2a#9, t2b#10, t2c#11]
:                 +- SubqueryAlias `t2`
:                    +- LocalRelation [t2a#9, t2b#10, t2c#11]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we do expect analysis exception for this query, don't we?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan Yeah.. this particular test is in subquery suite and checking for a specific analysis expression that disallows a combination of local and outer references inside an aggregate expression. However, this test had the aggregate expression in the Filter clause which we are disallowing now. So i changed it to use "having" instead of "filter". But when i did that i hit a different analysis exception due to missing attributes (this particular test in question is trying to test a different analysis exception). So i fixed the projection here to include the missing attribute so we get a desired subquery related exception.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense

FROM t2
WHERE t2a < min(t1a + t2a))
GROUP BY 1
HAVING t2a < min(t1a + t2a))
-- !query 5 schema
struct<>
-- !query 5 output
Expand Down