Skip to content

Commit 53e0bb3

Browse files
Davies Liuhvanhovell
authored andcommitted
[SC-3623][BRANCH-2.1] Dynamic partition pruning
## What changes were proposed in this pull request? This PR ports databricks/runtime#31 over to DB Spark branch-2.1. This adds dynamic partition pruning (see the original PR for more details on the feature). This was non-trivial to port because the read path has changed significantly in Spark 2.1. We only support partition pruning for `HadoopFsRelation`. This relation is exclusively read using the `FileSourceScanExec` and so I have only implemented dynamic partition pruning for this scan operator (in 2.0 we support dynamic partition pruning for both `RowScanExec` and `FileSourceScanExec`). ## How was this patch tested? Added a test to `SQLQuerySuite`. Author: Herman van Hovell <[email protected]> Author: Davies Liu <[email protected]> Closes apache#131 from hvanhovell/dynamic_partition_pruning.
1 parent 27e7e3d commit 53e0bb3

File tree

8 files changed

+185
-10
lines changed

8 files changed

+185
-10
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
5454

5555
// Filter the plan by applying left semi and left anti joins.
5656
withSubquery.foldLeft(newFilter) {
57+
case (p, PredicateSubquery(_, Seq(e: Expression), _, _)) if !e.isInstanceOf[Predicate] =>
58+
// This predicate subquery is inserted by PartitionPruning rule, should not be rewritten.
59+
p
5760
case (p, PredicateSubquery(sub, conditions, _, _)) =>
5861
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
5962
Join(outerPlan, sub, LeftSemi, joinCond)

sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,30 @@ case class FileSourceScanExec(
156156
false
157157
}
158158

159-
@transient private lazy val selectedPartitions = relation.location.listFiles(partitionFilters)
159+
private def isDynamicPartitionFilter(e: Expression): Boolean =
160+
e.find(_.isInstanceOf[PlanExpression[_]]).isDefined
161+
162+
@transient private lazy val selectedPartitions =
163+
relation.location.listFiles(partitionFilters.filterNot(isDynamicPartitionFilter))
164+
165+
// We can only determine the actual partitions at runtime when a dynamic partition filter is
166+
// present. This is because such a filter relies on information that is only available at run
167+
// time (for instance the keys used in the other side of a join).
168+
@transient private lazy val dynamicallySelectedPartitions = {
169+
val dynamicPartitionFilters = partitionFilters.filter(isDynamicPartitionFilter)
170+
if (dynamicPartitionFilters.nonEmpty) {
171+
val predicate = dynamicPartitionFilters.reduce(And)
172+
val partitionColumns = relation.partitionSchema
173+
val boundPredicate = newPredicate(predicate.transform {
174+
case a: AttributeReference =>
175+
val index = partitionColumns.indexWhere(a.name == _.name)
176+
BoundReference(index, partitionColumns(index).dataType, nullable = true)
177+
}, Nil)
178+
selectedPartitions.filter(p => boundPredicate.eval(p.values))
179+
} else {
180+
selectedPartitions
181+
}
182+
}
160183

161184
override val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = {
162185
val bucketSpec = if (relation.sparkSession.sessionState.conf.bucketingEnabled) {
@@ -261,9 +284,9 @@ case class FileSourceScanExec(
261284

262285
relation.bucketSpec match {
263286
case Some(bucketing) if relation.sparkSession.sessionState.conf.bucketingEnabled =>
264-
createBucketedReadRDD(bucketing, readFile, selectedPartitions, relation)
287+
createBucketedReadRDD(bucketing, readFile, dynamicallySelectedPartitions, relation)
265288
case _ =>
266-
createNonBucketedReadRDD(readFile, selectedPartitions, relation)
289+
createNonBucketedReadRDD(readFile, dynamicallySelectedPartitions, relation)
267290
}
268291
}
269292

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala

Lines changed: 104 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@ package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.sql.ExperimentalMethods
2121
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
22-
import org.apache.spark.sql.catalyst.optimizer.Optimizer
23-
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
22+
import org.apache.spark.sql.catalyst.expressions._
23+
import org.apache.spark.sql.catalyst.optimizer.{CombineFilters, Optimizer, PushDownPredicate, PushPredicateThroughJoin}
24+
import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, LeftSemi, RightOuter}
25+
import org.apache.spark.sql.catalyst.plans.logical._
2426
import org.apache.spark.sql.catalyst.rules.Rule
2527
import org.apache.spark.sql.execution.closure.TranslateClosureOptimizerRule
26-
import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions
28+
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, PruneFileSourcePartitions}
2729
import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate
2830
import org.apache.spark.sql.internal.SQLConf
2931

@@ -43,10 +45,109 @@ class SparkOptimizer(
4345
// Java closure to Catalyst expressions
4446
Batch("Translate Closure", Once, new TranslateClosureOptimizerRule(conf))) ++
4547
defaultOptimizers :+
48+
Batch("PartitionPruning", Once,
49+
PartitionPruning(conf),
50+
OptimizeSubqueries) :+
51+
Batch("Pushdown pruning subquery", fixedPoint,
52+
PushPredicateThroughJoin,
53+
PushDownPredicate,
54+
CombineFilters) :+
4655
Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog, conf)) :+
4756
Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+
4857
Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+
4958
Batch("User Provided Optimizers", fixedPoint,
5059
experimentalMethods.extraOptimizations ++ extraOptimizationRules: _*)
5160
}
5261
}
62+
63+
/**
64+
* Inserts a predicate for partitioned table when partition column is used as join key.
65+
*/
66+
case class PartitionPruning(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper {
67+
68+
/**
69+
* Returns whether an attribute is a partition column or not.
70+
*/
71+
private def isPartitioned(a: Expression, plan: LogicalPlan): Boolean = {
72+
plan.foreach {
73+
case l: LogicalRelation if a.references.subsetOf(l.outputSet) =>
74+
l.relation match {
75+
case fs: HadoopFsRelation =>
76+
val partitionColumns = AttributeSet(
77+
l.resolve(fs.partitionSchema, fs.sparkSession.sessionState.analyzer.resolver))
78+
if (a.references.subsetOf(partitionColumns)) {
79+
return true
80+
}
81+
case _ =>
82+
}
83+
case _ =>
84+
}
85+
false
86+
}
87+
88+
private def insertPredicate(
89+
partitionedPlan: LogicalPlan,
90+
partitioned: Expression,
91+
otherPlan: LogicalPlan,
92+
value: Expression): LogicalPlan = {
93+
val alias = value match {
94+
case a: Attribute => a
95+
case o => Alias(o, o.toString)()
96+
}
97+
Filter(
98+
PredicateSubquery(Aggregate(Seq(alias), Seq(alias), otherPlan), Seq(partitioned)),
99+
partitionedPlan)
100+
}
101+
102+
def apply(plan: LogicalPlan): LogicalPlan = {
103+
if (!conf.partitionPruning) {
104+
return plan
105+
}
106+
plan transformUp {
107+
case join @ Join(left, right, joinType, Some(condition)) =>
108+
var newLeft = left
109+
var newRight = right
110+
splitConjunctivePredicates(condition).foreach {
111+
case e @ EqualTo(a: Expression, b: Expression) =>
112+
// they should come from different sides, otherwise should be pushed down
113+
val (l, r) = if (a.references.subsetOf(left.outputSet) &&
114+
b.references.subsetOf(right.outputSet)) {
115+
a -> b
116+
} else {
117+
b -> a
118+
}
119+
if (isPartitioned(l, left) && hasHighlySelectivePredicate(right) &&
120+
(joinType == Inner || joinType == LeftSemi || joinType == RightOuter) &&
121+
r.references.subsetOf(right.outputSet)) {
122+
newLeft = insertPredicate(newLeft, l, right, r)
123+
} else if (isPartitioned(r, right) && hasHighlySelectivePredicate(left) &&
124+
(joinType == Inner || joinType == LeftOuter) &&
125+
l.references.subsetOf(left.outputSet)) {
126+
newRight = insertPredicate(newRight, r, left, l)
127+
}
128+
case _ =>
129+
}
130+
Join(newLeft, newRight, joinType, Some(condition))
131+
}
132+
}
133+
134+
/**
135+
* Returns whether an expression is highly selective or not.
136+
*/
137+
def isHighlySelective(e: Expression): Boolean = e match {
138+
case Not(expr) => isHighlySelective(expr)
139+
case And(l, r) => isHighlySelective(l) || isHighlySelective(r)
140+
case Or(l, r) => isHighlySelective(l) && isHighlySelective(r)
141+
case _: BinaryComparison => true
142+
case _: In | _: InSet => true
143+
case _: StringPredicate => true
144+
case _ => false
145+
}
146+
147+
def hasHighlySelectivePredicate(plan: LogicalPlan): Boolean = {
148+
plan.find {
149+
case f: Filter => isHighlySelective(f.condition)
150+
case _ => false
151+
}.isDefined
152+
}
153+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ object FileSourceStrategy extends Strategy with Logging {
6767
val normalizedFilters = filters.map { e =>
6868
e transform {
6969
case a: AttributeReference =>
70-
a.withName(l.output.find(_.semanticEquals(a)).get.name)
70+
a.withName(l.output.find(_.semanticEquals(a)).getOrElse(a).name)
7171
}
7272
}
7373

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,9 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {
5252
logicalRelation.resolve(
5353
partitionSchema, sparkSession.sessionState.analyzer.resolver)
5454
val partitionSet = AttributeSet(partitionColumns)
55-
val partitionKeyFilters =
56-
ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet)))
55+
val partitionKeyFilters = ExpressionSet(normalizedFilters.filter { f =>
56+
f.references.subsetOf(partitionSet) && f.find(_.isInstanceOf[SubqueryExpression]).isEmpty
57+
})
5758

5859
if (partitionKeyFilters.nonEmpty) {
5960
val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq)

sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,12 @@ object SQLConf {
417417
.booleanConf
418418
.createWithDefault(true)
419419

420+
val PARTITION_PRUNING = SQLConfigBuilder("spark.sql.dynamicPartitionPruning")
421+
.internal()
422+
.doc("When true, we will generate predicate for partition column when it's used as join key")
423+
.booleanConf
424+
.createWithDefault(true)
425+
420426
val WHOLESTAGE_CODEGEN_ENABLED = SQLConfigBuilder("spark.sql.codegen.wholeStage")
421427
.internal()
422428
.doc("When true, the whole stage (of multiple operators) will be compiled into single java" +
@@ -790,6 +796,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
790796

791797
def enableTwoLevelAggMap: Boolean = getConf(ENABLE_TWOLEVEL_AGG_MAP)
792798

799+
def partitionPruning: Boolean = getConf(PARTITION_PRUNING)
800+
793801
def variableSubstituteEnabled: Boolean = getConf(VARIABLE_SUBSTITUTE_ENABLED)
794802

795803
def variableSubstituteDepth: Int = getConf(VARIABLE_SUBSTITUTE_DEPTH)

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ import java.math.MathContext
2222
import java.sql.Timestamp
2323

2424
import org.apache.spark.{AccumulatorSuite, SparkException}
25+
import org.apache.spark.sql.catalyst.expressions.PlanExpression
2526
import org.apache.spark.sql.catalyst.util.StringUtils
26-
import org.apache.spark.sql.execution.aggregate
27+
import org.apache.spark.sql.execution.{aggregate, FileSourceScanExec}
2728
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec}
2829
import org.apache.spark.sql.functions._
2930
import org.apache.spark.sql.internal.SQLConf
@@ -2086,6 +2087,20 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
20862087
}
20872088
}
20882089

2090+
test("dynamic partition pruning") {
2091+
withTempDir { dir =>
2092+
val df = spark.range(100).selectExpr("id", "id as k")
2093+
df.write.mode("overwrite").partitionBy("k").parquet(dir.toString)
2094+
val df2 = spark.read.parquet(dir.toString).join(df.filter("id < 2"), "k")
2095+
assert(df2.queryExecution.executedPlan.find {
2096+
case s: FileSourceScanExec =>
2097+
s.partitionFilters.exists(_.find(_.isInstanceOf[PlanExpression[_]]).isDefined)
2098+
case o => false
2099+
}.isDefined, "Parquet scan should have partition predicate")
2100+
checkAnswer(df2, Row(0, 0, 0) :: Row(1, 1, 1) :: Nil)
2101+
}
2102+
}
2103+
20892104
test("SPARK-14986: Outer lateral view with empty generate expression") {
20902105
checkAnswer(
20912106
sql("select nil from (select 1 as x ) x lateral view outer explode(array()) n as nil"),

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2011,6 +2011,30 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
20112011
}
20122012
}
20132013

2014+
test("dynamic partition pruning") {
2015+
withTable("df1", "df2") {
2016+
spark.range(100)
2017+
.select($"id", $"id".as("k"))
2018+
.write
2019+
.partitionBy("k")
2020+
.format("parquet")
2021+
.mode("overwrite")
2022+
.saveAsTable("df1")
2023+
2024+
spark.range(100)
2025+
.select($"id", $"id".as("k"))
2026+
.write
2027+
.partitionBy("k")
2028+
.format("parquet")
2029+
.mode("overwrite")
2030+
.saveAsTable("df2")
2031+
2032+
checkAnswer(
2033+
sql("select df1.id, df2.k from df1 join df2 on df1.k = df2.k and df2.id < 2"),
2034+
Row(0, 0) :: Row(1, 1) :: Nil)
2035+
}
2036+
}
2037+
20142038
def testCommandAvailable(command: String): Boolean = {
20152039
val attempt = Try(Process(command).run(ProcessLogger(_ => ())).exitValue())
20162040
attempt.isSuccess && attempt.get == 0

0 commit comments

Comments
 (0)