Skip to content

Commit b2e65a4

Browse files
Tartarus0zmgodfreyhe
authored andcommitted
[FLINK-21923][table-planner-blink] Fix ClassCastException in SplitAggregateRule when a query contains both sum/count and avg function
This closes #15341
1 parent 033cdea commit b2e65a4

File tree

4 files changed

+94
-11
lines changed

4 files changed

+94
-11
lines changed

flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRule.scala

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.flink.table.planner.plan.metadata.FlinkRelMetadataQuery
2727
import org.apache.flink.table.planner.plan.nodes.FlinkRelNode
2828
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalAggregate
2929
import org.apache.flink.table.planner.plan.utils.AggregateUtil.doAllAggSupportSplit
30-
import org.apache.flink.table.planner.plan.utils.{ExpandUtil, WindowUtil}
30+
import org.apache.flink.table.planner.plan.utils.{AggregateUtil, ExpandUtil, WindowUtil}
3131

3232
import org.apache.calcite.plan.RelOptRule.{any, operand}
3333
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
@@ -138,9 +138,11 @@ class SplitAggregateRule extends RelOptRule(
138138
val windowProps = fmq.getRelWindowProperties(agg.getInput)
139139
val isWindowAgg = WindowUtil.groupingContainsWindowStartEnd(agg.getGroupSet, windowProps)
140140
val isProctimeWindowAgg = isWindowAgg && !windowProps.isRowtime
141+
// TableAggregate is not supported. see also FLINK-21923.
142+
val isTableAgg = AggregateUtil.isTableAggregate(agg.getAggCallList)
141143

142144
agg.partialFinalType == PartialFinalType.NONE && agg.containsDistinctCall() &&
143-
splitDistinctAggEnabled && isAllAggSplittable && !isProctimeWindowAgg
145+
splitDistinctAggEnabled && isAllAggSplittable && !isProctimeWindowAgg && !isTableAgg
144146
}
145147

146148
override def onMatch(call: RelOptRuleCall): Unit = {
@@ -280,11 +282,16 @@ class SplitAggregateRule extends RelOptRule(
280282
}
281283

282284
// STEP 2.3: construct partial aggregates
283-
relBuilder.aggregate(
284-
relBuilder.groupKey(fullGroupSet, ImmutableList.of[ImmutableBitSet](fullGroupSet)),
285+
// Create aggregate node directly to avoid ClassCastException,
286+
// Please see FLINK-21923 for more details.
287+
// TODO reuse aggregate function, see FLINK-22412
288+
val partialAggregate = FlinkLogicalAggregate.create(
289+
relBuilder.build(),
290+
fullGroupSet,
291+
ImmutableList.of[ImmutableBitSet](fullGroupSet),
285292
newPartialAggCalls)
286-
relBuilder.peek().asInstanceOf[FlinkLogicalAggregate]
287-
.setPartialFinalType(PartialFinalType.PARTIAL)
293+
partialAggregate.setPartialFinalType(PartialFinalType.PARTIAL)
294+
relBuilder.push(partialAggregate)
288295

289296
// STEP 3: construct final aggregates
290297
val finalAggInputOffset = fullGroupSet.cardinality
@@ -306,13 +313,16 @@ class SplitAggregateRule extends RelOptRule(
306313
needMergeFinalAggOutput = true
307314
}
308315
}
309-
relBuilder.aggregate(
310-
relBuilder.groupKey(
311-
SplitAggregateRule.remap(fullGroupSet, originalAggregate.getGroupSet),
312-
SplitAggregateRule.remap(fullGroupSet, Seq(originalAggregate.getGroupSet))),
316+
// Create aggregate node directly to avoid ClassCastException,
317+
// Please see FLINK-21923 for more details.
318+
// TODO reuse aggregate function, see FLINK-22412
319+
val finalAggregate = FlinkLogicalAggregate.create(
320+
relBuilder.build(),
321+
SplitAggregateRule.remap(fullGroupSet, originalAggregate.getGroupSet),
322+
SplitAggregateRule.remap(fullGroupSet, Seq(originalAggregate.getGroupSet)),
313323
finalAggCalls)
314-
val finalAggregate = relBuilder.peek().asInstanceOf[FlinkLogicalAggregate]
315324
finalAggregate.setPartialFinalType(PartialFinalType.FINAL)
325+
relBuilder.push(finalAggregate)
316326

317327
// STEP 4: convert final aggregation output to the original aggregation output.
318328
// For example, aggregate function AVG is transformed to SUM0 and COUNT, so the output of

flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRuleTest.xml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,4 +430,35 @@ FlinkLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[$SUM0($2)])
430430
]]>
431431
</Resource>
432432
</TestCase>
433+
<TestCase name="testAggFilterClauseBothWithAvgAndCount">
434+
<Resource name="sql">
435+
<![CDATA[
436+
SELECT
437+
a,
438+
COUNT(DISTINCT b) FILTER (WHERE NOT b = 2),
439+
SUM(b) FILTER (WHERE NOT b = 5),
440+
COUNT(b),
441+
AVG(b),
442+
SUM(b)
443+
FROM MyTable
444+
GROUP BY a
445+
]]>
446+
</Resource>
447+
<Resource name="ast">
448+
<![CDATA[
449+
LogicalAggregate(group=[{0}], EXPR$1=[COUNT(DISTINCT $1) FILTER $2], EXPR$2=[SUM($1) FILTER $3], EXPR$3=[COUNT($1)], EXPR$4=[AVG($1)], EXPR$5=[SUM($1)])
450+
+- LogicalProject(a=[$0], b=[$1], $f2=[IS TRUE(<>($1, 2))], $f3=[IS TRUE(<>($1, 5))])
451+
+- LogicalTableScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c)]]])
452+
]]>
453+
</Resource>
454+
<Resource name="optimized rel plan">
455+
<![CDATA[
456+
FlinkLogicalCalc(select=[a, $f1, $f2, $f3, CAST(IF(=($f5, 0:BIGINT), null:INTEGER, /($f4, $f5))) AS $f4, $f6])
457+
+- FlinkLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], agg#1=[SUM($3)], agg#2=[$SUM0($4)], agg#3=[$SUM0($5)], agg#4=[$SUM0($6)], agg#5=[SUM($7)])
458+
+- FlinkLogicalAggregate(group=[{0, 4}], agg#0=[COUNT(DISTINCT $1) FILTER $2], agg#1=[SUM($1) FILTER $3], agg#2=[COUNT($1)], agg#3=[$SUM0($1)], agg#4=[COUNT($1)], agg#5=[SUM($1)])
459+
+- FlinkLogicalCalc(select=[a, b, IS TRUE(<>(b, 2)) AS $f2, IS TRUE(<>(b, 5)) AS $f3, MOD(HASH_CODE(b), 1024) AS $f4])
460+
+- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
461+
]]>
462+
</Resource>
463+
</TestCase>
433464
</Root>

flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRuleTest.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,4 +186,23 @@ class SplitAggregateRuleTest extends TableTestBase {
186186
|""".stripMargin
187187
util.verifyRelPlan(sqlQuery)
188188
}
189+
190+
@Test
191+
def testAggFilterClauseBothWithAvgAndCount(): Unit = {
192+
util.tableEnv.getConfig.getConfiguration.setBoolean(
193+
OptimizerConfigOptions.TABLE_OPTIMIZER_DISTINCT_AGG_SPLIT_ENABLED, true)
194+
val sqlQuery =
195+
s"""
196+
|SELECT
197+
| a,
198+
| COUNT(DISTINCT b) FILTER (WHERE NOT b = 2),
199+
| SUM(b) FILTER (WHERE NOT b = 5),
200+
| COUNT(b),
201+
| AVG(b),
202+
| SUM(b)
203+
|FROM MyTable
204+
|GROUP BY a
205+
|""".stripMargin
206+
util.verifyRelPlan(sqlQuery)
207+
}
189208
}

flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/SplitAggregateITCase.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,29 @@ class SplitAggregateITCase(
412412
val expected = List("1,2,1,2,1", "2,4,3,4,3", "3,1,1,null,5", "4,2,2,6,5")
413413
assertEquals(expected.sorted, sink.getRetractResults.sorted)
414414
}
415+
416+
@Test
417+
def testAggFilterClauseBothWithAvgAndCount(): Unit = {
418+
val t1 = tEnv.sqlQuery(
419+
s"""
420+
|SELECT
421+
| a,
422+
| COUNT(DISTINCT b) FILTER (WHERE NOT b = 2),
423+
| SUM(b) FILTER (WHERE NOT b = 5),
424+
| COUNT(b),
425+
| SUM(b),
426+
| AVG(b)
427+
|FROM T
428+
|GROUP BY a
429+
""".stripMargin)
430+
431+
val sink = new TestingRetractSink
432+
t1.toRetractStream[Row].addSink(sink)
433+
env.execute()
434+
435+
val expected = List("1,1,3,2,3,1", "2,3,24,8,29,3", "3,1,null,2,10,5", "4,2,6,4,21,5")
436+
assertEquals(expected.sorted, sink.getRetractResults.sorted)
437+
}
415438
}
416439

417440
object SplitAggregateITCase {

0 commit comments

Comments
 (0)