Skip to content

Commit 185aba1

Browse files
author
aokolnychyi
committed
Add test cases and move the check
1 parent 646a96d commit 185aba1

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1749,9 +1749,6 @@ class Analyzer(
17491749

17501750
private def hasWindowFunction(expr: Expression): Boolean = {
17511751
expr.find {
1752-
case AggregateExpression(aggFunc, _, _, _) if hasWindowFunction(aggFunc.children) =>
1753-
failAnalysis("It is not allowed to use a window function inside an aggregate function. " +
1754-
"Please use the inner window function in a sub-query.")
17551752
case window: WindowExpression => true
17561753
case _ => false
17571754
}.isDefined
@@ -1833,6 +1830,10 @@ class Analyzer(
18331830
seenWindowAggregates += newAgg
18341831
WindowExpression(newAgg, spec)
18351832

1833+
case AggregateExpression(aggFunc, _, _, _) if hasWindowFunction(aggFunc.children) =>
1834+
failAnalysis("It is not allowed to use a window function inside an aggregate " +
1835+
"function. Please use the inner window function in a sub-query.")
1836+
18361837
// Extracts AggregateExpression. For example, for SUM(x) - Sum(y) OVER (...),
18371838
// we need to extract SUM(x).
18381839
case agg: AggregateExpression if !seenWindowAggregates.contains(agg) =>

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -699,14 +699,23 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
699699
checkWindowError(testData2.select(min(avg('b).over(Window.partitionBy('a)))))
700700
checkWindowError(testData2.agg(sum('b), max(rank().over(Window.orderBy('a)))))
701701
checkWindowError(testData2.groupBy('a).agg(sum('b), max(rank().over(Window.orderBy('b)))))
702-
checkWindowError(testData2.groupBy('a).agg(max(sum(sum('b)).over(Window.orderBy('b)))))
702+
checkWindowError(testData2.groupBy('a).agg(max(sum(sum('b)).over(Window.orderBy('a)))))
703+
checkWindowError(
704+
testData2.groupBy('a).agg(sum('b).as("s"), max(count("*").over())).where('s === 3))
705+
checkAnswer(
706+
testData2.groupBy('a).agg(max('b), sum('b).as("s"), count("*").over()).where('s === 3),
707+
Row(1, 2, 3, 3) :: Row(2, 2, 3, 3) :: Row(3, 2, 3, 3) :: Nil)
703708

704709
checkWindowError(
705-
sql("SELECT MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a HAVING SUM(b) = 3"))
710+
sql("SELECT MIN(AVG(b) OVER(PARTITION BY a)) FROM testData2"))
706711
checkWindowError(
707-
sql("SELECT MAX(RANK() OVER(ORDER BY a)) FROM testData2"))
712+
sql("SELECT SUM(b), MAX(RANK() OVER(ORDER BY a)) FROM testData2"))
708713
checkWindowError(
709-
sql("SELECT MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a"))
714+
sql("SELECT SUM(b), MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a"))
715+
checkWindowError(
716+
sql("SELECT MAX(SUM(SUM(b)) OVER(ORDER BY a)) FROM testData2 GROUP BY a"))
717+
checkWindowError(
718+
sql("SELECT MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a HAVING SUM(b) = 3"))
710719
checkAnswer(
711720
sql("SELECT a, MAX(b), RANK() OVER(ORDER BY a) FROM testData2 GROUP BY a HAVING SUM(b) = 3"),
712721
Row(1, 2, 1) :: Row(2, 2, 2) :: Row(3, 2, 3) :: Nil)

0 commit comments

Comments
 (0)