From ac017a8e22b2d20d5c6f4b6e280c7366231220d5 Mon Sep 17 00:00:00 2001 From: Ivan Sadikov Date: Wed, 31 Jan 2024 15:45:00 +1300 Subject: [PATCH] update --- .../spark/sql/avro/AvroDataToCatalyst.scala | 12 +++- .../spark/sql/avro/AvroDeserializer.scala | 12 ++-- .../spark/sql/avro/AvroFileFormat.scala | 3 +- .../apache/spark/sql/avro/AvroOptions.scala | 6 ++ .../org/apache/spark/sql/avro/AvroUtils.scala | 5 +- .../spark/sql/avro/SchemaConverters.scala | 58 +++++++++++++------ .../v2/avro/AvroPartitionReaderFactory.scala | 3 +- .../AvroCatalystDataConversionSuite.scala | 7 ++- .../spark/sql/avro/AvroFunctionsSuite.scala | 3 +- .../spark/sql/avro/AvroRowReaderSuite.scala | 3 +- .../spark/sql/avro/AvroSerdeSuite.scala | 3 +- .../org/apache/spark/sql/avro/AvroSuite.scala | 54 +++++++++++++---- docs/sql-data-sources-avro.md | 10 +++- 13 files changed, 133 insertions(+), 46 deletions(-) diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala index 9f31a2db55a52..7d80998d96eb1 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala @@ -40,7 +40,9 @@ private[sql] case class AvroDataToCatalyst( override lazy val dataType: DataType = { val dt = SchemaConverters.toSqlType( - expectedSchema, avroOptions.useStableIdForUnionType).dataType + expectedSchema, + avroOptions.useStableIdForUnionType, + avroOptions.stableIdPrefixForUnionType).dataType parseMode match { // With PermissiveMode, the output Catalyst row might contain columns of null values for // corrupt records, even if some of the columns are not nullable in the user-provided schema. @@ -62,8 +64,12 @@ private[sql] case class AvroDataToCatalyst( @transient private lazy val reader = new GenericDatumReader[Any](actualSchema, expectedSchema) @transient private lazy val deserializer = - new AvroDeserializer(expectedSchema, dataType, - avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType) + new AvroDeserializer( + expectedSchema, + dataType, + avroOptions.datetimeRebaseModeInRead, + avroOptions.useStableIdForUnionType, + avroOptions.stableIdPrefixForUnionType) @transient private var decoder: BinaryDecoder = _ diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala index 9e10fac8bb552..139c45adb4421 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -50,20 +50,23 @@ private[sql] class AvroDeserializer( positionalFieldMatch: Boolean, datetimeRebaseSpec: RebaseSpec, filters: StructFilters, - useStableIdForUnionType: Boolean) { + useStableIdForUnionType: Boolean, + stableIdPrefixForUnionType: String) { def this( rootAvroType: Schema, rootCatalystType: DataType, datetimeRebaseMode: String, - useStableIdForUnionType: Boolean) = { + useStableIdForUnionType: Boolean, + stableIdPrefixForUnionType: String) = { this( rootAvroType, rootCatalystType, positionalFieldMatch = false, RebaseSpec(LegacyBehaviorPolicy.withName(datetimeRebaseMode)), new NoopFilters, - useStableIdForUnionType) + useStableIdForUnionType, + stableIdPrefixForUnionType) } private lazy val decimalConversions = new DecimalConversion() @@ -124,7 +127,8 @@ private[sql] class AvroDeserializer( val incompatibleMsg = errorPrefix + s"schema is incompatible (avroType = $avroType, sqlType = ${catalystType.sql})" - val realDataType = SchemaConverters.toSqlType(avroType, useStableIdForUnionType).dataType + val realDataType = SchemaConverters.toSqlType( + avroType, useStableIdForUnionType, stableIdPrefixForUnionType).dataType (avroType.getType, catalystType) match { case (NULL, NullType) => (updater, ordinal, _) => diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index 7b0292df43c2f..2792edaea2843 100755 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -142,7 +142,8 @@ private[sql] class AvroFileFormat extends FileFormat parsedOptions.positionalFieldMatching, datetimeRebaseMode, avroFilters, - parsedOptions.useStableIdForUnionType) + parsedOptions.useStableIdForUnionType, + parsedOptions.stableIdPrefixForUnionType) override val stopPosition = file.start + file.length override def hasNext: Boolean = hasNextRow diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala index a0db82f987162..4332904339f19 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala @@ -133,6 +133,9 @@ private[sql] class AvroOptions( val useStableIdForUnionType: Boolean = parameters.get(STABLE_ID_FOR_UNION_TYPE).map(_.toBoolean).getOrElse(false) + + val stableIdPrefixForUnionType: String = parameters + .getOrElse(STABLE_ID_PREFIX_FOR_UNION_TYPE, "member_") } private[sql] object AvroOptions extends DataSourceOptions { @@ -164,4 +167,7 @@ private[sql] object AvroOptions extends DataSourceOptions { // type name are identical regardless of case, an exception will be raised. However, in other // cases, the field names can be uniquely identified. val STABLE_ID_FOR_UNION_TYPE = newOption("enableStableIdentifiersForUnionType") + // When STABLE_ID_FOR_UNION_TYPE is enabled, the option allows to configure the prefix for fields + // of Avro Union type. + val STABLE_ID_PREFIX_FOR_UNION_TYPE = newOption("stableIdentifierPrefixForUnionType") } diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala index 05562c913b19e..c1365d1b5ae1c 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala @@ -61,7 +61,10 @@ private[sql] object AvroUtils extends Logging { new FileSourceOptions(CaseInsensitiveMap(options)).ignoreCorruptFiles) } - SchemaConverters.toSqlType(avroSchema, parsedOptions.useStableIdForUnionType).dataType match { + SchemaConverters.toSqlType( + avroSchema, + parsedOptions.useStableIdForUnionType, + parsedOptions.stableIdPrefixForUnionType).dataType match { case t: StructType => Some(t) case _ => throw new RuntimeException( s"""Avro schema cannot be converted to a Spark SQL StructType: diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala index 00fb32794e3ad..387526d40f68f 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala @@ -51,8 +51,11 @@ object SchemaConverters { * * @since 4.0.0 */ - def toSqlType(avroSchema: Schema, useStableIdForUnionType: Boolean): SchemaType = { - toSqlTypeHelper(avroSchema, Set.empty, useStableIdForUnionType) + def toSqlType( + avroSchema: Schema, + useStableIdForUnionType: Boolean, + stableIdPrefixForUnionType: String): SchemaType = { + toSqlTypeHelper(avroSchema, Set.empty, useStableIdForUnionType, stableIdPrefixForUnionType) } /** * Converts an Avro schema to a corresponding Spark SQL schema. @@ -60,12 +63,17 @@ object SchemaConverters { * @since 2.4.0 */ def toSqlType(avroSchema: Schema): SchemaType = { - toSqlType(avroSchema, false) + toSqlType(avroSchema, false, "") } @deprecated("using toSqlType(..., useStableIdForUnionType: Boolean) instead", "4.0.0") def toSqlType(avroSchema: Schema, options: Map[String, String]): SchemaType = { - toSqlTypeHelper(avroSchema, Set.empty, AvroOptions(options).useStableIdForUnionType) + val avroOptions = AvroOptions(options) + toSqlTypeHelper( + avroSchema, + Set.empty, + avroOptions.useStableIdForUnionType, + avroOptions.stableIdPrefixForUnionType) } // The property specifies Catalyst type of the given field @@ -74,7 +82,8 @@ object SchemaConverters { private def toSqlTypeHelper( avroSchema: Schema, existingRecordNames: Set[String], - useStableIdForUnionType: Boolean): SchemaType = { + useStableIdForUnionType: Boolean, + stableIdPrefixForUnionType: String): SchemaType = { avroSchema.getType match { case INT => avroSchema.getLogicalType match { case _: Date => SchemaType(DateType, nullable = false) @@ -127,7 +136,11 @@ object SchemaConverters { } val newRecordNames = existingRecordNames + avroSchema.getFullName val fields = avroSchema.getFields.asScala.map { f => - val schemaType = toSqlTypeHelper(f.schema(), newRecordNames, useStableIdForUnionType) + val schemaType = toSqlTypeHelper( + f.schema(), + newRecordNames, + useStableIdForUnionType, + stableIdPrefixForUnionType) StructField(f.name, schemaType.dataType, schemaType.nullable) } @@ -137,14 +150,15 @@ object SchemaConverters { val schemaType = toSqlTypeHelper( avroSchema.getElementType, existingRecordNames, - useStableIdForUnionType) + useStableIdForUnionType, + stableIdPrefixForUnionType) SchemaType( ArrayType(schemaType.dataType, containsNull = schemaType.nullable), nullable = false) case MAP => val schemaType = toSqlTypeHelper(avroSchema.getValueType, - existingRecordNames, useStableIdForUnionType) + existingRecordNames, useStableIdForUnionType, stableIdPrefixForUnionType) SchemaType( MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable), nullable = false) @@ -154,18 +168,22 @@ object SchemaConverters { // In case of a union with null, eliminate it and make a recursive call val remainingUnionTypes = AvroUtils.nonNullUnionBranches(avroSchema) if (remainingUnionTypes.size == 1) { - toSqlTypeHelper(remainingUnionTypes.head, existingRecordNames, useStableIdForUnionType) - .copy(nullable = true) + toSqlTypeHelper( + remainingUnionTypes.head, + existingRecordNames, + useStableIdForUnionType, + stableIdPrefixForUnionType).copy(nullable = true) } else { toSqlTypeHelper( Schema.createUnion(remainingUnionTypes.asJava), existingRecordNames, - useStableIdForUnionType).copy(nullable = true) + useStableIdForUnionType, + stableIdPrefixForUnionType).copy(nullable = true) } } else avroSchema.getTypes.asScala.map(_.getType).toSeq match { case Seq(t1) => toSqlTypeHelper(avroSchema.getTypes.get(0), - existingRecordNames, useStableIdForUnionType) + existingRecordNames, useStableIdForUnionType, stableIdPrefixForUnionType) case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) => SchemaType(LongType, nullable = false) case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) => @@ -179,20 +197,26 @@ object SchemaConverters { val fieldNameSet : mutable.Set[String] = mutable.Set() val fields = avroSchema.getTypes.asScala.zipWithIndex.map { case (s, i) => - val schemaType = toSqlTypeHelper(s, existingRecordNames, useStableIdForUnionType) + val schemaType = toSqlTypeHelper( + s, + existingRecordNames, + useStableIdForUnionType, + stableIdPrefixForUnionType) val fieldName = if (useStableIdForUnionType) { // Avro's field name may be case sensitive, so field names for two named type // could be "a" and "A" and we need to distinguish them. In this case, we throw // an exception. - val temp_name = s"member_${s.getName.toLowerCase(Locale.ROOT)}" - if (fieldNameSet.contains(temp_name)) { + // Stable id prefix can be empty so the name of the field can be just the type. + val tempFieldName = + s"${stableIdPrefixForUnionType}${s.getName.toLowerCase(Locale.ROOT)}" + if (fieldNameSet.contains(tempFieldName)) { throw new IncompatibleSchemaException( "Cannot generate stable indentifier for Avro union type due to name " + s"conflict of type name ${s.getName}") } - fieldNameSet.add(temp_name) - temp_name + fieldNameSet.add(tempFieldName) + tempFieldName } else { s"member$i" } diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala index 2c85c1b067392..1083c99160724 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala @@ -104,7 +104,8 @@ case class AvroPartitionReaderFactory( options.positionalFieldMatching, datetimeRebaseMode, avroFilters, - options.useStableIdForUnionType) + options.useStableIdForUnionType, + options.stableIdPrefixForUnionType) override val stopPosition = partitionedFile.start + partitionedFile.length override def next(): Boolean = hasNextRow diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala index 633bbce8df801..388347537a4d6 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala @@ -60,7 +60,7 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite val expected = { val avroSchema = new Schema.Parser().parse(schema) - SchemaConverters.toSqlType(avroSchema, false).dataType match { + SchemaConverters.toSqlType(avroSchema, false, "").dataType match { case st: StructType => Row.fromSeq((0 until st.length).map(_ => null)) case _ => null } @@ -283,14 +283,15 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite data: GenericData.Record, expected: Option[Any], filters: StructFilters = new NoopFilters): Unit = { - val dataType = SchemaConverters.toSqlType(schema, false).dataType + val dataType = SchemaConverters.toSqlType(schema, false, "").dataType val deserializer = new AvroDeserializer( schema, dataType, false, RebaseSpec(LegacyBehaviorPolicy.CORRECTED), filters, - false) + false, + "") val deserialized = deserializer.deserialize(data) expected match { case None => assert(deserialized == None) diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala index 9095f1c0831a3..d16ddb4973205 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala @@ -264,7 +264,8 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { val avroOptions = AvroOptions(options) val avroSchema = avroOptions.schema.get val sparkSchema = SchemaConverters - .toSqlType(avroSchema, avroOptions.useStableIdForUnionType) + .toSqlType(avroSchema, avroOptions.useStableIdForUnionType, + avroOptions.stableIdPrefixForUnionType) .dataType .asInstanceOf[StructType] diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala index 7117ef4b21e83..9b3bb929a700d 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala @@ -76,7 +76,8 @@ class AvroRowReaderSuite false, RebaseSpec(CORRECTED), new NoopFilters, - false) + false, + "") override val stopPosition = fileSize override def hasNext: Boolean = hasNextRow diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala index c6d0398017ef3..cbcbc2e7e76a6 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala @@ -227,7 +227,8 @@ object AvroSerdeSuite { isPositional(matchType), RebaseSpec(CORRECTED), new NoopFilters, - false) + false, + "") } /** diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 3d481d1d731db..61d93ef82336f 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -103,14 +103,16 @@ abstract class AvroSuite // Check whether an Avro schema of union type is converted to SQL in an expected way, when the // stable ID option is on. // - // @param types Avro types that contain in an Avro union type - // @param expectedSchema expeted SQL schema, provided in DDL string form - // @param fieldsAndRow A list of rows to be appended to the Avro file and the expected - // converted SQL rows + // @param types Avro types that contain in an Avro union type + // @param expectedSchema Expected SQL schema, provided in DDL string form + // @param fieldsAndRow A list of rows to be appended to the Avro file and the expected + // converted SQL rows + // @param stableIdPrefixOpt Stable id prefix to use for Union type private def checkUnionStableId( types: List[Schema], expectedSchema: String, - fieldsAndRow: Seq[(Any, Row)]): Unit = { + fieldsAndRow: Seq[(Any, Row)], + stableIdPrefixOpt: Option[String] = None): Unit = { withTempDir { dir => val unionType = Schema.createUnion( types.asJava @@ -137,11 +139,16 @@ abstract class AvroSuite dataFileWriter.flush() dataFileWriter.close() - val df = spark - .read. - format("avro") + var dfReader = spark + .read + .format("avro") .option(AvroOptions.STABLE_ID_FOR_UNION_TYPE, "true") - .load(s"$dir.avro") + + stableIdPrefixOpt.foreach { prefix => + dfReader = dfReader.option(AvroOptions.STABLE_ID_PREFIX_FOR_UNION_TYPE, prefix) + } + + val df = dfReader.load(s"$dir.avro") assert(df.schema === StructType.fromDDL("field1 " + expectedSchema)) assert(df.collect().toSet == fieldsAndRow.map(fr => Row(fr._2)).toSet) } @@ -320,7 +327,7 @@ abstract class AvroSuite } } - // The test test Avro option "enableStableIdentifiersForUnionType". It adds all types into + // The test verifies Avro option "enableStableIdentifiersForUnionType". It adds all types into // union and validate they are converted to expected SQL field names. The test also creates // different cases that might cause field name conflicts and see they are handled properly. test("SPARK-43333: Stable field names when converting Union type") { @@ -435,6 +442,28 @@ abstract class AvroSuite } } + test("SPARK-46930: Use custom prefix for stable ids when converting Union type") { + // Test default "member_" prefix. + checkUnionStableId( + List(Type.INT, Type.NULL, Type.STRING).map(Schema.create(_)), + "struct", + Seq( + (42, Row(42, null)), + ("Alice", Row(null, "Alice")))) + + // Test user-configured prefixes. + for (prefix <- Seq("tmp_", "tmp", "member", "MEMBER_", "__", "")) { + checkUnionStableId( + List(Type.INT, Type.NULL, Type.STRING).map(Schema.create(_)), + s"struct<${prefix}int: int, ${prefix}string: string>", + Seq( + (42, Row(42, null)), + ("Alice", Row(null, "Alice"))), + Some(prefix) + ) + } + } + test("SPARK-27858 Union type: More than one non-null type") { Seq(true, false).foreach { isStableUnionMember => withTempDir { dir => @@ -2146,7 +2175,7 @@ abstract class AvroSuite private def checkSchemaWithRecursiveLoop(avroSchema: String): Unit = { val message = intercept[IncompatibleSchemaException] { - SchemaConverters.toSqlType(new Schema.Parser().parse(avroSchema), false) + SchemaConverters.toSqlType(new Schema.Parser().parse(avroSchema), false, "") }.getMessage assert(message.contains("Found recursive reference in Avro schema")) @@ -2703,7 +2732,7 @@ abstract class AvroSuite } test("SPARK-40667: validate Avro Options") { - assert(AvroOptions.getAllOptions.size == 10) + assert(AvroOptions.getAllOptions.size == 11) // Please add validation on any new Avro options here assert(AvroOptions.isValidOption("ignoreExtension")) assert(AvroOptions.isValidOption("mode")) @@ -2715,6 +2744,7 @@ abstract class AvroSuite assert(AvroOptions.isValidOption("positionalFieldMatching")) assert(AvroOptions.isValidOption("datetimeRebaseMode")) assert(AvroOptions.isValidOption("enableStableIdentifiersForUnionType")) + assert(AvroOptions.isValidOption("stableIdentifierPrefixForUnionType")) } test("SPARK-46633: read file with empty blocks") { diff --git a/docs/sql-data-sources-avro.md b/docs/sql-data-sources-avro.md index cbc3367e5f852..712d4d3b8cd46 100644 --- a/docs/sql-data-sources-avro.md +++ b/docs/sql-data-sources-avro.md @@ -327,7 +327,15 @@ Data source options of Avro can be set via: If it is set to true, Avro schema is deserialized into Spark SQL schema, and the Avro Union type is transformed into a structure where the field names remain consistent with their respective types. The resulting field names are converted to lowercase, e.g. member_int or member_string. If two user-defined type names or a user-defined type name and a built-in type name are identical regardless of case, an exception will be raised. However, in other cases, the field names can be uniquely identified. read 3.5.0 - + + + stableIdentifierPrefixForUnionType + member_ + When `enableStableIdentifiersForUnionType` is enabled, the option allows to configure the prefix for fields of Avro Union type. + read + 4.0.0 + + ## Configuration Configuration of Avro can be done via `spark.conf.set` or by running `SET key=value` commands using SQL.