Skip to content

Commit 5902afe

Browse files
committed
[SPARK-25207][SQL] Case-insensitve field resolution for filter pushdown when reading Parquet
1 parent 4fb96e5 commit 5902afe

File tree

3 files changed

+121
-25
lines changed

3 files changed

+121
-25
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ class ParquetFileFormat
377377
// Collects all converted Parquet filter predicates. Notice that not all predicates can be
378378
// converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap`
379379
// is used here.
380-
.flatMap(parquetFilters.createFilter(parquetSchema, _))
380+
.flatMap(parquetFilters.createFilter(parquetSchema, _, sqlConf.caseSensitiveAnalysis))
381381
.reduceOption(FilterApi.and)
382382
} else {
383383
None

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet
2020
import java.lang.{Boolean => JBoolean, Double => JDouble, Float => JFloat, Long => JLong}
2121
import java.math.{BigDecimal => JBigDecimal}
2222
import java.sql.{Date, Timestamp}
23+
import java.util.Locale
2324

2425
import scala.collection.JavaConverters.asScalaBufferConverter
2526

@@ -31,7 +32,7 @@ import org.apache.parquet.schema.OriginalType._
3132
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
3233
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._
3334

34-
import org.apache.spark.sql.catalyst.util.DateTimeUtils
35+
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
3536
import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLDate
3637
import org.apache.spark.sql.sources
3738
import org.apache.spark.unsafe.types.UTF8String
@@ -350,25 +351,46 @@ private[parquet] class ParquetFilters(
350351
}
351352

352353
/**
353-
* Returns a map from name of the column to the data type, if predicate push down applies.
354+
* Returns nameMap and typeMap based on different case sensitive mode, if predicate push
355+
* down applies.
354356
*/
355-
private def getFieldMap(dataType: MessageType): Map[String, ParquetSchemaType] = dataType match {
356-
case m: MessageType =>
357-
// Here we don't flatten the fields in the nested schema but just look up through
358-
// root fields. Currently, accessing to nested fields does not push down filters
359-
// and it does not support to create filters for them.
360-
m.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f =>
357+
private def getFieldMaps(dataType: MessageType, caseSensitive: Boolean)
358+
: (Map[String, String], Map[String, ParquetSchemaType]) = {
359+
// Here we don't flatten the fields in the nested schema but just look up through
360+
// root fields. Currently, accessing to nested fields does not push down filters
361+
// and it does not support to create filters for them.
362+
val primitiveFields = dataType.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType())
363+
if (caseSensitive) {
364+
val nameMap = primitiveFields.map { f =>
365+
f.getName -> f.getName
366+
}.toMap
367+
val typeMap = primitiveFields.map { f =>
361368
f.getName -> ParquetSchemaType(
362369
f.getOriginalType, f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata)
363370
}.toMap
364-
case _ => Map.empty[String, ParquetSchemaType]
371+
(nameMap, typeMap)
372+
} else {
373+
// Don't consider ambiguity here, i.e. more than one field is matched in case insensitive
374+
// mode, just skip pushdown for these fields, they will trigger Exception when reading,
375+
// See: SPARK-25132.
376+
val dedupFields = primitiveFields.map { f =>
377+
f.getName -> ParquetSchemaType(
378+
f.getOriginalType, f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata)
379+
}.groupBy(_._1.toLowerCase(Locale.ROOT)).filter(_._2.size == 1).mapValues(_.head)
380+
val nameMap = CaseInsensitiveMap(dedupFields.mapValues(_._1))
381+
val typeMap = CaseInsensitiveMap(dedupFields.mapValues(_._2))
382+
(nameMap, typeMap)
383+
}
365384
}
366385

367386
/**
368387
* Converts data sources filters to Parquet filter predicates.
369388
*/
370-
def createFilter(schema: MessageType, predicate: sources.Filter): Option[FilterPredicate] = {
371-
val nameToType = getFieldMap(schema)
389+
def createFilter(
390+
schema: MessageType,
391+
predicate: sources.Filter,
392+
caseSensitive: Boolean = true): Option[FilterPredicate] = {
393+
val (nameMap, typeMap) = getFieldMaps(schema, caseSensitive)
372394

373395
// Decimal type must make sure that filter value's scale matched the file.
374396
// If doesn't matched, which would cause data corruption.
@@ -381,7 +403,7 @@ private[parquet] class ParquetFilters(
381403
// Parquet's type in the given file should be matched to the value's type
382404
// in the pushed filter in order to push down the filter to Parquet.
383405
def valueCanMakeFilterOn(name: String, value: Any): Boolean = {
384-
value == null || (nameToType(name) match {
406+
value == null || (typeMap(name) match {
385407
case ParquetBooleanType => value.isInstanceOf[JBoolean]
386408
case ParquetByteType | ParquetShortType | ParquetIntegerType => value.isInstanceOf[Number]
387409
case ParquetLongType => value.isInstanceOf[JLong]
@@ -408,7 +430,7 @@ private[parquet] class ParquetFilters(
408430
// filters for the column having dots in the names. Thus, we do not push down such filters.
409431
// See SPARK-20364.
410432
def canMakeFilterOn(name: String, value: Any): Boolean = {
411-
nameToType.contains(name) && !name.contains(".") && valueCanMakeFilterOn(name, value)
433+
typeMap.contains(name) && !name.contains(".") && valueCanMakeFilterOn(name, value)
412434
}
413435

414436
// NOTE:
@@ -428,29 +450,29 @@ private[parquet] class ParquetFilters(
428450

429451
predicate match {
430452
case sources.IsNull(name) if canMakeFilterOn(name, null) =>
431-
makeEq.lift(nameToType(name)).map(_(name, null))
453+
makeEq.lift(typeMap(name)).map(_(nameMap(name), null))
432454
case sources.IsNotNull(name) if canMakeFilterOn(name, null) =>
433-
makeNotEq.lift(nameToType(name)).map(_(name, null))
455+
makeNotEq.lift(typeMap(name)).map(_(nameMap(name), null))
434456

435457
case sources.EqualTo(name, value) if canMakeFilterOn(name, value) =>
436-
makeEq.lift(nameToType(name)).map(_(name, value))
458+
makeEq.lift(typeMap(name)).map(_(nameMap(name), value))
437459
case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) =>
438-
makeNotEq.lift(nameToType(name)).map(_(name, value))
460+
makeNotEq.lift(typeMap(name)).map(_(nameMap(name), value))
439461

440462
case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) =>
441-
makeEq.lift(nameToType(name)).map(_(name, value))
463+
makeEq.lift(typeMap(name)).map(_(nameMap(name), value))
442464
case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) =>
443-
makeNotEq.lift(nameToType(name)).map(_(name, value))
465+
makeNotEq.lift(typeMap(name)).map(_(nameMap(name), value))
444466

445467
case sources.LessThan(name, value) if canMakeFilterOn(name, value) =>
446-
makeLt.lift(nameToType(name)).map(_(name, value))
468+
makeLt.lift(typeMap(name)).map(_(nameMap(name), value))
447469
case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) =>
448-
makeLtEq.lift(nameToType(name)).map(_(name, value))
470+
makeLtEq.lift(typeMap(name)).map(_(nameMap(name), value))
449471

450472
case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) =>
451-
makeGt.lift(nameToType(name)).map(_(name, value))
473+
makeGt.lift(typeMap(name)).map(_(nameMap(name), value))
452474
case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) =>
453-
makeGtEq.lift(nameToType(name)).map(_(name, value))
475+
makeGtEq.lift(typeMap(name)).map(_(nameMap(name), value))
454476

455477
case sources.And(lhs, rhs) =>
456478
// At here, it is not safe to just convert one side if we do not understand the
@@ -477,7 +499,7 @@ private[parquet] class ParquetFilters(
477499
case sources.In(name, values) if canMakeFilterOn(name, values.head)
478500
&& values.distinct.length <= pushDownInFilterThreshold =>
479501
values.distinct.flatMap { v =>
480-
makeEq.lift(nameToType(name)).map(_(name, v))
502+
makeEq.lift(typeMap(name)).map(_(nameMap(name), v))
481503
}.reduceLeftOption(FilterApi.or)
482504

483505
case sources.StringStartsWith(name, prefix)

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,80 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
10211021
}
10221022
}
10231023
}
1024+
1025+
test("Case-insensitive field resolution for pushdown when reading parquet") {
1026+
def testCaseInsensitiveResolution(
1027+
schema: StructType,
1028+
expected: FilterPredicate,
1029+
filter: sources.Filter): Unit = {
1030+
val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema)
1031+
1032+
assertResult(Some(expected)) {
1033+
parquetFilters.createFilter(parquetSchema, filter, caseSensitive = false)
1034+
}
1035+
assertResult(None) {
1036+
parquetFilters.createFilter(parquetSchema, filter, caseSensitive = true)
1037+
}
1038+
}
1039+
1040+
val schema = StructType(Seq(StructField("cint", IntegerType)))
1041+
1042+
testCaseInsensitiveResolution(
1043+
schema, FilterApi.eq(intColumn("cint"), null.asInstanceOf[Integer]), sources.IsNull("CINT"))
1044+
1045+
testCaseInsensitiveResolution(
1046+
schema,
1047+
FilterApi.notEq(intColumn("cint"), null.asInstanceOf[Integer]),
1048+
sources.IsNotNull("CINT"))
1049+
1050+
testCaseInsensitiveResolution(
1051+
schema, FilterApi.eq(intColumn("cint"), 1000: Integer), sources.EqualTo("CINT", 1000))
1052+
1053+
testCaseInsensitiveResolution(
1054+
schema,
1055+
FilterApi.notEq(intColumn("cint"), 1000: Integer),
1056+
sources.Not(sources.EqualTo("CINT", 1000)))
1057+
1058+
testCaseInsensitiveResolution(
1059+
schema, FilterApi.eq(intColumn("cint"), 1000: Integer), sources.EqualNullSafe("CINT", 1000))
1060+
1061+
testCaseInsensitiveResolution(
1062+
schema,
1063+
FilterApi.notEq(intColumn("cint"), 1000: Integer),
1064+
sources.Not(sources.EqualNullSafe("CINT", 1000)))
1065+
1066+
testCaseInsensitiveResolution(
1067+
schema,
1068+
FilterApi.lt(intColumn("cint"), 1000: Integer), sources.LessThan("CINT", 1000))
1069+
1070+
testCaseInsensitiveResolution(
1071+
schema,
1072+
FilterApi.ltEq(intColumn("cint"), 1000: Integer),
1073+
sources.LessThanOrEqual("CINT", 1000))
1074+
1075+
testCaseInsensitiveResolution(
1076+
schema, FilterApi.gt(intColumn("cint"), 1000: Integer), sources.GreaterThan("CINT", 1000))
1077+
1078+
testCaseInsensitiveResolution(
1079+
schema,
1080+
FilterApi.gtEq(intColumn("cint"), 1000: Integer),
1081+
sources.GreaterThanOrEqual("CINT", 1000))
1082+
1083+
testCaseInsensitiveResolution(
1084+
schema,
1085+
FilterApi.or(
1086+
FilterApi.eq(intColumn("cint"), 10: Integer),
1087+
FilterApi.eq(intColumn("cint"), 20: Integer)),
1088+
sources.In("CINT", Array(10, 20)))
1089+
1090+
val dupFieldSchema = StructType(
1091+
Seq(StructField("cint", IntegerType), StructField("cINT", IntegerType)))
1092+
val dupParquetSchema = new SparkToParquetSchemaConverter(conf).convert(dupFieldSchema)
1093+
assertResult(None) {
1094+
parquetFilters.createFilter(
1095+
dupParquetSchema, sources.EqualTo("CINT", 1000), caseSensitive = false)
1096+
}
1097+
}
10241098
}
10251099

10261100
class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] {

0 commit comments

Comments
 (0)