diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index ff8856708c6d1..681163ca507b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -74,12 +74,14 @@ object SubqueryExpression { } /** - * Returns true when an expression contains a subquery that has outer reference(s). The outer + * Returns true when an expression contains a subquery that has outer reference(s) except + * the [[org.apache.spark.sql.catalyst.expressions.DynamicPruningSubquery]]. The outer * reference attributes are kept as children of subquery expression by * [[org.apache.spark.sql.catalyst.analysis.Analyzer.ResolveSubquery]] */ def hasCorrelatedSubquery(e: Expression): Boolean = { e.find { + case _: DynamicPruningSubquery => false case s: SubqueryExpression => s.children.nonEmpty case _ => false }.isDefined diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index c4243da7b9e4b..61ee3019d9344 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -58,10 +58,16 @@ trait ConstraintHelper { * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an * additional constraint of the form `b = 5`. */ - def inferAdditionalConstraints(constraints: ExpressionSet): ExpressionSet = { + def inferAdditionalConstraints( + constraints: ExpressionSet, + isInferDynamicPruning: Boolean = false): ExpressionSet = { var inferredConstraints = ExpressionSet() // IsNotNull should be constructed by `constructIsNotNullConstraints`. - val predicates = constraints.filterNot(_.isInstanceOf[IsNotNull]) + val predicates = if (isInferDynamicPruning) { + constraints.filterNot(_.isInstanceOf[IsNotNull]) + } else { + constraints.filterNot(e => e.isInstanceOf[IsNotNull] || e.isInstanceOf[DynamicPruning]) + } predicates.foreach { case eq @ EqualTo(l: Attribute, r: Attribute) => val candidateConstraints = predicates - eq diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index dde5dc2be0556..94e4ad786a574 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions import org.apache.spark.sql.execution.datasources.SchemaPruning import org.apache.spark.sql.execution.datasources.v2.{V2ScanRelationPushDown, V2Writes} -import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning} +import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, InferDynamicPruningFilters, PartitionPruning} import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs} class SparkOptimizer( @@ -44,7 +44,11 @@ class SparkOptimizer( Batch("PartitionPruning", Once, PartitionPruning, OptimizeSubqueries) :+ - Batch("Pushdown Filters from PartitionPruning", fixedPoint, + Batch("Pushdown Filters from PartitionPruning before Inferring Filters", fixedPoint, + PushDownPredicates) :+ + Batch("Infer Filters from PartitionPruning", Once, + InferDynamicPruningFilters) :+ + Batch("Pushdown Filters from PartitionPruning after Inferring Filters", fixedPoint, PushDownPredicates) :+ Batch("Cleanup filters that cannot be pushed down", Once, CleanupDynamicPruningFilters, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/InferDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/InferDynamicPruningFilters.scala new file mode 100644 index 0000000000000..b92ab41f0169b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/InferDynamicPruningFilters.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.dynamicpruning + +import org.apache.spark.sql.catalyst.expressions.{And, DynamicPruningSubquery, ExpressionSet, PredicateHelper} +import org.apache.spark.sql.catalyst.plans.{InnerLike, LeftAnti, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.catalyst.plans.logical.{ConstraintHelper, Filter, Join, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.dynamicpruning.PartitionPruning._ +import org.apache.spark.sql.internal.SQLConf + +/** + * Similar to InferFiltersFromConstraints, this one only infer DynamicPruning filters. + */ +object InferDynamicPruningFilters extends Rule[LogicalPlan] + with PredicateHelper with ConstraintHelper { + + def apply(plan: LogicalPlan): LogicalPlan = { + if (SQLConf.get.constraintPropagationEnabled) { + inferFilters(plan) + } else { + plan + } + } + + private def inferFilters(plan: LogicalPlan): LogicalPlan = plan transform { + case join @ Join(left, right, joinType, _, _) => + joinType match { + // For inner join, we can infer additional filters for both sides. LeftSemi is kind of an + // inner join, it just drops the right side in the final output. + case _: InnerLike | LeftSemi => + val allConstraints = inferDynamicPrunings(join) + val newLeft = inferNewFilter(left, allConstraints) + val newRight = inferNewFilter(right, allConstraints) + join.copy(left = newLeft, right = newRight) + + // For right outer join, we can only infer additional filters for left side. + case RightOuter => + val allConstraints = inferDynamicPrunings(join) + val newLeft = inferNewFilter(left, allConstraints) + join.copy(left = newLeft) + + // For left join, we can only infer additional filters for right side. + case LeftOuter | LeftAnti => + val allConstraints = inferDynamicPrunings(join) + val newRight = inferNewFilter(right, allConstraints) + join.copy(right = newRight) + + case _ => join + } + } + + def inferDynamicPrunings(join: Join): ExpressionSet = { + val baseConstraints = join.left.constraints.union(join.right.constraints) + .union(ExpressionSet(join.condition.map(splitConjunctivePredicates).getOrElse(Nil))) + inferAdditionalConstraints(baseConstraints, true).filter { + case DynamicPruningSubquery( + pruningKey, buildQuery, buildKeys, broadcastKeyIndex, _, _) => + getPartitionTableScan(pruningKey, join) match { + case Some(partScan) => + pruningHasBenefit(pruningKey, partScan, buildKeys(broadcastKeyIndex), buildQuery) + case _ => + false + } + case _ => false + } + } + + private def inferNewFilter(plan: LogicalPlan, dynamicPrunings: ExpressionSet): LogicalPlan = { + val newPredicates = dynamicPrunings + .filter { c => + c.references.nonEmpty && c.references.subsetOf(plan.outputSet) && c.deterministic + } -- plan.constraints + if (newPredicates.isEmpty) { + plan + } else { + Filter(newPredicates.reduce(And), plan) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala index 182c2aaad581c..324003089b5c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala @@ -112,7 +112,7 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper { * using column statistics if they are available, otherwise we use the config value of * `spark.sql.optimizer.joinFilterRatio`. */ - private def pruningHasBenefit( + private[sql] def pruningHasBenefit( partExpr: Expression, partPlan: LogicalPlan, otherExpr: Expression, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala index cd7c4415d6f2b..5f3df05bdf07c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala @@ -1388,6 +1388,111 @@ abstract class DynamicPartitionPruningSuiteBase checkAnswer(df, Nil) } } + + test("Infer filters from DPP") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") { + withTable("df1", "df2", "df3", "df4") { + spark.range(1000) + .select(col("id"), col("id").as("k")) + .write + .partitionBy("k") + .format(tableFormat) + .mode("overwrite") + .saveAsTable("df1") + + spark.range(1000) + .select(col("id"), col("id").as("k")) + .write + .partitionBy("k") + .format(tableFormat) + .mode("overwrite") + .saveAsTable("df2") + + spark.range(5) + .select(col("id"), col("id").as("k")) + .write + .partitionBy("k") + .format(tableFormat) + .mode("overwrite") + .saveAsTable("df3") + + spark.range(100) + .select(col("id"), col("id").as("k")) + .write + .format(tableFormat) + .mode("overwrite") + .saveAsTable("df4") + + spark.range(1000) + .select(col("id"), col("id").as("k")) + .write + .format(tableFormat) + .mode("overwrite") + .saveAsTable("df5") + + Given("Inferred DPP on partition column") + Seq(true, false).foreach { infer => + withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> s"$infer") { + val df = sql( + """ + |SELECT t1.id, + | df4.k + |FROM (SELECT df2.id, + | df1.k + | FROM df1 + | JOIN df2 + | ON df1.k = df2.k) t1 + | JOIN df4 + | ON t1.k = df4.k AND df4.id < 2 + |""".stripMargin) + if (infer) { + assert(collectDynamicPruningExpressions(df.queryExecution.executedPlan).size === 2) + } else { + assert(collectDynamicPruningExpressions(df.queryExecution.executedPlan).size === 1) + } + checkAnswer(df, Row(0, 0) :: Row(1, 1) :: Nil) + } + } + + Given("Remove no benefit inferred DPP on partition column") + withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "true") { + val df = sql( + """ + |SELECT t1.id, + | df4.k + |FROM (SELECT df3.id, + | df1.k + | FROM df1 + | JOIN df3 + | ON df1.k = df3.k) t1 + | JOIN df4 + | ON t1.k = df4.k AND df4.id < 2 + |""".stripMargin) + assert(collectDynamicPruningExpressions(df.queryExecution.executedPlan).size === 1) + checkAnswer(df, Row(0, 0) :: Row(1, 1) :: Nil) + } + + Given("Remove inferred DPP on non-partition column") + withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "true") { + val df = sql( + """ + |SELECT t1.id, + | df4.k + |FROM (SELECT df5.id, + | df1.k + | FROM df1 + | JOIN df5 + | ON df1.k = df5.k) t1 + | JOIN df4 + | ON t1.k = df4.k AND df4.id < 2 + |""".stripMargin) + + assert(collectDynamicPruningExpressions(df.queryExecution.executedPlan).size === 1) + checkAnswer(df, Row(0, 0) :: Row(1, 1) :: Nil) + } + } + } + } } class DynamicPartitionPruningSuiteAEOff extends DynamicPartitionPruningSuiteBase {