Skip to content

Commit b300ed8

Browse files
dbtsaicloud-fan
authored andcommitted
[SPARK-25556][SPARK-17636][SPARK-31026][SPARK-31060][SQL][TEST-HIVE1.2] Nested Column Predicate Pushdown for Parquet
### What changes were proposed in this pull request? 1. `DataSourceStrategy.scala` is extended to create `org.apache.spark.sql.sources.Filter` from nested expressions. 2. Translation from nested `org.apache.spark.sql.sources.Filter` to `org.apache.parquet.filter2.predicate.FilterPredicate` is implemented to support nested predicate pushdown for Parquet. ### Why are the changes needed? Better performance for handling nested predicate pushdown. ### Does this PR introduce any user-facing change? No ### How was this patch tested? New tests are added. Closes #27728 from dbtsai/SPARK-17636. Authored-by: DB Tsai <[email protected]> Signed-off-by: Wenchen Fan <[email protected]> (cherry picked from commit cb0db21) Signed-off-by: Wenchen Fan <[email protected]>
1 parent a7c58b1 commit b300ed8

File tree

14 files changed

+852
-480
lines changed

14 files changed

+852
-480
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ package org.apache.spark.sql.connector.catalog
2020
import org.apache.spark.sql.AnalysisException
2121
import org.apache.spark.sql.catalyst.TableIdentifier
2222
import org.apache.spark.sql.catalyst.catalog.BucketSpec
23+
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
2324
import org.apache.spark.sql.connector.expressions.{BucketTransform, IdentityTransform, LogicalExpressions, Transform}
25+
import org.apache.spark.sql.internal.SQLConf
2426

2527
/**
2628
* Conversion helpers for working with v2 [[CatalogPlugin]].
@@ -132,4 +134,10 @@ private[sql] object CatalogV2Implicits {
132134
part
133135
}
134136
}
137+
138+
private lazy val catalystSqlParser = new CatalystSqlParser(SQLConf.get)
139+
140+
def parseColumnPath(name: String): Seq[String] = {
141+
catalystSqlParser.parseMultipartIdentifier(name)
142+
}
135143
}

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1857,6 +1857,17 @@ object SQLConf {
18571857
.booleanConf
18581858
.createWithDefault(true)
18591859

1860+
val NESTED_PREDICATE_PUSHDOWN_ENABLED =
1861+
buildConf("spark.sql.optimizer.nestedPredicatePushdown.enabled")
1862+
.internal()
1863+
.doc("When true, Spark tries to push down predicates for nested columns and or names " +
1864+
"containing `dots` to data sources. Currently, Parquet implements both optimizations " +
1865+
"while ORC only supports predicates for names containing `dots`. The other data sources" +
1866+
"don't support this feature yet.")
1867+
.version("3.0.0")
1868+
.booleanConf
1869+
.createWithDefault(true)
1870+
18601871
val SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED =
18611872
buildConf("spark.sql.optimizer.serializer.nestedSchemaPruning.enabled")
18621873
.internal()
@@ -2790,6 +2801,8 @@ class SQLConf extends Serializable with Logging {
27902801

27912802
def nestedSchemaPruningEnabled: Boolean = getConf(NESTED_SCHEMA_PRUNING_ENABLED)
27922803

2804+
def nestedPredicatePushdownEnabled: Boolean = getConf(NESTED_PREDICATE_PUSHDOWN_ENABLED)
2805+
27932806
def serializerNestedSchemaPruningEnabled: Boolean =
27942807
getConf(SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED)
27952808

sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.sources
1919

2020
import org.apache.spark.annotation.{Evolving, Stable}
21+
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.parseColumnPath
2122

2223
////////////////////////////////////////////////////////////////////////////////////////////////////
2324
// This file defines all the filters that we can push down to the data sources.
@@ -32,6 +33,10 @@ import org.apache.spark.annotation.{Evolving, Stable}
3233
sealed abstract class Filter {
3334
/**
3435
* List of columns that are referenced by this filter.
36+
*
37+
* Note that, each element in `references` represents a column; `dots` are used as separators
38+
* for nested columns. If any part of the names contains `dots`, it is quoted to avoid confusion.
39+
*
3540
* @since 2.1.0
3641
*/
3742
def references: Array[String]
@@ -40,12 +45,32 @@ sealed abstract class Filter {
4045
case f: Filter => f.references
4146
case _ => Array.empty
4247
}
48+
49+
/**
50+
* List of columns that are referenced by this filter.
51+
*
52+
* @return each element is a column name as an array of string multi-identifier
53+
* @since 3.0.0
54+
*/
55+
def v2references: Array[Array[String]] = {
56+
this.references.map(parseColumnPath(_).toArray)
57+
}
58+
59+
/**
60+
* If any of the references of this filter contains nested column
61+
*/
62+
private[sql] def containsNestedColumn: Boolean = {
63+
this.v2references.exists(_.length > 1)
64+
}
4365
}
4466

4567
/**
46-
* A filter that evaluates to `true` iff the attribute evaluates to a value
68+
* A filter that evaluates to `true` iff the column evaluates to a value
4769
* equal to `value`.
4870
*
71+
* @param attribute of the column to be evaluated; `dots` are used as separators
72+
* for nested columns. If any part of the names contains `dots`,
73+
* it is quoted to avoid confusion.
4974
* @since 1.3.0
5075
*/
5176
@Stable
@@ -58,6 +83,9 @@ case class EqualTo(attribute: String, value: Any) extends Filter {
5883
* in that it returns `true` (rather than NULL) if both inputs are NULL, and `false`
5984
* (rather than NULL) if one of the input is NULL and the other is not NULL.
6085
*
86+
* @param attribute of the column to be evaluated; `dots` are used as separators
87+
* for nested columns. If any part of the names contains `dots`,
88+
* it is quoted to avoid confusion.
6189
* @since 1.5.0
6290
*/
6391
@Stable
@@ -69,6 +97,9 @@ case class EqualNullSafe(attribute: String, value: Any) extends Filter {
6997
* A filter that evaluates to `true` iff the attribute evaluates to a value
7098
* greater than `value`.
7199
*
100+
* @param attribute of the column to be evaluated; `dots` are used as separators
101+
* for nested columns. If any part of the names contains `dots`,
102+
* it is quoted to avoid confusion.
72103
* @since 1.3.0
73104
*/
74105
@Stable
@@ -80,6 +111,9 @@ case class GreaterThan(attribute: String, value: Any) extends Filter {
80111
* A filter that evaluates to `true` iff the attribute evaluates to a value
81112
* greater than or equal to `value`.
82113
*
114+
* @param attribute of the column to be evaluated; `dots` are used as separators
115+
* for nested columns. If any part of the names contains `dots`,
116+
* it is quoted to avoid confusion.
83117
* @since 1.3.0
84118
*/
85119
@Stable
@@ -91,6 +125,9 @@ case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter {
91125
* A filter that evaluates to `true` iff the attribute evaluates to a value
92126
* less than `value`.
93127
*
128+
* @param attribute of the column to be evaluated; `dots` are used as separators
129+
* for nested columns. If any part of the names contains `dots`,
130+
* it is quoted to avoid confusion.
94131
* @since 1.3.0
95132
*/
96133
@Stable
@@ -102,6 +139,9 @@ case class LessThan(attribute: String, value: Any) extends Filter {
102139
* A filter that evaluates to `true` iff the attribute evaluates to a value
103140
* less than or equal to `value`.
104141
*
142+
* @param attribute of the column to be evaluated; `dots` are used as separators
143+
* for nested columns. If any part of the names contains `dots`,
144+
* it is quoted to avoid confusion.
105145
* @since 1.3.0
106146
*/
107147
@Stable
@@ -112,6 +152,9 @@ case class LessThanOrEqual(attribute: String, value: Any) extends Filter {
112152
/**
113153
* A filter that evaluates to `true` iff the attribute evaluates to one of the values in the array.
114154
*
155+
* @param attribute of the column to be evaluated; `dots` are used as separators
156+
* for nested columns. If any part of the names contains `dots`,
157+
* it is quoted to avoid confusion.
115158
* @since 1.3.0
116159
*/
117160
@Stable
@@ -139,6 +182,9 @@ case class In(attribute: String, values: Array[Any]) extends Filter {
139182
/**
140183
* A filter that evaluates to `true` iff the attribute evaluates to null.
141184
*
185+
* @param attribute of the column to be evaluated; `dots` are used as separators
186+
* for nested columns. If any part of the names contains `dots`,
187+
* it is quoted to avoid confusion.
142188
* @since 1.3.0
143189
*/
144190
@Stable
@@ -149,6 +195,9 @@ case class IsNull(attribute: String) extends Filter {
149195
/**
150196
* A filter that evaluates to `true` iff the attribute evaluates to a non-null value.
151197
*
198+
* @param attribute of the column to be evaluated; `dots` are used as separators
199+
* for nested columns. If any part of the names contains `dots`,
200+
* it is quoted to avoid confusion.
152201
* @since 1.3.0
153202
*/
154203
@Stable
@@ -190,6 +239,9 @@ case class Not(child: Filter) extends Filter {
190239
* A filter that evaluates to `true` iff the attribute evaluates to
191240
* a string that starts with `value`.
192241
*
242+
* @param attribute of the column to be evaluated; `dots` are used as separators
243+
* for nested columns. If any part of the names contains `dots`,
244+
* it is quoted to avoid confusion.
193245
* @since 1.3.1
194246
*/
195247
@Stable
@@ -201,6 +253,9 @@ case class StringStartsWith(attribute: String, value: String) extends Filter {
201253
* A filter that evaluates to `true` iff the attribute evaluates to
202254
* a string that ends with `value`.
203255
*
256+
* @param attribute of the column to be evaluated; `dots` are used as separators
257+
* for nested columns. If any part of the names contains `dots`,
258+
* it is quoted to avoid confusion.
204259
* @since 1.3.1
205260
*/
206261
@Stable
@@ -212,6 +267,9 @@ case class StringEndsWith(attribute: String, value: String) extends Filter {
212267
* A filter that evaluates to `true` iff the attribute evaluates to
213268
* a string that contains the string `value`.
214269
*
270+
* @param attribute of the column to be evaluated; `dots` are used as separators
271+
* for nested columns. If any part of the names contains `dots`,
272+
* it is quoted to avoid confusion.
215273
* @since 1.3.1
216274
*/
217275
@Stable

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -652,10 +652,19 @@ object DataSourceStrategy {
652652
*/
653653
object PushableColumn {
654654
def unapply(e: Expression): Option[String] = {
655-
def helper(e: Expression) = e match {
656-
case a: Attribute => Some(a.name)
655+
val nestedPredicatePushdownEnabled = SQLConf.get.nestedPredicatePushdownEnabled
656+
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
657+
def helper(e: Expression): Option[Seq[String]] = e match {
658+
case a: Attribute =>
659+
if (nestedPredicatePushdownEnabled || !a.name.contains(".")) {
660+
Some(Seq(a.name))
661+
} else {
662+
None
663+
}
664+
case s: GetStructField if nestedPredicatePushdownEnabled =>
665+
helper(s.child).map(_ :+ s.childSchema(s.ordinal).name)
657666
case _ => None
658667
}
659-
helper(e)
668+
helper(e).map(_.quoted)
660669
}
661670
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala

Lines changed: 47 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import scala.collection.JavaConverters.asScalaBufferConverter
2727
import org.apache.parquet.filter2.predicate._
2828
import org.apache.parquet.filter2.predicate.SparkFilterApi._
2929
import org.apache.parquet.io.api.Binary
30-
import org.apache.parquet.schema.{DecimalMetadata, MessageType, OriginalType, PrimitiveComparator}
30+
import org.apache.parquet.schema.{DecimalMetadata, GroupType, MessageType, OriginalType, PrimitiveComparator, PrimitiveType, Type}
3131
import org.apache.parquet.schema.OriginalType._
3232
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
3333
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._
@@ -49,15 +49,35 @@ class ParquetFilters(
4949
pushDownInFilterThreshold: Int,
5050
caseSensitive: Boolean) {
5151
// A map which contains parquet field name and data type, if predicate push down applies.
52-
private val nameToParquetField : Map[String, ParquetField] = {
53-
// Here we don't flatten the fields in the nested schema but just look up through
54-
// root fields. Currently, accessing to nested fields does not push down filters
55-
// and it does not support to create filters for them.
56-
val primitiveFields =
57-
schema.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f =>
58-
f.getName -> ParquetField(f.getName,
59-
ParquetSchemaType(f.getOriginalType,
60-
f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata))
52+
//
53+
// Each key in `nameToParquetField` represents a column; `dots` are used as separators for
54+
// nested columns. If any part of the names contains `dots`, it is quoted to avoid confusion.
55+
// See `org.apache.spark.sql.connector.catalog.quote` for implementation details.
56+
private val nameToParquetField : Map[String, ParquetPrimitiveField] = {
57+
// Recursively traverse the parquet schema to get primitive fields that can be pushed-down.
58+
// `parentFieldNames` is used to keep track of the current nested level when traversing.
59+
def getPrimitiveFields(
60+
fields: Seq[Type],
61+
parentFieldNames: Array[String] = Array.empty): Seq[ParquetPrimitiveField] = {
62+
fields.flatMap {
63+
case p: PrimitiveType =>
64+
Some(ParquetPrimitiveField(fieldNames = parentFieldNames :+ p.getName,
65+
fieldType = ParquetSchemaType(p.getOriginalType,
66+
p.getPrimitiveTypeName, p.getTypeLength, p.getDecimalMetadata)))
67+
// Note that when g is a `Struct`, `g.getOriginalType` is `null`.
68+
// When g is a `Map`, `g.getOriginalType` is `MAP`.
69+
// When g is a `List`, `g.getOriginalType` is `LIST`.
70+
case g: GroupType if g.getOriginalType == null =>
71+
getPrimitiveFields(g.getFields.asScala, parentFieldNames :+ g.getName)
72+
// Parquet only supports push-down for primitive types; as a result, Map and List types
73+
// are removed.
74+
case _ => None
75+
}
76+
}
77+
78+
val primitiveFields = getPrimitiveFields(schema.getFields.asScala).map { field =>
79+
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
80+
(field.fieldNames.toSeq.quoted, field)
6181
}
6282
if (caseSensitive) {
6383
primitiveFields.toMap
@@ -75,13 +95,13 @@ class ParquetFilters(
7595
}
7696

7797
/**
78-
* Holds a single field information stored in the underlying parquet file.
98+
* Holds a single primitive field information stored in the underlying parquet file.
7999
*
80-
* @param fieldName field name in parquet file
100+
* @param fieldNames a field name as an array of string multi-identifier in parquet file
81101
* @param fieldType field type related info in parquet file
82102
*/
83-
private case class ParquetField(
84-
fieldName: String,
103+
private case class ParquetPrimitiveField(
104+
fieldNames: Array[String],
85105
fieldType: ParquetSchemaType)
86106

87107
private case class ParquetSchemaType(
@@ -472,13 +492,8 @@ class ParquetFilters(
472492
case _ => false
473493
}
474494

475-
// Parquet does not allow dots in the column name because dots are used as a column path
476-
// delimiter. Since Parquet 1.8.2 (PARQUET-389), Parquet accepts the filter predicates
477-
// with missing columns. The incorrect results could be got from Parquet when we push down
478-
// filters for the column having dots in the names. Thus, we do not push down such filters.
479-
// See SPARK-20364.
480495
private def canMakeFilterOn(name: String, value: Any): Boolean = {
481-
nameToParquetField.contains(name) && !name.contains(".") && valueCanMakeFilterOn(name, value)
496+
nameToParquetField.contains(name) && valueCanMakeFilterOn(name, value)
482497
}
483498

484499
/**
@@ -509,38 +524,38 @@ class ParquetFilters(
509524
predicate match {
510525
case sources.IsNull(name) if canMakeFilterOn(name, null) =>
511526
makeEq.lift(nameToParquetField(name).fieldType)
512-
.map(_(Array(nameToParquetField(name).fieldName), null))
527+
.map(_(nameToParquetField(name).fieldNames, null))
513528
case sources.IsNotNull(name) if canMakeFilterOn(name, null) =>
514529
makeNotEq.lift(nameToParquetField(name).fieldType)
515-
.map(_(Array(nameToParquetField(name).fieldName), null))
530+
.map(_(nameToParquetField(name).fieldNames, null))
516531

517532
case sources.EqualTo(name, value) if canMakeFilterOn(name, value) =>
518533
makeEq.lift(nameToParquetField(name).fieldType)
519-
.map(_(Array(nameToParquetField(name).fieldName), value))
534+
.map(_(nameToParquetField(name).fieldNames, value))
520535
case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) =>
521536
makeNotEq.lift(nameToParquetField(name).fieldType)
522-
.map(_(Array(nameToParquetField(name).fieldName), value))
537+
.map(_(nameToParquetField(name).fieldNames, value))
523538

524539
case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) =>
525540
makeEq.lift(nameToParquetField(name).fieldType)
526-
.map(_(Array(nameToParquetField(name).fieldName), value))
541+
.map(_(nameToParquetField(name).fieldNames, value))
527542
case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) =>
528543
makeNotEq.lift(nameToParquetField(name).fieldType)
529-
.map(_(Array(nameToParquetField(name).fieldName), value))
544+
.map(_(nameToParquetField(name).fieldNames, value))
530545

531546
case sources.LessThan(name, value) if canMakeFilterOn(name, value) =>
532547
makeLt.lift(nameToParquetField(name).fieldType)
533-
.map(_(Array(nameToParquetField(name).fieldName), value))
548+
.map(_(nameToParquetField(name).fieldNames, value))
534549
case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) =>
535550
makeLtEq.lift(nameToParquetField(name).fieldType)
536-
.map(_(Array(nameToParquetField(name).fieldName), value))
551+
.map(_(nameToParquetField(name).fieldNames, value))
537552

538553
case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) =>
539554
makeGt.lift(nameToParquetField(name).fieldType)
540-
.map(_(Array(nameToParquetField(name).fieldName), value))
555+
.map(_(nameToParquetField(name).fieldNames, value))
541556
case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) =>
542557
makeGtEq.lift(nameToParquetField(name).fieldType)
543-
.map(_(Array(nameToParquetField(name).fieldName), value))
558+
.map(_(nameToParquetField(name).fieldNames, value))
544559

545560
case sources.And(lhs, rhs) =>
546561
// At here, it is not safe to just convert one side and remove the other side
@@ -591,13 +606,13 @@ class ParquetFilters(
591606
&& values.distinct.length <= pushDownInFilterThreshold =>
592607
values.distinct.flatMap { v =>
593608
makeEq.lift(nameToParquetField(name).fieldType)
594-
.map(_(Array(nameToParquetField(name).fieldName), v))
609+
.map(_(nameToParquetField(name).fieldNames, v))
595610
}.reduceLeftOption(FilterApi.or)
596611

597612
case sources.StringStartsWith(name, prefix)
598613
if pushDownStartWith && canMakeFilterOn(name, prefix) =>
599614
Option(prefix).map { v =>
600-
FilterApi.userDefined(binaryColumn(Array(nameToParquetField(name).fieldName)),
615+
FilterApi.userDefined(binaryColumn(nameToParquetField(name).fieldNames),
601616
new UserDefinedPredicate[Binary] with Serializable {
602617
private val strToBinary = Binary.fromReusedByteArray(v.getBytes)
603618
private val size = strToBinary.length

0 commit comments

Comments
 (0)