@@ -27,7 +27,7 @@ import scala.collection.JavaConverters.asScalaBufferConverter
2727import org .apache .parquet .filter2 .predicate ._
2828import org .apache .parquet .filter2 .predicate .SparkFilterApi ._
2929import org .apache .parquet .io .api .Binary
30- import org .apache .parquet .schema .{DecimalMetadata , MessageType , OriginalType , PrimitiveComparator }
30+ import org .apache .parquet .schema .{DecimalMetadata , GroupType , MessageType , OriginalType , PrimitiveComparator , PrimitiveType , Type }
3131import org .apache .parquet .schema .OriginalType ._
3232import org .apache .parquet .schema .PrimitiveType .PrimitiveTypeName
3333import org .apache .parquet .schema .PrimitiveType .PrimitiveTypeName ._
@@ -49,15 +49,35 @@ class ParquetFilters(
4949 pushDownInFilterThreshold : Int ,
5050 caseSensitive : Boolean ) {
5151 // A map which contains parquet field name and data type, if predicate push down applies.
52- private val nameToParquetField : Map [String , ParquetField ] = {
53- // Here we don't flatten the fields in the nested schema but just look up through
54- // root fields. Currently, accessing to nested fields does not push down filters
55- // and it does not support to create filters for them.
56- val primitiveFields =
57- schema.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f =>
58- f.getName -> ParquetField (f.getName,
59- ParquetSchemaType (f.getOriginalType,
60- f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata))
52+ //
53+ // Each key in `nameToParquetField` represents a column; `dots` are used as separators for
54+ // nested columns. If any part of the names contains `dots`, it is quoted to avoid confusion.
55+ // See `org.apache.spark.sql.connector.catalog.quote` for implementation details.
56+ private val nameToParquetField : Map [String , ParquetPrimitiveField ] = {
57+ // Recursively traverse the parquet schema to get primitive fields that can be pushed-down.
58+ // `parentFieldNames` is used to keep track of the current nested level when traversing.
59+ def getPrimitiveFields (
60+ fields : Seq [Type ],
61+ parentFieldNames : Array [String ] = Array .empty): Seq [ParquetPrimitiveField ] = {
62+ fields.flatMap {
63+ case p : PrimitiveType =>
64+ Some (ParquetPrimitiveField (fieldNames = parentFieldNames :+ p.getName,
65+ fieldType = ParquetSchemaType (p.getOriginalType,
66+ p.getPrimitiveTypeName, p.getTypeLength, p.getDecimalMetadata)))
67+ // Note that when g is a `Struct`, `g.getOriginalType` is `null`.
68+ // When g is a `Map`, `g.getOriginalType` is `MAP`.
69+ // When g is a `List`, `g.getOriginalType` is `LIST`.
70+ case g : GroupType if g.getOriginalType == null =>
71+ getPrimitiveFields(g.getFields.asScala, parentFieldNames :+ g.getName)
72+ // Parquet only supports push-down for primitive types; as a result, Map and List types
73+ // are removed.
74+ case _ => None
75+ }
76+ }
77+
78+ val primitiveFields = getPrimitiveFields(schema.getFields.asScala).map { field =>
79+ import org .apache .spark .sql .connector .catalog .CatalogV2Implicits .MultipartIdentifierHelper
80+ (field.fieldNames.toSeq.quoted, field)
6181 }
6282 if (caseSensitive) {
6383 primitiveFields.toMap
@@ -75,13 +95,13 @@ class ParquetFilters(
7595 }
7696
7797 /**
78- * Holds a single field information stored in the underlying parquet file.
98+ * Holds a single primitive field information stored in the underlying parquet file.
7999 *
80- * @param fieldName field name in parquet file
100+ * @param fieldNames a field name as an array of string multi-identifier in parquet file
81101 * @param fieldType field type related info in parquet file
82102 */
83- private case class ParquetField (
84- fieldName : String ,
103+ private case class ParquetPrimitiveField (
104+ fieldNames : Array [ String ] ,
85105 fieldType : ParquetSchemaType )
86106
87107 private case class ParquetSchemaType (
@@ -472,13 +492,8 @@ class ParquetFilters(
472492 case _ => false
473493 }
474494
475- // Parquet does not allow dots in the column name because dots are used as a column path
476- // delimiter. Since Parquet 1.8.2 (PARQUET-389), Parquet accepts the filter predicates
477- // with missing columns. The incorrect results could be got from Parquet when we push down
478- // filters for the column having dots in the names. Thus, we do not push down such filters.
479- // See SPARK-20364.
480495 private def canMakeFilterOn (name : String , value : Any ): Boolean = {
481- nameToParquetField.contains(name) && ! name.contains( " . " ) && valueCanMakeFilterOn(name, value)
496+ nameToParquetField.contains(name) && valueCanMakeFilterOn(name, value)
482497 }
483498
484499 /**
@@ -509,38 +524,38 @@ class ParquetFilters(
509524 predicate match {
510525 case sources.IsNull (name) if canMakeFilterOn(name, null ) =>
511526 makeEq.lift(nameToParquetField(name).fieldType)
512- .map(_(Array ( nameToParquetField(name).fieldName) , null ))
527+ .map(_(nameToParquetField(name).fieldNames , null ))
513528 case sources.IsNotNull (name) if canMakeFilterOn(name, null ) =>
514529 makeNotEq.lift(nameToParquetField(name).fieldType)
515- .map(_(Array ( nameToParquetField(name).fieldName) , null ))
530+ .map(_(nameToParquetField(name).fieldNames , null ))
516531
517532 case sources.EqualTo (name, value) if canMakeFilterOn(name, value) =>
518533 makeEq.lift(nameToParquetField(name).fieldType)
519- .map(_(Array ( nameToParquetField(name).fieldName) , value))
534+ .map(_(nameToParquetField(name).fieldNames , value))
520535 case sources.Not (sources.EqualTo (name, value)) if canMakeFilterOn(name, value) =>
521536 makeNotEq.lift(nameToParquetField(name).fieldType)
522- .map(_(Array ( nameToParquetField(name).fieldName) , value))
537+ .map(_(nameToParquetField(name).fieldNames , value))
523538
524539 case sources.EqualNullSafe (name, value) if canMakeFilterOn(name, value) =>
525540 makeEq.lift(nameToParquetField(name).fieldType)
526- .map(_(Array ( nameToParquetField(name).fieldName) , value))
541+ .map(_(nameToParquetField(name).fieldNames , value))
527542 case sources.Not (sources.EqualNullSafe (name, value)) if canMakeFilterOn(name, value) =>
528543 makeNotEq.lift(nameToParquetField(name).fieldType)
529- .map(_(Array ( nameToParquetField(name).fieldName) , value))
544+ .map(_(nameToParquetField(name).fieldNames , value))
530545
531546 case sources.LessThan (name, value) if canMakeFilterOn(name, value) =>
532547 makeLt.lift(nameToParquetField(name).fieldType)
533- .map(_(Array ( nameToParquetField(name).fieldName) , value))
548+ .map(_(nameToParquetField(name).fieldNames , value))
534549 case sources.LessThanOrEqual (name, value) if canMakeFilterOn(name, value) =>
535550 makeLtEq.lift(nameToParquetField(name).fieldType)
536- .map(_(Array ( nameToParquetField(name).fieldName) , value))
551+ .map(_(nameToParquetField(name).fieldNames , value))
537552
538553 case sources.GreaterThan (name, value) if canMakeFilterOn(name, value) =>
539554 makeGt.lift(nameToParquetField(name).fieldType)
540- .map(_(Array ( nameToParquetField(name).fieldName) , value))
555+ .map(_(nameToParquetField(name).fieldNames , value))
541556 case sources.GreaterThanOrEqual (name, value) if canMakeFilterOn(name, value) =>
542557 makeGtEq.lift(nameToParquetField(name).fieldType)
543- .map(_(Array ( nameToParquetField(name).fieldName) , value))
558+ .map(_(nameToParquetField(name).fieldNames , value))
544559
545560 case sources.And (lhs, rhs) =>
546561 // At here, it is not safe to just convert one side and remove the other side
@@ -591,13 +606,13 @@ class ParquetFilters(
591606 && values.distinct.length <= pushDownInFilterThreshold =>
592607 values.distinct.flatMap { v =>
593608 makeEq.lift(nameToParquetField(name).fieldType)
594- .map(_(Array ( nameToParquetField(name).fieldName) , v))
609+ .map(_(nameToParquetField(name).fieldNames , v))
595610 }.reduceLeftOption(FilterApi .or)
596611
597612 case sources.StringStartsWith (name, prefix)
598613 if pushDownStartWith && canMakeFilterOn(name, prefix) =>
599614 Option (prefix).map { v =>
600- FilterApi .userDefined(binaryColumn(Array ( nameToParquetField(name).fieldName) ),
615+ FilterApi .userDefined(binaryColumn(nameToParquetField(name).fieldNames ),
601616 new UserDefinedPredicate [Binary ] with Serializable {
602617 private val strToBinary = Binary .fromReusedByteArray(v.getBytes)
603618 private val size = strToBinary.length
0 commit comments