Skip to content

Commit ccba836

Browse files
committed
avoid pushing down too many predicated in partition pruning
1 parent 004aea8 commit ccba836

File tree

4 files changed

+46
-18
lines changed

4 files changed

+46
-18
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -267,20 +267,13 @@ trait PredicateHelper extends Logging {
267267

268268
/**
269269
* Convert an expression to conjunctive normal form for predicate pushdown and partition pruning.
270-
* When expanding predicates, this method groups expressions by their references for reducing
271-
* the size of pushed down predicates and corresponding codegen. In partition pruning strategies,
272-
* we split filters by [[splitConjunctivePredicates]] and partition filters by judging if it's
273-
* references is subset of partCols, if we combine expressions group by reference when expand
274-
* predicate of [[Or]], it won't impact final predicate pruning result since
275-
* [[splitConjunctivePredicates]] won't split [[Or]] expression.
276270
*
277271
* @param condition condition need to be converted
278272
* @return the CNF result as sequence of disjunctive expressions. If the number of expressions
279273
* exceeds threshold on converting `Or`, `Seq.empty` is returned.
280274
*/
281-
def CNFWithGroupExpressionsByReference(condition: Expression): Seq[Expression] = {
282-
conjunctiveNormalForm(condition, (expressions: Seq[Expression]) =>
283-
expressions.groupBy(e => AttributeSet(e.references)).map(_._2.reduceLeft(And)).toSeq)
275+
def CNFConversion(condition: Expression): Seq[Expression] = {
276+
conjunctiveNormalForm(condition, (expressions: Seq[Expression]) => expressions)
284277
}
285278

286279
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,17 @@ private[sql] object PruneFileSourcePartitions
5353
val partitionColumns =
5454
relation.resolve(partitionSchema, sparkSession.sessionState.analyzer.resolver)
5555
val partitionSet = AttributeSet(partitionColumns)
56-
val (partitionFilters, dataFilters) = normalizedFilters.partition(f =>
56+
val (partitionFilters, remainingFilters) = normalizedFilters.partition(f =>
5757
f.references.subsetOf(partitionSet)
5858
)
5959

60-
(ExpressionSet(partitionFilters), dataFilters)
60+
// Try extracting more convertible partition filters from the remaining filters by converting
61+
// them into CNF.
62+
val remainingFilterInCnf = remainingFilters.flatMap(CNFConversion)
63+
val extraPartitionFilters =
64+
remainingFilterInCnf.filter(f => f.references.subsetOf(partitionSet))
65+
66+
(ExpressionSet(partitionFilters ++ extraPartitionFilters), remainingFilters)
6167
}
6268

6369
private def rebuildPhysicalOperation(
@@ -88,12 +94,9 @@ private[sql] object PruneFileSourcePartitions
8894
_,
8995
_))
9096
if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined =>
91-
val predicates = CNFWithGroupExpressionsByReference(filters.reduceLeft(And))
92-
val finalPredicates = if (predicates.nonEmpty) predicates else filters
9397
val (partitionKeyFilters, _) = getPartitionKeyFiltersAndDataFilters(
94-
fsRelation.sparkSession, logicalRelation, partitionSchema, finalPredicates,
98+
fsRelation.sparkSession, logicalRelation, partitionSchema, filters,
9599
logicalRelation.output)
96-
97100
if (partitionKeyFilters.nonEmpty) {
98101
val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq)
99102
val prunedFsRelation =

sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation
2727
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
2828
import org.apache.spark.sql.catalyst.rules.Rule
2929
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
30+
import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions.CNFConversion
3031
import org.apache.spark.sql.internal.SQLConf
3132

3233
/**
@@ -54,9 +55,15 @@ private[sql] class PruneHiveTablePartitions(session: SparkSession)
5455
val normalizedFilters = DataSourceStrategy.normalizeExprs(
5556
filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), relation.output)
5657
val partitionColumnSet = AttributeSet(relation.partitionCols)
57-
ExpressionSet(normalizedFilters.filter { f =>
58+
val (partitionFilters, remainingFilters) = normalizedFilters.partition { f =>
5859
!f.references.isEmpty && f.references.subsetOf(partitionColumnSet)
59-
})
60+
}
61+
// Try extracting more convertible partition filters from the remaining filters by converting
62+
// them into CNF.
63+
val remainingFilterInCnf = remainingFilters.flatMap(CNFConversion)
64+
val extraPartitionFilters = remainingFilterInCnf.filter(f =>
65+
!f.references.isEmpty && f.references.subsetOf(partitionColumnSet))
66+
ExpressionSet(partitionFilters ++ extraPartitionFilters)
6067
}
6168

6269
/**
@@ -103,7 +110,7 @@ private[sql] class PruneHiveTablePartitions(session: SparkSession)
103110
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
104111
case op @ PhysicalOperation(projections, filters, relation: HiveTableRelation)
105112
if filters.nonEmpty && relation.isPartitioned && relation.prunedPartitions.isEmpty =>
106-
val predicates = CNFWithGroupExpressionsByReference(filters.reduceLeft(And))
113+
val predicates = CNFConversion(filters.reduceLeft(And))
107114
val finalPredicates = if (predicates.nonEmpty) predicates else filters
108115
val partitionKeyFilters = getPartitionKeyFilters(finalPredicates, relation)
109116
if (partitionKeyFilters.nonEmpty) {

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PrunePartitionSuiteBase.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,31 @@ abstract class PrunePartitionSuiteBase extends QueryTest with SQLTestUtils with
6767
}
6868
}
6969

70+
test("SPARK-32284: Avoid pushing down too many predicates in partition pruning") {
71+
withTempView("temp") {
72+
withTable("t") {
73+
sql(
74+
s"""
75+
|CREATE TABLE t(i INT, p0 INT, p1 INT)
76+
|USING $format
77+
|PARTITIONED BY (p0, p1)""".stripMargin)
78+
79+
spark.range(0, 10, 1).selectExpr("id as col")
80+
.createOrReplaceTempView("temp")
81+
82+
for (part <- (0 to 25)) {
83+
sql(
84+
s"""
85+
|INSERT OVERWRITE TABLE t PARTITION (p0='$part', p1='$part')
86+
|SELECT col FROM temp""".stripMargin)
87+
}
88+
val scale = 20
89+
val predicate = (1 to scale).map(i => s"(p0 = '$i' AND p1 = '$i')").mkString(" OR ")
90+
assertPrunedPartitions(s"SELECT * FROM t WHERE $predicate", scale)
91+
}
92+
}
93+
}
94+
7095
protected def assertPrunedPartitions(query: String, expected: Long): Unit = {
7196
val plan = sql(query).queryExecution.sparkPlan
7297
assert(getScanExecPartitionSize(plan) == expected)

0 commit comments

Comments
 (0)