Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,17 @@ package org.apache.spark.sql

import org.apache.spark.sql.catalyst.expressions.{Alias, BloomFilterMightContain, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate}
import org.apache.spark.sql.catalyst.optimizer.MergeScalarSubqueries
import org.apache.spark.sql.catalyst.plans.LeftSemi
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, LogicalPlan}
import org.apache.spark.sql.execution.{ReusedSubqueryExec, SubqueryExec}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEPropagateEmptyRelation}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
import org.apache.spark.sql.types.{IntegerType, StructType}

class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSparkSession {
class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSparkSession
with AdaptiveSparkPlanHelper {

protected override def beforeAll(): Unit = {
super.beforeAll()
Expand Down Expand Up @@ -201,9 +205,16 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp
sql("analyze table bf4 compute statistics for columns a4, b4, c4, d4, e4, f4")
sql("analyze table bf5part compute statistics for columns a5, b5, c5, d5, e5, f5")
sql("analyze table bf5filtered compute statistics for columns a5, b5, c5, d5, e5, f5")

// `MergeScalarSubqueries` can duplicate subqueries in the optimized plan and would make testing
// complicated.
conf.setConfString(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, MergeScalarSubqueries.ruleName)
}

protected override def afterAll(): Unit = try {
conf.setConfString(SQLConf.OPTIMIZER_EXCLUDED_RULES.key,
SQLConf.OPTIMIZER_EXCLUDED_RULES.defaultValueString)

sql("DROP TABLE IF EXISTS bf1")
sql("DROP TABLE IF EXISTS bf2")
sql("DROP TABLE IF EXISTS bf3")
Expand Down Expand Up @@ -264,24 +275,28 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp
}
}

// `MergeScalarSubqueries` can duplicate subqueries in the optimized plan, but the subqueries will
// be reused in the physical plan.
def getNumBloomFilters(plan: LogicalPlan, scalarSubqueryCTEMultiplicator: Int = 1): Integer = {
val numBloomFilterAggs = plan.collectWithSubqueries {
case Aggregate(_, aggregateExpressions, _) =>
aggregateExpressions.collect {
case Alias(AggregateExpression(bfAgg: BloomFilterAggregate, _, _, _, _), _) =>
assert(bfAgg.estimatedNumItemsExpression.isInstanceOf[Literal])
assert(bfAgg.numBitsExpression.isInstanceOf[Literal])
1
def getNumBloomFilters(plan: LogicalPlan): Integer = {
val numBloomFilterAggs = plan.collect {
case Filter(condition, _) => condition.collect {
case subquery: org.apache.spark.sql.catalyst.expressions.ScalarSubquery
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the previous code more as it precisely matches Filter and ScalarSubquery. +1 to this change.

My major concern is scalarSubqueryCTEMultiplicator. This makes the test really hard to write as we need to manually think about if scalar subqueries merging can be applied or not to the testing query.

Can we turn off this optimization completely in this test suite? We can add one more test to verify the case that scalar subqueries merging is beneficial to bloom filter join, by explicitly enabling the optimizer rule.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, let me check...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7ac7981 removes scalarSubqueryCTEMultiplicator, 9b1347e adds a new test

=> subquery.plan.collect {
case Aggregate(_, aggregateExpressions, _) =>
aggregateExpressions.map {
case Alias(AggregateExpression(bfAgg : BloomFilterAggregate, _, _, _, _),
_) =>
assert(bfAgg.estimatedNumItemsExpression.isInstanceOf[Literal])
assert(bfAgg.numBitsExpression.isInstanceOf[Literal])
1
}.sum
}.sum
}.sum
}.sum
val numMightContains = plan.collect {
case Filter(condition, _) => condition.collect {
case BloomFilterMightContain(_, _) => 1
}.sum
}.sum
assert(numBloomFilterAggs == numMightContains * scalarSubqueryCTEMultiplicator)
assert(numBloomFilterAggs == numMightContains)
numMightContains
}

Expand Down Expand Up @@ -385,7 +400,7 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp
planEnabled = sql(query).queryExecution.optimizedPlan
checkAnswer(sql(query), expectedAnswer)
}
assert(getNumBloomFilters(planEnabled, 2) == getNumBloomFilters(planDisabled) + 2)
assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled) + 2)
}
}

Expand Down Expand Up @@ -413,10 +428,10 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp
checkAnswer(sql(query), expectedAnswer)
}
if (numFilterThreshold < 3) {
assert(getNumBloomFilters(planEnabled, numFilterThreshold) ==
getNumBloomFilters(planDisabled) + numFilterThreshold)
assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled)
+ numFilterThreshold)
} else {
assert(getNumBloomFilters(planEnabled, 2) == getNumBloomFilters(planDisabled) + 2)
assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled) + 2)
}
}
}
Expand Down Expand Up @@ -561,4 +576,30 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp
""".stripMargin)
}
}

test("Merge runtime bloom filters") {
withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000",
SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false",
SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "true",
// Re-enable `MergeScalarSubqueries`
SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> "",
SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) {

val query = "select * from bf1 join bf2 on bf1.c1 = bf2.c2 and " +
"bf1.b1 = bf2.b2 where bf2.a2 = 62"
val df = sql(query)
df.collect()
val plan = df.queryExecution.executedPlan

val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id }
val reusedSubqueryIds = collectWithSubqueries(plan) {
case rs: ReusedSubqueryExec => rs.child.id
}

assert(subqueryIds.size == 1, "Missing or unexpected SubqueryExec in the plan")
assert(reusedSubqueryIds.size == 1,
"Missing or unexpected reused ReusedSubqueryExec in the plan")
}
}
}