diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala index f82a5a3041b1c..0e016e19a628b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala @@ -19,13 +19,17 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.expressions.{Alias, BloomFilterMightContain, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate} +import org.apache.spark.sql.catalyst.optimizer.MergeScalarSubqueries import org.apache.spark.sql.catalyst.plans.LeftSemi import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, LogicalPlan} +import org.apache.spark.sql.execution.{ReusedSubqueryExec, SubqueryExec} +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEPropagateEmptyRelation} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} import org.apache.spark.sql.types.{IntegerType, StructType} -class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSparkSession { +class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSparkSession + with AdaptiveSparkPlanHelper { protected override def beforeAll(): Unit = { super.beforeAll() @@ -201,9 +205,16 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp sql("analyze table bf4 compute statistics for columns a4, b4, c4, d4, e4, f4") sql("analyze table bf5part compute statistics for columns a5, b5, c5, d5, e5, f5") sql("analyze table bf5filtered compute statistics for columns a5, b5, c5, d5, e5, f5") + + // `MergeScalarSubqueries` can duplicate subqueries in the optimized plan and would make testing + // complicated. + conf.setConfString(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, MergeScalarSubqueries.ruleName) } protected override def afterAll(): Unit = try { + conf.setConfString(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, + SQLConf.OPTIMIZER_EXCLUDED_RULES.defaultValueString) + sql("DROP TABLE IF EXISTS bf1") sql("DROP TABLE IF EXISTS bf2") sql("DROP TABLE IF EXISTS bf3") @@ -264,24 +275,28 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp } } - // `MergeScalarSubqueries` can duplicate subqueries in the optimized plan, but the subqueries will - // be reused in the physical plan. - def getNumBloomFilters(plan: LogicalPlan, scalarSubqueryCTEMultiplicator: Int = 1): Integer = { - val numBloomFilterAggs = plan.collectWithSubqueries { - case Aggregate(_, aggregateExpressions, _) => - aggregateExpressions.collect { - case Alias(AggregateExpression(bfAgg: BloomFilterAggregate, _, _, _, _), _) => - assert(bfAgg.estimatedNumItemsExpression.isInstanceOf[Literal]) - assert(bfAgg.numBitsExpression.isInstanceOf[Literal]) - 1 + def getNumBloomFilters(plan: LogicalPlan): Integer = { + val numBloomFilterAggs = plan.collect { + case Filter(condition, _) => condition.collect { + case subquery: org.apache.spark.sql.catalyst.expressions.ScalarSubquery + => subquery.plan.collect { + case Aggregate(_, aggregateExpressions, _) => + aggregateExpressions.map { + case Alias(AggregateExpression(bfAgg : BloomFilterAggregate, _, _, _, _), + _) => + assert(bfAgg.estimatedNumItemsExpression.isInstanceOf[Literal]) + assert(bfAgg.numBitsExpression.isInstanceOf[Literal]) + 1 + }.sum }.sum + }.sum }.sum val numMightContains = plan.collect { case Filter(condition, _) => condition.collect { case BloomFilterMightContain(_, _) => 1 }.sum }.sum - assert(numBloomFilterAggs == numMightContains * scalarSubqueryCTEMultiplicator) + assert(numBloomFilterAggs == numMightContains) numMightContains } @@ -385,7 +400,7 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp planEnabled = sql(query).queryExecution.optimizedPlan checkAnswer(sql(query), expectedAnswer) } - assert(getNumBloomFilters(planEnabled, 2) == getNumBloomFilters(planDisabled) + 2) + assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled) + 2) } } @@ -413,10 +428,10 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp checkAnswer(sql(query), expectedAnswer) } if (numFilterThreshold < 3) { - assert(getNumBloomFilters(planEnabled, numFilterThreshold) == - getNumBloomFilters(planDisabled) + numFilterThreshold) + assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled) + + numFilterThreshold) } else { - assert(getNumBloomFilters(planEnabled, 2) == getNumBloomFilters(planDisabled) + 2) + assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled) + 2) } } } @@ -561,4 +576,30 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp """.stripMargin) } } + + test("Merge runtime bloom filters") { + withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000", + SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false", + SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "true", + // Re-enable `MergeScalarSubqueries` + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> "", + SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) { + + val query = "select * from bf1 join bf2 on bf1.c1 = bf2.c2 and " + + "bf1.b1 = bf2.b2 where bf2.a2 = 62" + val df = sql(query) + df.collect() + val plan = df.queryExecution.executedPlan + + val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } + val reusedSubqueryIds = collectWithSubqueries(plan) { + case rs: ReusedSubqueryExec => rs.child.id + } + + assert(subqueryIds.size == 1, "Missing or unexpected SubqueryExec in the plan") + assert(reusedSubqueryIds.size == 1, + "Missing or unexpected reused ReusedSubqueryExec in the plan") + } + } }