From 28315888eaae5a9c9160ea53eb6eb9a9af712958 Mon Sep 17 00:00:00 2001 From: seancxmao Date: Tue, 21 Aug 2018 10:34:23 +0800 Subject: [PATCH] [SPARK-25132][SQL][BACKPORT-2.3] Case-insensitive field resolution when reading from Parquet --- .../parquet/ParquetFileFormat.scala | 3 + .../parquet/ParquetReadSupport.scala | 84 +++++++++++++------ .../spark/sql/FileBasedDataSourceSuite.scala | 43 ++++++++++ .../parquet/ParquetSchemaSuite.scala | 61 ++++++++++++-- 4 files changed, 161 insertions(+), 30 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index b0ba21e47df45..ddd94e119246a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -310,6 +310,9 @@ class ParquetFileFormat hadoopConf.set( SQLConf.SESSION_LOCAL_TIMEZONE.key, sparkSession.sessionState.conf.sessionLocalTimeZone) + hadoopConf.setBoolean( + SQLConf.CASE_SENSITIVE.key, + sparkSession.sessionState.conf.caseSensitiveAnalysis) ParquetWriteSupport.setSchema(requiredSchema, hadoopConf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala index 40ce5d5e0564e..3319e73f2b313 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.parquet -import java.util.{Map => JMap, TimeZone} +import java.util.{Locale, Map => JMap, TimeZone} import scala.collection.JavaConverters._ @@ -30,6 +30,7 @@ import org.apache.parquet.schema.Type.Repetition import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -71,8 +72,10 @@ private[parquet] class ParquetReadSupport(val convertTz: Option[TimeZone]) StructType.fromString(schemaString) } - val parquetRequestedSchema = - ParquetReadSupport.clipParquetSchema(context.getFileSchema, catalystRequestedSchema) + val caseSensitive = context.getConfiguration.getBoolean(SQLConf.CASE_SENSITIVE.key, + SQLConf.CASE_SENSITIVE.defaultValue.get) + val parquetRequestedSchema = ParquetReadSupport.clipParquetSchema( + context.getFileSchema, catalystRequestedSchema, caseSensitive) new ReadContext(parquetRequestedSchema, Map.empty[String, String].asJava) } @@ -117,8 +120,12 @@ private[parquet] object ParquetReadSupport { * Tailors `parquetSchema` according to `catalystSchema` by removing column paths don't exist * in `catalystSchema`, and adding those only exist in `catalystSchema`. */ - def clipParquetSchema(parquetSchema: MessageType, catalystSchema: StructType): MessageType = { - val clippedParquetFields = clipParquetGroupFields(parquetSchema.asGroupType(), catalystSchema) + def clipParquetSchema( + parquetSchema: MessageType, + catalystSchema: StructType, + caseSensitive: Boolean = true): MessageType = { + val clippedParquetFields = clipParquetGroupFields( + parquetSchema.asGroupType(), catalystSchema, caseSensitive) if (clippedParquetFields.isEmpty) { ParquetSchemaConverter.EMPTY_MESSAGE } else { @@ -129,20 +136,21 @@ private[parquet] object ParquetReadSupport { } } - private def clipParquetType(parquetType: Type, catalystType: DataType): Type = { + private def clipParquetType( + parquetType: Type, catalystType: DataType, caseSensitive: Boolean): Type = { catalystType match { case t: ArrayType if !isPrimitiveCatalystType(t.elementType) => // Only clips array types with nested type as element type. - clipParquetListType(parquetType.asGroupType(), t.elementType) + clipParquetListType(parquetType.asGroupType(), t.elementType, caseSensitive) case t: MapType if !isPrimitiveCatalystType(t.keyType) || !isPrimitiveCatalystType(t.valueType) => // Only clips map types with nested key type or value type - clipParquetMapType(parquetType.asGroupType(), t.keyType, t.valueType) + clipParquetMapType(parquetType.asGroupType(), t.keyType, t.valueType, caseSensitive) case t: StructType => - clipParquetGroup(parquetType.asGroupType(), t) + clipParquetGroup(parquetType.asGroupType(), t, caseSensitive) case _ => // UDTs and primitive types are not clipped. For UDTs, a clipped version might not be able @@ -168,14 +176,15 @@ private[parquet] object ParquetReadSupport { * of the [[ArrayType]] should also be a nested type, namely an [[ArrayType]], a [[MapType]], or a * [[StructType]]. */ - private def clipParquetListType(parquetList: GroupType, elementType: DataType): Type = { + private def clipParquetListType( + parquetList: GroupType, elementType: DataType, caseSensitive: Boolean): Type = { // Precondition of this method, should only be called for lists with nested element types. assert(!isPrimitiveCatalystType(elementType)) // Unannotated repeated group should be interpreted as required list of required element, so // list element type is just the group itself. Clip it. if (parquetList.getOriginalType == null && parquetList.isRepetition(Repetition.REPEATED)) { - clipParquetType(parquetList, elementType) + clipParquetType(parquetList, elementType, caseSensitive) } else { assert( parquetList.getOriginalType == OriginalType.LIST, @@ -207,7 +216,7 @@ private[parquet] object ParquetReadSupport { Types .buildGroup(parquetList.getRepetition) .as(OriginalType.LIST) - .addField(clipParquetType(repeatedGroup, elementType)) + .addField(clipParquetType(repeatedGroup, elementType, caseSensitive)) .named(parquetList.getName) } else { // Otherwise, the repeated field's type is the element type with the repeated field's @@ -218,7 +227,7 @@ private[parquet] object ParquetReadSupport { .addField( Types .repeatedGroup() - .addField(clipParquetType(repeatedGroup.getType(0), elementType)) + .addField(clipParquetType(repeatedGroup.getType(0), elementType, caseSensitive)) .named(repeatedGroup.getName)) .named(parquetList.getName) } @@ -231,7 +240,10 @@ private[parquet] object ParquetReadSupport { * a [[StructType]]. */ private def clipParquetMapType( - parquetMap: GroupType, keyType: DataType, valueType: DataType): GroupType = { + parquetMap: GroupType, + keyType: DataType, + valueType: DataType, + caseSensitive: Boolean): GroupType = { // Precondition of this method, only handles maps with nested key types or value types. assert(!isPrimitiveCatalystType(keyType) || !isPrimitiveCatalystType(valueType)) @@ -243,8 +255,8 @@ private[parquet] object ParquetReadSupport { Types .repeatedGroup() .as(repeatedGroup.getOriginalType) - .addField(clipParquetType(parquetKeyType, keyType)) - .addField(clipParquetType(parquetValueType, valueType)) + .addField(clipParquetType(parquetKeyType, keyType, caseSensitive)) + .addField(clipParquetType(parquetValueType, valueType, caseSensitive)) .named(repeatedGroup.getName) Types @@ -262,8 +274,9 @@ private[parquet] object ParquetReadSupport { * [[MessageType]]. Because it's legal to construct an empty requested schema for column * pruning. */ - private def clipParquetGroup(parquetRecord: GroupType, structType: StructType): GroupType = { - val clippedParquetFields = clipParquetGroupFields(parquetRecord, structType) + private def clipParquetGroup( + parquetRecord: GroupType, structType: StructType, caseSensitive: Boolean): GroupType = { + val clippedParquetFields = clipParquetGroupFields(parquetRecord, structType, caseSensitive) Types .buildGroup(parquetRecord.getRepetition) .as(parquetRecord.getOriginalType) @@ -277,14 +290,35 @@ private[parquet] object ParquetReadSupport { * @return A list of clipped [[GroupType]] fields, which can be empty. */ private def clipParquetGroupFields( - parquetRecord: GroupType, structType: StructType): Seq[Type] = { - val parquetFieldMap = parquetRecord.getFields.asScala.map(f => f.getName -> f).toMap + parquetRecord: GroupType, structType: StructType, caseSensitive: Boolean): Seq[Type] = { val toParquet = new SparkToParquetSchemaConverter(writeLegacyParquetFormat = false) - structType.map { f => - parquetFieldMap - .get(f.name) - .map(clipParquetType(_, f.dataType)) - .getOrElse(toParquet.convertField(f)) + if (caseSensitive) { + val caseSensitiveParquetFieldMap = + parquetRecord.getFields.asScala.map(f => f.getName -> f).toMap + structType.map { f => + caseSensitiveParquetFieldMap + .get(f.name) + .map(clipParquetType(_, f.dataType, caseSensitive)) + .getOrElse(toParquet.convertField(f)) + } + } else { + // Do case-insensitive resolution only if in case-insensitive mode + val caseInsensitiveParquetFieldMap = + parquetRecord.getFields.asScala.groupBy(_.getName.toLowerCase(Locale.ROOT)) + structType.map { f => + caseInsensitiveParquetFieldMap + .get(f.name.toLowerCase(Locale.ROOT)) + .map { parquetTypes => + if (parquetTypes.size > 1) { + // Need to fail if there is ambiguity, i.e. more than one field is matched + val parquetTypesString = parquetTypes.map(_.getName).mkString("[", ", ", "]") + throw new RuntimeException(s"""Found duplicate field(s) "${f.name}": """ + + s"$parquetTypesString in case-insensitive mode") + } else { + clipParquetType(parquetTypes.head, f.dataType, caseSensitive) + } + }.getOrElse(toParquet.convertField(f)) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index b5d4c558f0d3e..cb96407096016 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -156,4 +156,47 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } } } + + test(s"SPARK-25132: case-insensitive field resolution when reading from Parquet") { + withTempDir { dir => + val format = "parquet" + val tableDir = dir.getCanonicalPath + s"/$format" + val tableName = s"spark_25132_${format}" + withTable(tableName) { + val end = 5 + val data = spark.range(end).selectExpr("id as A", "id * 2 as b", "id * 3 as B") + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + data.write.format(format).mode("overwrite").save(tableDir) + } + sql(s"CREATE TABLE $tableName (a LONG, b LONG) USING $format LOCATION '$tableDir'") + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkAnswer(sql(s"select a from $tableName"), data.select("A")) + checkAnswer(sql(s"select A from $tableName"), data.select("A")) + + // RuntimeException is triggered at executor side, which is then wrapped as + // SparkException at driver side + val e1 = intercept[SparkException] { + sql(s"select b from $tableName").collect() + } + assert( + e1.getCause.isInstanceOf[RuntimeException] && + e1.getCause.getMessage.contains( + """Found duplicate field(s) "b": [b, B] in case-insensitive mode""")) + val e2 = intercept[SparkException] { + sql(s"select B from $tableName").collect() + } + assert( + e2.getCause.isInstanceOf[RuntimeException] && + e2.getCause.getMessage.contains( + """Found duplicate field(s) "b": [b, B] in case-insensitive mode""")) + } + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + checkAnswer(sql(s"select a from $tableName"), (0 until end).map(_ => Row(null))) + checkAnswer(sql(s"select b from $tableName"), data.select("b")) + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 9d3dfae348beb..4805b1e76806a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -1014,19 +1014,21 @@ class ParquetSchemaSuite extends ParquetSchemaTest { testName: String, parquetSchema: String, catalystSchema: StructType, - expectedSchema: String): Unit = { + expectedSchema: String, + caseSensitive: Boolean = true): Unit = { testSchemaClipping(testName, parquetSchema, catalystSchema, - MessageTypeParser.parseMessageType(expectedSchema)) + MessageTypeParser.parseMessageType(expectedSchema), caseSensitive) } private def testSchemaClipping( testName: String, parquetSchema: String, catalystSchema: StructType, - expectedSchema: MessageType): Unit = { + expectedSchema: MessageType, + caseSensitive: Boolean): Unit = { test(s"Clipping - $testName") { val actual = ParquetReadSupport.clipParquetSchema( - MessageTypeParser.parseMessageType(parquetSchema), catalystSchema) + MessageTypeParser.parseMessageType(parquetSchema), catalystSchema, caseSensitive) try { expectedSchema.checkContains(actual) @@ -1387,7 +1389,8 @@ class ParquetSchemaSuite extends ParquetSchemaTest { catalystSchema = new StructType(), - expectedSchema = ParquetSchemaConverter.EMPTY_MESSAGE) + expectedSchema = ParquetSchemaConverter.EMPTY_MESSAGE, + caseSensitive = true) testSchemaClipping( "disjoint field sets", @@ -1544,4 +1547,52 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin) + + testSchemaClipping( + "case-insensitive resolution: no ambiguity", + parquetSchema = + """message root { + | required group A { + | optional int32 B; + | } + | optional int32 c; + |} + """.stripMargin, + catalystSchema = { + val nestedType = new StructType().add("b", IntegerType, nullable = true) + new StructType() + .add("a", nestedType, nullable = true) + .add("c", IntegerType, nullable = true) + }, + expectedSchema = + """message root { + | required group A { + | optional int32 B; + | } + | optional int32 c; + |} + """.stripMargin, + caseSensitive = false) + + test("Clipping - case-insensitive resolution: more than one field is matched") { + val parquetSchema = + """message root { + | required group A { + | optional int32 B; + | } + | optional int32 c; + | optional int32 a; + |} + """.stripMargin + val catalystSchema = { + val nestedType = new StructType().add("b", IntegerType, nullable = true) + new StructType() + .add("a", nestedType, nullable = true) + .add("c", IntegerType, nullable = true) + } + assertThrows[RuntimeException] { + ParquetReadSupport.clipParquetSchema( + MessageTypeParser.parseMessageType(parquetSchema), catalystSchema, caseSensitive = false) + } + } }