Skip to content

Commit e76b012

Browse files
Asher Sabancloud-fan
authored andcommitted
[SPARK-23803][SQL] Support bucket pruning
## What changes were proposed in this pull request? support bucket pruning when filtering on a single bucketed column on the following predicates - EqualTo, EqualNullSafe, In, And/Or predicates ## How was this patch tested? refactored unit tests to test the above. based on gatorsmile work in e3c75c6 Author: Asher Saban <[email protected]> Author: asaban <[email protected]> Closes #20915 from sabanas/filter-prune-buckets.
1 parent e9efb62 commit e76b012

File tree

5 files changed

+231
-35
lines changed

5 files changed

+231
-35
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
3636
import org.apache.spark.sql.sources.{BaseRelation, Filter}
3737
import org.apache.spark.sql.types.StructType
3838
import org.apache.spark.util.Utils
39+
import org.apache.spark.util.collection.BitSet
3940

4041
trait DataSourceScanExec extends LeafExecNode with CodegenSupport {
4142
val relation: BaseRelation
@@ -151,6 +152,7 @@ case class RowDataSourceScanExec(
151152
* @param output Output attributes of the scan, including data attributes and partition attributes.
152153
* @param requiredSchema Required schema of the underlying relation, excluding partition columns.
153154
* @param partitionFilters Predicates to use for partition pruning.
155+
* @param optionalBucketSet Bucket ids for bucket pruning
154156
* @param dataFilters Filters on non-partition columns.
155157
* @param tableIdentifier identifier for the table in the metastore.
156158
*/
@@ -159,6 +161,7 @@ case class FileSourceScanExec(
159161
output: Seq[Attribute],
160162
requiredSchema: StructType,
161163
partitionFilters: Seq[Expression],
164+
optionalBucketSet: Option[BitSet],
162165
dataFilters: Seq[Expression],
163166
override val tableIdentifier: Option[TableIdentifier])
164167
extends DataSourceScanExec with ColumnarBatchScan {
@@ -286,7 +289,20 @@ case class FileSourceScanExec(
286289
} getOrElse {
287290
metadata
288291
}
289-
withOptPartitionCount
292+
293+
val withSelectedBucketsCount = relation.bucketSpec.map { spec =>
294+
val numSelectedBuckets = optionalBucketSet.map { b =>
295+
b.cardinality()
296+
} getOrElse {
297+
spec.numBuckets
298+
}
299+
withOptPartitionCount + ("SelectedBucketsCount" ->
300+
s"$numSelectedBuckets out of ${spec.numBuckets}")
301+
} getOrElse {
302+
withOptPartitionCount
303+
}
304+
305+
withSelectedBucketsCount
290306
}
291307

292308
private lazy val inputRDD: RDD[InternalRow] = {
@@ -365,7 +381,7 @@ case class FileSourceScanExec(
365381
selectedPartitions: Seq[PartitionDirectory],
366382
fsRelation: HadoopFsRelation): RDD[InternalRow] = {
367383
logInfo(s"Planning with ${bucketSpec.numBuckets} buckets")
368-
val bucketed =
384+
val filesGroupedToBuckets =
369385
selectedPartitions.flatMap { p =>
370386
p.files.map { f =>
371387
val hosts = getBlockHosts(getBlockLocations(f), 0, f.getLen)
@@ -377,8 +393,17 @@ case class FileSourceScanExec(
377393
.getOrElse(sys.error(s"Invalid bucket file ${f.filePath}"))
378394
}
379395

396+
val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) {
397+
val bucketSet = optionalBucketSet.get
398+
filesGroupedToBuckets.filter {
399+
f => bucketSet.get(f._1)
400+
}
401+
} else {
402+
filesGroupedToBuckets
403+
}
404+
380405
val filePartitions = Seq.tabulate(bucketSpec.numBuckets) { bucketId =>
381-
FilePartition(bucketId, bucketed.getOrElse(bucketId, Nil))
406+
FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Nil))
382407
}
383408

384409
new FileScanRDD(fsRelation.sparkSession, readFile, filePartitions)
@@ -503,6 +528,7 @@ case class FileSourceScanExec(
503528
output.map(QueryPlan.normalizeExprId(_, output)),
504529
requiredSchema,
505530
QueryPlan.normalizePredicates(partitionFilters, output),
531+
optionalBucketSet,
506532
QueryPlan.normalizePredicates(dataFilters, output),
507533
None)
508534
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
package org.apache.spark.sql.execution.datasources
1919

20+
import org.apache.spark.sql.catalyst.expressions.{Attribute, SpecificInternalRow, UnsafeProjection}
21+
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
22+
2023
object BucketingUtils {
2124
// The file name of bucketed data should have 3 parts:
2225
// 1. some other information in the head of file name
@@ -35,5 +38,16 @@ object BucketingUtils {
3538
case other => None
3639
}
3740

41+
// Given bucketColumn, numBuckets and value, returns the corresponding bucketId
42+
def getBucketIdFromValue(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = {
43+
val mutableInternalRow = new SpecificInternalRow(Seq(bucketColumn.dataType))
44+
mutableInternalRow.update(0, value)
45+
46+
val bucketIdGenerator = UnsafeProjection.create(
47+
HashPartitioning(Seq(bucketColumn), numBuckets).partitionIdExpression :: Nil,
48+
bucketColumn :: Nil)
49+
bucketIdGenerator(mutableInternalRow).getInt(0)
50+
}
51+
3852
def bucketIdToString(id: Int): String = f"_$id%05d"
3953
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -312,18 +312,6 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with
312312
case _ => Nil
313313
}
314314

315-
// Get the bucket ID based on the bucketing values.
316-
// Restriction: Bucket pruning works iff the bucketing column has one and only one column.
317-
def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = {
318-
val mutableRow = new SpecificInternalRow(Seq(bucketColumn.dataType))
319-
mutableRow(0) = cast(Literal(value), bucketColumn.dataType).eval(null)
320-
val bucketIdGeneration = UnsafeProjection.create(
321-
HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil,
322-
bucketColumn :: Nil)
323-
324-
bucketIdGeneration(mutableRow).getInt(0)
325-
}
326-
327315
// Based on Public API.
328316
private def pruneFilterProject(
329317
relation: LogicalRelation,

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@ package org.apache.spark.sql.execution.datasources
1919

2020
import org.apache.spark.internal.Logging
2121
import org.apache.spark.sql._
22+
import org.apache.spark.sql.catalyst.catalog.BucketSpec
2223
import org.apache.spark.sql.catalyst.expressions
2324
import org.apache.spark.sql.catalyst.expressions._
2425
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
2526
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
26-
import org.apache.spark.sql.execution.FileSourceScanExec
27-
import org.apache.spark.sql.execution.SparkPlan
27+
import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan}
28+
import org.apache.spark.util.collection.BitSet
2829

2930
/**
3031
* A strategy for planning scans over collections of files that might be partitioned or bucketed
@@ -50,6 +51,91 @@ import org.apache.spark.sql.execution.SparkPlan
5051
* and add it. Proceed to the next file.
5152
*/
5253
object FileSourceStrategy extends Strategy with Logging {
54+
55+
// should prune buckets iff num buckets is greater than 1 and there is only one bucket column
56+
private def shouldPruneBuckets(bucketSpec: Option[BucketSpec]): Boolean = {
57+
bucketSpec match {
58+
case Some(spec) => spec.bucketColumnNames.length == 1 && spec.numBuckets > 1
59+
case None => false
60+
}
61+
}
62+
63+
private def getExpressionBuckets(
64+
expr: Expression,
65+
bucketColumnName: String,
66+
numBuckets: Int): BitSet = {
67+
68+
def getBucketNumber(attr: Attribute, v: Any): Int = {
69+
BucketingUtils.getBucketIdFromValue(attr, numBuckets, v)
70+
}
71+
72+
def getBucketSetFromIterable(attr: Attribute, iter: Iterable[Any]): BitSet = {
73+
val matchedBuckets = new BitSet(numBuckets)
74+
iter
75+
.map(v => getBucketNumber(attr, v))
76+
.foreach(bucketNum => matchedBuckets.set(bucketNum))
77+
matchedBuckets
78+
}
79+
80+
def getBucketSetFromValue(attr: Attribute, v: Any): BitSet = {
81+
val matchedBuckets = new BitSet(numBuckets)
82+
matchedBuckets.set(getBucketNumber(attr, v))
83+
matchedBuckets
84+
}
85+
86+
expr match {
87+
case expressions.Equality(a: Attribute, Literal(v, _)) if a.name == bucketColumnName =>
88+
getBucketSetFromValue(a, v)
89+
case expressions.In(a: Attribute, list)
90+
if list.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName =>
91+
getBucketSetFromIterable(a, list.map(e => e.eval(EmptyRow)))
92+
case expressions.InSet(a: Attribute, hset)
93+
if hset.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName =>
94+
getBucketSetFromIterable(a, hset.map(e => expressions.Literal(e).eval(EmptyRow)))
95+
case expressions.IsNull(a: Attribute) if a.name == bucketColumnName =>
96+
getBucketSetFromValue(a, null)
97+
case expressions.And(left, right) =>
98+
getExpressionBuckets(left, bucketColumnName, numBuckets) &
99+
getExpressionBuckets(right, bucketColumnName, numBuckets)
100+
case expressions.Or(left, right) =>
101+
getExpressionBuckets(left, bucketColumnName, numBuckets) |
102+
getExpressionBuckets(right, bucketColumnName, numBuckets)
103+
case _ =>
104+
val matchedBuckets = new BitSet(numBuckets)
105+
matchedBuckets.setUntil(numBuckets)
106+
matchedBuckets
107+
}
108+
}
109+
110+
private def genBucketSet(
111+
normalizedFilters: Seq[Expression],
112+
bucketSpec: BucketSpec): Option[BitSet] = {
113+
if (normalizedFilters.isEmpty) {
114+
return None
115+
}
116+
117+
val bucketColumnName = bucketSpec.bucketColumnNames.head
118+
val numBuckets = bucketSpec.numBuckets
119+
120+
val normalizedFiltersAndExpr = normalizedFilters
121+
.reduce(expressions.And)
122+
val matchedBuckets = getExpressionBuckets(normalizedFiltersAndExpr, bucketColumnName,
123+
numBuckets)
124+
125+
val numBucketsSelected = matchedBuckets.cardinality()
126+
127+
logInfo {
128+
s"Pruned ${numBuckets - numBucketsSelected} out of $numBuckets buckets."
129+
}
130+
131+
// None means all the buckets need to be scanned
132+
if (numBucketsSelected == numBuckets) {
133+
None
134+
} else {
135+
Some(matchedBuckets)
136+
}
137+
}
138+
53139
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
54140
case PhysicalOperation(projects, filters,
55141
l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _)) =>
@@ -82,6 +168,13 @@ object FileSourceStrategy extends Strategy with Logging {
82168

83169
logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}")
84170

171+
val bucketSpec: Option[BucketSpec] = fsRelation.bucketSpec
172+
val bucketSet = if (shouldPruneBuckets(bucketSpec)) {
173+
genBucketSet(normalizedFilters, bucketSpec.get)
174+
} else {
175+
None
176+
}
177+
85178
val dataColumns =
86179
l.resolve(fsRelation.dataSchema, fsRelation.sparkSession.sessionState.analyzer.resolver)
87180

@@ -111,6 +204,7 @@ object FileSourceStrategy extends Strategy with Logging {
111204
outputAttributes,
112205
outputSchema,
113206
partitionKeyFilters.toSeq,
207+
bucketSet,
114208
dataFilters,
115209
table.map(_.identifier))
116210

0 commit comments

Comments
 (0)