Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ class ParquetFileFormat
val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal
val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith
val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold
val isCaseSensitive = sqlConf.caseSensitiveAnalysis

(file: PartitionedFile) => {
assert(file.partitionValues.numFields == partitionSchema.size)
Expand All @@ -372,7 +373,7 @@ class ParquetFileFormat
val pushed = if (enableParquetFilterPushDown) {
val parquetSchema = footerFileMetaData.getSchema
val parquetFilters = new ParquetFilters(pushDownDate, pushDownTimestamp, pushDownDecimal,
pushDownStringStartWith, pushDownInFilterThreshold)
pushDownStringStartWith, pushDownInFilterThreshold, isCaseSensitive)
filters
// Collects all converted Parquet filter predicates. Notice that not all predicates can be
// converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet
import java.lang.{Boolean => JBoolean, Double => JDouble, Float => JFloat, Long => JLong}
import java.math.{BigDecimal => JBigDecimal}
import java.sql.{Date, Timestamp}
import java.util.Locale

import scala.collection.JavaConverters.asScalaBufferConverter

Expand All @@ -31,7 +32,7 @@ import org.apache.parquet.schema.OriginalType._
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._

import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLDate
import org.apache.spark.sql.sources
import org.apache.spark.unsafe.types.UTF8String
Expand All @@ -44,7 +45,18 @@ private[parquet] class ParquetFilters(
pushDownTimestamp: Boolean,
pushDownDecimal: Boolean,
pushDownStartWith: Boolean,
pushDownInFilterThreshold: Int) {
pushDownInFilterThreshold: Int,
caseSensitive: Boolean) {

/**
* Holds a single field information stored in the underlying parquet file.
*
* @param fieldName field name in parquet file
* @param fieldType field type related info in parquet file
*/
private case class ParquetField(
fieldName: String,
fieldType: ParquetSchemaType)

private case class ParquetSchemaType(
originalType: OriginalType,
Expand Down Expand Up @@ -350,25 +362,38 @@ private[parquet] class ParquetFilters(
}

/**
* Returns a map from name of the column to the data type, if predicate push down applies.
* Returns a map, which contains parquet field name and data type, if predicate push down applies.
*/
private def getFieldMap(dataType: MessageType): Map[String, ParquetSchemaType] = dataType match {
case m: MessageType =>
// Here we don't flatten the fields in the nested schema but just look up through
// root fields. Currently, accessing to nested fields does not push down filters
// and it does not support to create filters for them.
m.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f =>
f.getName -> ParquetSchemaType(
f.getOriginalType, f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata)
}.toMap
case _ => Map.empty[String, ParquetSchemaType]
private def getFieldMap(dataType: MessageType): Map[String, ParquetField] = {
// Here we don't flatten the fields in the nested schema but just look up through
// root fields. Currently, accessing to nested fields does not push down filters
// and it does not support to create filters for them.
val primitiveFields =
dataType.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f =>
f.getName -> ParquetField(f.getName,
ParquetSchemaType(f.getOriginalType,
f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata))
}
if (caseSensitive) {
primitiveFields.toMap
} else {
// Don't consider ambiguity here, i.e. more than one field is matched in case insensitive
// mode, just skip pushdown for these fields, they will trigger Exception when reading,
// See: SPARK-25132.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't need to consider ambiguity, can't we just lowercase f.getName above instead of doing dedup here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a good question!

Let's see the scenario like below:

  1. parquet file has duplicate fields "a INT, A INT".
  2. user wants to pushdown "A > 0".

Without dedup, we possible pushdown "a > 0" instead of "A > 0",
although it is wrong, it will still trigger the Exception finally when reading parquet,
so whether dedup or not, we will get the same result.

@cloud-fan , @gatorsmile any idea?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we do the dedup before parquet filter pushdown and parquet column pruning? Then we can simplify the code in both cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping @yucai

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan, it is a great idea, thanks!
I think it is not to "dedup" before pushdown and pruning.
Maybe we should do parquet schema clip before pushdown and pruning.
If duplicated fields are detected, throw the exception.
If not, pass clipped parquet schema via hadoopconf to parquet lib.

    catalystRequestedSchema = {
      val conf = context.getConfiguration
      val schemaString = conf.get(ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA)
      assert(schemaString != null, "Parquet requested schema not set.")
      StructType.fromString(schemaString)
    }

    val caseSensitive = context.getConfiguration.getBoolean(SQLConf.CASE_SENSITIVE.key,
      SQLConf.CASE_SENSITIVE.defaultValue.get)
    val parquetRequestedSchema = ParquetReadSupport.clipParquetSchema(
      context.getFileSchema, catalystRequestedSchema, caseSensitive)

I am trying this way, will update soon.

val dedupPrimitiveFields =
primitiveFields
.groupBy(_._1.toLowerCase(Locale.ROOT))
.filter(_._2.size == 1)
.mapValues(_.head._2)
CaseInsensitiveMap(dedupPrimitiveFields)
}
}

/**
* Converts data sources filters to Parquet filter predicates.
*/
def createFilter(schema: MessageType, predicate: sources.Filter): Option[FilterPredicate] = {
val nameToType = getFieldMap(schema)
val nameToParquetField = getFieldMap(schema)

// Decimal type must make sure that filter value's scale matched the file.
// If doesn't matched, which would cause data corruption.
Expand All @@ -381,7 +406,7 @@ private[parquet] class ParquetFilters(
// Parquet's type in the given file should be matched to the value's type
// in the pushed filter in order to push down the filter to Parquet.
def valueCanMakeFilterOn(name: String, value: Any): Boolean = {
value == null || (nameToType(name) match {
value == null || (nameToParquetField(name).fieldType match {
case ParquetBooleanType => value.isInstanceOf[JBoolean]
case ParquetByteType | ParquetShortType | ParquetIntegerType => value.isInstanceOf[Number]
case ParquetLongType => value.isInstanceOf[JLong]
Expand All @@ -408,7 +433,7 @@ private[parquet] class ParquetFilters(
// filters for the column having dots in the names. Thus, we do not push down such filters.
// See SPARK-20364.
def canMakeFilterOn(name: String, value: Any): Boolean = {
nameToType.contains(name) && !name.contains(".") && valueCanMakeFilterOn(name, value)
nameToParquetField.contains(name) && !name.contains(".") && valueCanMakeFilterOn(name, value)
}

// NOTE:
Expand All @@ -428,29 +453,39 @@ private[parquet] class ParquetFilters(

predicate match {
case sources.IsNull(name) if canMakeFilterOn(name, null) =>
makeEq.lift(nameToType(name)).map(_(name, null))
makeEq.lift(nameToParquetField(name).fieldType)
.map(_(nameToParquetField(name).fieldName, null))
case sources.IsNotNull(name) if canMakeFilterOn(name, null) =>
makeNotEq.lift(nameToType(name)).map(_(name, null))
makeNotEq.lift(nameToParquetField(name).fieldType)
.map(_(nameToParquetField(name).fieldName, null))

case sources.EqualTo(name, value) if canMakeFilterOn(name, value) =>
makeEq.lift(nameToType(name)).map(_(name, value))
makeEq.lift(nameToParquetField(name).fieldType)
.map(_(nameToParquetField(name).fieldName, value))
case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) =>
makeNotEq.lift(nameToType(name)).map(_(name, value))
makeNotEq.lift(nameToParquetField(name).fieldType)
.map(_(nameToParquetField(name).fieldName, value))

case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) =>
makeEq.lift(nameToType(name)).map(_(name, value))
makeEq.lift(nameToParquetField(name).fieldType)
.map(_(nameToParquetField(name).fieldName, value))
case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) =>
makeNotEq.lift(nameToType(name)).map(_(name, value))
makeNotEq.lift(nameToParquetField(name).fieldType)
.map(_(nameToParquetField(name).fieldName, value))

case sources.LessThan(name, value) if canMakeFilterOn(name, value) =>
makeLt.lift(nameToType(name)).map(_(name, value))
makeLt.lift(nameToParquetField(name).fieldType)
.map(_(nameToParquetField(name).fieldName, value))
case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) =>
makeLtEq.lift(nameToType(name)).map(_(name, value))
makeLtEq.lift(nameToParquetField(name).fieldType)
.map(_(nameToParquetField(name).fieldName, value))

case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) =>
makeGt.lift(nameToType(name)).map(_(name, value))
makeGt.lift(nameToParquetField(name).fieldType)
.map(_(nameToParquetField(name).fieldName, value))
case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) =>
makeGtEq.lift(nameToType(name)).map(_(name, value))
makeGtEq.lift(nameToParquetField(name).fieldType)
.map(_(nameToParquetField(name).fieldName, value))

case sources.And(lhs, rhs) =>
// At here, it is not safe to just convert one side if we do not understand the
Expand All @@ -477,7 +512,8 @@ private[parquet] class ParquetFilters(
case sources.In(name, values) if canMakeFilterOn(name, values.head)
&& values.distinct.length <= pushDownInFilterThreshold =>
values.distinct.flatMap { v =>
makeEq.lift(nameToType(name)).map(_(name, v))
makeEq.lift(nameToParquetField(name).fieldType)
.map(_(nameToParquetField(name).fieldName, v))
}.reduceLeftOption(FilterApi.or)

case sources.StringStartsWith(name, prefix)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Operato
import org.apache.parquet.filter2.predicate.FilterApi._
import org.apache.parquet.filter2.predicate.Operators.{Column => _, _}

import org.apache.spark.SparkException
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -60,7 +61,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
private lazy val parquetFilters =
new ParquetFilters(conf.parquetFilterPushDownDate, conf.parquetFilterPushDownTimestamp,
conf.parquetFilterPushDownDecimal, conf.parquetFilterPushDownStringStartWith,
conf.parquetFilterPushDownInFilterThreshold)
conf.parquetFilterPushDownInFilterThreshold, conf.caseSensitiveAnalysis)

override def beforeEach(): Unit = {
super.beforeEach()
Expand Down Expand Up @@ -1021,6 +1022,118 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
}
}
}

test("SPARK-25207: Case-insensitive field resolution for pushdown when reading parquet") {
def createParquetFilter(caseSensitive: Boolean): ParquetFilters = {
new ParquetFilters(conf.parquetFilterPushDownDate, conf.parquetFilterPushDownTimestamp,
conf.parquetFilterPushDownDecimal, conf.parquetFilterPushDownStringStartWith,
conf.parquetFilterPushDownInFilterThreshold, caseSensitive)
}
val caseSensitiveParquetFilters = createParquetFilter(caseSensitive = true)
val caseInsensitiveParquetFilters = createParquetFilter(caseSensitive = false)

def testCaseInsensitiveResolution(
schema: StructType,
expected: FilterPredicate,
filter: sources.Filter): Unit = {
val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema)

assertResult(Some(expected)) {
caseInsensitiveParquetFilters.createFilter(parquetSchema, filter)
}
assertResult(None) {
caseSensitiveParquetFilters.createFilter(parquetSchema, filter)
}
}

val schema = StructType(Seq(StructField("cint", IntegerType)))

testCaseInsensitiveResolution(
schema, FilterApi.eq(intColumn("cint"), null.asInstanceOf[Integer]), sources.IsNull("CINT"))

testCaseInsensitiveResolution(
schema,
FilterApi.notEq(intColumn("cint"), null.asInstanceOf[Integer]),
sources.IsNotNull("CINT"))

testCaseInsensitiveResolution(
schema, FilterApi.eq(intColumn("cint"), 1000: Integer), sources.EqualTo("CINT", 1000))

testCaseInsensitiveResolution(
schema,
FilterApi.notEq(intColumn("cint"), 1000: Integer),
sources.Not(sources.EqualTo("CINT", 1000)))

testCaseInsensitiveResolution(
schema, FilterApi.eq(intColumn("cint"), 1000: Integer), sources.EqualNullSafe("CINT", 1000))

testCaseInsensitiveResolution(
schema,
FilterApi.notEq(intColumn("cint"), 1000: Integer),
sources.Not(sources.EqualNullSafe("CINT", 1000)))

testCaseInsensitiveResolution(
schema,
FilterApi.lt(intColumn("cint"), 1000: Integer), sources.LessThan("CINT", 1000))

testCaseInsensitiveResolution(
schema,
FilterApi.ltEq(intColumn("cint"), 1000: Integer),
sources.LessThanOrEqual("CINT", 1000))

testCaseInsensitiveResolution(
schema, FilterApi.gt(intColumn("cint"), 1000: Integer), sources.GreaterThan("CINT", 1000))

testCaseInsensitiveResolution(
schema,
FilterApi.gtEq(intColumn("cint"), 1000: Integer),
sources.GreaterThanOrEqual("CINT", 1000))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe we don't need to test against so many predicate. We just want to make sure case insensitive resolution work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each test is corresponding to one line code change in createFilter. Like:

      case sources.IsNull(name) if canMakeFilterOn(name, null) =>
        makeEq.lift(fieldMap(name).schema).map(_(fieldMap(name).name, null))

All tests together can cover all my change in createFilter.


testCaseInsensitiveResolution(
schema,
FilterApi.or(
FilterApi.eq(intColumn("cint"), 10: Integer),
FilterApi.eq(intColumn("cint"), 20: Integer)),
sources.In("CINT", Array(10, 20)))

val dupFieldSchema = StructType(
Seq(StructField("cint", IntegerType), StructField("cINT", IntegerType)))
val dupParquetSchema = new SparkToParquetSchemaConverter(conf).convert(dupFieldSchema)
assertResult(None) {
caseInsensitiveParquetFilters.createFilter(
dupParquetSchema, sources.EqualTo("CINT", 1000))
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add one negative test that having name names in case insensitive modes, for example, cInt, CINT and check if that throws an exception?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added, thanks!

}

test("SPARK-25207: exception when duplicate fields in case-insensitive mode") {
withTempPath { dir =>
val count = 10
val tableName = "spark_25207"
val tableDir = dir.getAbsoluteFile + "/table"
withTable(tableName) {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
spark.range(count).selectExpr("id as A", "id as B", "id as b")
.write.mode("overwrite").parquet(tableDir)
}
sql(
s"""
|CREATE TABLE $tableName (A LONG, B LONG) USING PARQUET LOCATION '$tableDir'
""".stripMargin)

withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
val e = intercept[SparkException] {
sql(s"select a from $tableName where b > 0").collect()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we read this table with case-sensitive mode?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can, see below.

val tableName = "test"
val tableDir = "/tmp/data"
spark.conf.set("spark.sql.caseSensitive", true)
spark.range(10).selectExpr("id as A", "2 * id as B", "3 * id as b").write.mode("overwrite").parquet(tableDir)
sql(s"DROP TABLE $tableName")
sql(s"CREATE TABLE $tableName (A LONG, B LONG) USING PARQUET LOCATION '$tableDir'")
scala> sql("select A from test where B > 0").show
+---+
|  A|
+---+
|  7|
|  8|
|  9|
|  2|
|  3|
|  4|
|  5|
|  6|
|  1|
+---+

Let me add one test case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: to be consistent with the following query, I'd make this query as select A from $tableName where B > 0 too.

}
assert(e.getCause.isInstanceOf[RuntimeException] && e.getCause.getMessage.contains(
"""Found duplicate field(s) "B": [B, b] in case-insensitive mode"""))
}

withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
checkAnswer(sql(s"select A from $tableName where B > 0"), (1 until count).map(Row(_)))
}
}
}
}
}

class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] {
Expand Down