diff --git a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala index a7e4f186000b..12d456a371d0 100644 --- a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala +++ b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala @@ -266,6 +266,7 @@ private[spark] object LogKeys { case object FEATURE_NAME extends LogKey case object FETCH_SIZE extends LogKey case object FIELD_NAME extends LogKey + case object FIELD_TYPE extends LogKey case object FILES extends LogKey case object FILE_ABSOLUTE_PATH extends LogKey case object FILE_END_OFFSET extends LogKey @@ -652,6 +653,7 @@ private[spark] object LogKeys { case object RECEIVER_IDS extends LogKey case object RECORDS extends LogKey case object RECOVERY_STATE extends LogKey + case object RECURSIVE_DEPTH extends LogKey case object REDACTED_STATEMENT extends LogKey case object REDUCE_ID extends LogKey case object REGEX extends LogKey diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala index 7d80998d96eb..0b85b208242c 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala @@ -42,7 +42,8 @@ private[sql] case class AvroDataToCatalyst( val dt = SchemaConverters.toSqlType( expectedSchema, avroOptions.useStableIdForUnionType, - avroOptions.stableIdPrefixForUnionType).dataType + avroOptions.stableIdPrefixForUnionType, + avroOptions.recursiveFieldMaxDepth).dataType parseMode match { // With PermissiveMode, the output Catalyst row might contain columns of null values for // corrupt records, even if some of the columns are not nullable in the user-provided schema. @@ -69,7 +70,8 @@ private[sql] case class AvroDataToCatalyst( dataType, avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType, - avroOptions.stableIdPrefixForUnionType) + avroOptions.stableIdPrefixForUnionType, + avroOptions.recursiveFieldMaxDepth) @transient private var decoder: BinaryDecoder = _ diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala index 877c3f89e88c..ac20614553ca 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -51,14 +51,16 @@ private[sql] class AvroDeserializer( datetimeRebaseSpec: RebaseSpec, filters: StructFilters, useStableIdForUnionType: Boolean, - stableIdPrefixForUnionType: String) { + stableIdPrefixForUnionType: String, + recursiveFieldMaxDepth: Int) { def this( rootAvroType: Schema, rootCatalystType: DataType, datetimeRebaseMode: String, useStableIdForUnionType: Boolean, - stableIdPrefixForUnionType: String) = { + stableIdPrefixForUnionType: String, + recursiveFieldMaxDepth: Int) = { this( rootAvroType, rootCatalystType, @@ -66,7 +68,8 @@ private[sql] class AvroDeserializer( RebaseSpec(LegacyBehaviorPolicy.withName(datetimeRebaseMode)), new NoopFilters, useStableIdForUnionType, - stableIdPrefixForUnionType) + stableIdPrefixForUnionType, + recursiveFieldMaxDepth) } private lazy val decimalConversions = new DecimalConversion() @@ -128,7 +131,8 @@ private[sql] class AvroDeserializer( s"schema is incompatible (avroType = $avroType, sqlType = ${catalystType.sql})" val realDataType = SchemaConverters.toSqlType( - avroType, useStableIdForUnionType, stableIdPrefixForUnionType).dataType + avroType, useStableIdForUnionType, stableIdPrefixForUnionType, + recursiveFieldMaxDepth).dataType (avroType.getType, catalystType) match { case (NULL, NullType) => (updater, ordinal, _) => diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index 372f24b54f5c..264c3a1f48ab 100755 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -145,7 +145,8 @@ private[sql] class AvroFileFormat extends FileFormat datetimeRebaseMode, avroFilters, parsedOptions.useStableIdForUnionType, - parsedOptions.stableIdPrefixForUnionType) + parsedOptions.stableIdPrefixForUnionType, + parsedOptions.recursiveFieldMaxDepth) override val stopPosition = file.start + file.length override def hasNext: Boolean = hasNextRow diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala index 4332904339f1..e0c6ad3ee69d 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala @@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions} import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, FailFastMode, ParseMode} +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf /** @@ -136,6 +137,15 @@ private[sql] class AvroOptions( val stableIdPrefixForUnionType: String = parameters .getOrElse(STABLE_ID_PREFIX_FOR_UNION_TYPE, "member_") + + val recursiveFieldMaxDepth: Int = + parameters.get(RECURSIVE_FIELD_MAX_DEPTH).map(_.toInt).getOrElse(-1) + + if (recursiveFieldMaxDepth > RECURSIVE_FIELD_MAX_DEPTH_LIMIT) { + throw QueryCompilationErrors.avroOptionsException( + RECURSIVE_FIELD_MAX_DEPTH, + s"Should not be greater than $RECURSIVE_FIELD_MAX_DEPTH_LIMIT.") + } } private[sql] object AvroOptions extends DataSourceOptions { @@ -170,4 +180,25 @@ private[sql] object AvroOptions extends DataSourceOptions { // When STABLE_ID_FOR_UNION_TYPE is enabled, the option allows to configure the prefix for fields // of Avro Union type. val STABLE_ID_PREFIX_FOR_UNION_TYPE = newOption("stableIdentifierPrefixForUnionType") + + /** + * Adds support for recursive fields. If this option is not specified or is set to 0, recursive + * fields are not permitted. Setting it to 1 drops all recursive fields, 2 allows recursive + * fields to be recursed once, and 3 allows it to be recursed twice and so on, up to 15. + * Values larger than 15 are not allowed in order to avoid inadvertently creating very large + * schemas. If an avro message has depth beyond this limit, the Spark struct returned is + * truncated after the recursion limit. + * + * Examples: Consider an Avro schema with a recursive field: + * {"type" : "record", "name" : "Node", "fields" : [{"name": "Id", "type": "int"}, + * {"name": "Next", "type": ["null", "Node"]}]} + * The following lists the parsed schema with different values for this setting. + * 1: `struct` + * 2: `struct>` + * 3: `struct>>` + * and so on. + */ + val RECURSIVE_FIELD_MAX_DEPTH = newOption("recursiveFieldMaxDepth") + + val RECURSIVE_FIELD_MAX_DEPTH_LIMIT: Int = 15 } diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala index 7cbc30f1fb3d..594ebb4716c4 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala @@ -65,7 +65,8 @@ private[sql] object AvroUtils extends Logging { SchemaConverters.toSqlType( avroSchema, parsedOptions.useStableIdForUnionType, - parsedOptions.stableIdPrefixForUnionType).dataType match { + parsedOptions.stableIdPrefixForUnionType, + parsedOptions.recursiveFieldMaxDepth).dataType match { case t: StructType => Some(t) case _ => throw new RuntimeException( s"""Avro schema cannot be converted to a Spark SQL StructType: diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala index b2285aa966dd..1168a887abd8 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala @@ -27,6 +27,10 @@ import org.apache.avro.LogicalTypes.{Date, Decimal, LocalTimestampMicros, LocalT import org.apache.avro.Schema.Type._ import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.internal.Logging +import org.apache.spark.internal.LogKeys.{FIELD_NAME, FIELD_TYPE, RECURSIVE_DEPTH} +import org.apache.spark.internal.MDC +import org.apache.spark.sql.avro.AvroOptions.RECURSIVE_FIELD_MAX_DEPTH_LIMIT import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.types._ import org.apache.spark.sql.types.Decimal.minBytesForPrecision @@ -36,7 +40,7 @@ import org.apache.spark.sql.types.Decimal.minBytesForPrecision * versa. */ @DeveloperApi -object SchemaConverters { +object SchemaConverters extends Logging { private lazy val nullSchema = Schema.create(Schema.Type.NULL) /** @@ -48,14 +52,27 @@ object SchemaConverters { /** * Converts an Avro schema to a corresponding Spark SQL schema. - * + * + * @param avroSchema The Avro schema to convert. + * @param useStableIdForUnionType If true, Avro schema is deserialized into Spark SQL schema, + * and the Avro Union type is transformed into a structure where + * the field names remain consistent with their respective types. + * @param stableIdPrefixForUnionType The prefix to use to configure the prefix for fields of + * Avro Union type + * @param recursiveFieldMaxDepth The maximum depth to recursively process fields in Avro schema. + * -1 means not supported. * @since 4.0.0 */ def toSqlType( avroSchema: Schema, useStableIdForUnionType: Boolean, - stableIdPrefixForUnionType: String): SchemaType = { - toSqlTypeHelper(avroSchema, Set.empty, useStableIdForUnionType, stableIdPrefixForUnionType) + stableIdPrefixForUnionType: String, + recursiveFieldMaxDepth: Int = -1): SchemaType = { + val schema = toSqlTypeHelper(avroSchema, Map.empty, useStableIdForUnionType, + stableIdPrefixForUnionType, recursiveFieldMaxDepth) + // the top level record should never return null + assert(schema != null) + schema } /** * Converts an Avro schema to a corresponding Spark SQL schema. @@ -63,17 +80,17 @@ object SchemaConverters { * @since 2.4.0 */ def toSqlType(avroSchema: Schema): SchemaType = { - toSqlType(avroSchema, false, "") + toSqlType(avroSchema, false, "", -1) } @deprecated("using toSqlType(..., useStableIdForUnionType: Boolean) instead", "4.0.0") def toSqlType(avroSchema: Schema, options: Map[String, String]): SchemaType = { val avroOptions = AvroOptions(options) - toSqlTypeHelper( + toSqlType( avroSchema, - Set.empty, avroOptions.useStableIdForUnionType, - avroOptions.stableIdPrefixForUnionType) + avroOptions.stableIdPrefixForUnionType, + avroOptions.recursiveFieldMaxDepth) } // The property specifies Catalyst type of the given field @@ -81,9 +98,10 @@ object SchemaConverters { private def toSqlTypeHelper( avroSchema: Schema, - existingRecordNames: Set[String], + existingRecordNames: Map[String, Int], useStableIdForUnionType: Boolean, - stableIdPrefixForUnionType: String): SchemaType = { + stableIdPrefixForUnionType: String, + recursiveFieldMaxDepth: Int): SchemaType = { avroSchema.getType match { case INT => avroSchema.getLogicalType match { case _: Date => SchemaType(DateType, nullable = false) @@ -128,62 +146,110 @@ object SchemaConverters { case NULL => SchemaType(NullType, nullable = true) case RECORD => - if (existingRecordNames.contains(avroSchema.getFullName)) { + val recursiveDepth: Int = existingRecordNames.getOrElse(avroSchema.getFullName, 0) + if (recursiveDepth > 0 && recursiveFieldMaxDepth <= 0) { throw new IncompatibleSchemaException(s""" - |Found recursive reference in Avro schema, which can not be processed by Spark: - |${avroSchema.toString(true)} + |Found recursive reference in Avro schema, which can not be processed by Spark by + | default: ${avroSchema.toString(true)}. Try setting the option `recursiveFieldMaxDepth` + | to 1 - $RECURSIVE_FIELD_MAX_DEPTH_LIMIT. """.stripMargin) - } - val newRecordNames = existingRecordNames + avroSchema.getFullName - val fields = avroSchema.getFields.asScala.map { f => - val schemaType = toSqlTypeHelper( - f.schema(), - newRecordNames, - useStableIdForUnionType, - stableIdPrefixForUnionType) - StructField(f.name, schemaType.dataType, schemaType.nullable) - } + } else if (recursiveDepth > 0 && recursiveDepth >= recursiveFieldMaxDepth) { + logInfo( + log"The field ${MDC(FIELD_NAME, avroSchema.getFullName)} of type " + + log"${MDC(FIELD_TYPE, avroSchema.getType.getName)} is dropped at recursive depth " + + log"${MDC(RECURSIVE_DEPTH, recursiveDepth)}." + ) + null + } else { + val newRecordNames = + existingRecordNames + (avroSchema.getFullName -> (recursiveDepth + 1)) + val fields = avroSchema.getFields.asScala.map { f => + val schemaType = toSqlTypeHelper( + f.schema(), + newRecordNames, + useStableIdForUnionType, + stableIdPrefixForUnionType, + recursiveFieldMaxDepth) + if (schemaType == null) { + null + } + else { + StructField(f.name, schemaType.dataType, schemaType.nullable) + } + }.filter(_ != null).toSeq - SchemaType(StructType(fields.toArray), nullable = false) + SchemaType(StructType(fields), nullable = false) + } case ARRAY => val schemaType = toSqlTypeHelper( avroSchema.getElementType, existingRecordNames, useStableIdForUnionType, - stableIdPrefixForUnionType) - SchemaType( - ArrayType(schemaType.dataType, containsNull = schemaType.nullable), - nullable = false) + stableIdPrefixForUnionType, + recursiveFieldMaxDepth) + if (schemaType == null) { + logInfo( + log"Dropping ${MDC(FIELD_NAME, avroSchema.getFullName)} of type " + + log"${MDC(FIELD_TYPE, avroSchema.getType.getName)} as it does not have any " + + log"fields left likely due to recursive depth limit." + ) + null + } else { + SchemaType( + ArrayType(schemaType.dataType, containsNull = schemaType.nullable), + nullable = false) + } case MAP => val schemaType = toSqlTypeHelper(avroSchema.getValueType, - existingRecordNames, useStableIdForUnionType, stableIdPrefixForUnionType) - SchemaType( - MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable), - nullable = false) + existingRecordNames, useStableIdForUnionType, stableIdPrefixForUnionType, + recursiveFieldMaxDepth) + if (schemaType == null) { + logInfo( + log"Dropping ${MDC(FIELD_NAME, avroSchema.getFullName)} of type " + + log"${MDC(FIELD_TYPE, avroSchema.getType.getName)} as it does not have any " + + log"fields left likely due to recursive depth limit." + ) + null + } else { + SchemaType( + MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable), + nullable = false) + } case UNION => if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) { // In case of a union with null, eliminate it and make a recursive call val remainingUnionTypes = AvroUtils.nonNullUnionBranches(avroSchema) - if (remainingUnionTypes.size == 1) { - toSqlTypeHelper( - remainingUnionTypes.head, - existingRecordNames, - useStableIdForUnionType, - stableIdPrefixForUnionType).copy(nullable = true) + val remainingSchema = + if (remainingUnionTypes.size == 1) { + remainingUnionTypes.head + } else { + Schema.createUnion(remainingUnionTypes.asJava) + } + val schemaType = toSqlTypeHelper( + remainingSchema, + existingRecordNames, + useStableIdForUnionType, + stableIdPrefixForUnionType, + recursiveFieldMaxDepth) + + if (schemaType == null) { + logInfo( + log"Dropping ${MDC(FIELD_NAME, avroSchema.getFullName)} of type " + + log"${MDC(FIELD_TYPE, avroSchema.getType.getName)} as it does not have any " + + log"fields left likely due to recursive depth limit." + ) + null } else { - toSqlTypeHelper( - Schema.createUnion(remainingUnionTypes.asJava), - existingRecordNames, - useStableIdForUnionType, - stableIdPrefixForUnionType).copy(nullable = true) + schemaType.copy(nullable = true) } } else avroSchema.getTypes.asScala.map(_.getType).toSeq match { case Seq(t1) => toSqlTypeHelper(avroSchema.getTypes.get(0), - existingRecordNames, useStableIdForUnionType, stableIdPrefixForUnionType) + existingRecordNames, useStableIdForUnionType, stableIdPrefixForUnionType, + recursiveFieldMaxDepth) case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) => SchemaType(LongType, nullable = false) case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) => @@ -201,29 +267,33 @@ object SchemaConverters { s, existingRecordNames, useStableIdForUnionType, - stableIdPrefixForUnionType) - - val fieldName = if (useStableIdForUnionType) { - // Avro's field name may be case sensitive, so field names for two named type - // could be "a" and "A" and we need to distinguish them. In this case, we throw - // an exception. - // Stable id prefix can be empty so the name of the field can be just the type. - val tempFieldName = s"${stableIdPrefixForUnionType}${s.getName}" - if (!fieldNameSet.add(tempFieldName.toLowerCase(Locale.ROOT))) { - throw new IncompatibleSchemaException( - "Cannot generate stable identifier for Avro union type due to name " + - s"conflict of type name ${s.getName}") - } - tempFieldName + stableIdPrefixForUnionType, + recursiveFieldMaxDepth) + if (schemaType == null) { + null } else { - s"member$i" - } + val fieldName = if (useStableIdForUnionType) { + // Avro's field name may be case sensitive, so field names for two named type + // could be "a" and "A" and we need to distinguish them. In this case, we throw + // an exception. + // Stable id prefix can be empty so the name of the field can be just the type. + val tempFieldName = s"${stableIdPrefixForUnionType}${s.getName}" + if (!fieldNameSet.add(tempFieldName.toLowerCase(Locale.ROOT))) { + throw new IncompatibleSchemaException( + "Cannot generate stable identifier for Avro union type due to name " + + s"conflict of type name ${s.getName}") + } + tempFieldName + } else { + s"member$i" + } - // All fields are nullable because only one of them is set at a time - StructField(fieldName, schemaType.dataType, nullable = true) - } + // All fields are nullable because only one of them is set at a time + StructField(fieldName, schemaType.dataType, nullable = true) + } + }.filter(_ != null).toSeq - SchemaType(StructType(fields.toArray), nullable = false) + SchemaType(StructType(fields), nullable = false) } case other => throw new IncompatibleSchemaException(s"Unsupported type $other") diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala index 1083c9916072..a13faf3b5156 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala @@ -105,7 +105,8 @@ case class AvroPartitionReaderFactory( datetimeRebaseMode, avroFilters, options.useStableIdForUnionType, - options.stableIdPrefixForUnionType) + options.stableIdPrefixForUnionType, + options.recursiveFieldMaxDepth) override val stopPosition = partitionedFile.start + partitionedFile.length override def next(): Boolean = hasNextRow diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala index 388347537a4d..311eda3a1b6a 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala @@ -291,7 +291,8 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite RebaseSpec(LegacyBehaviorPolicy.CORRECTED), filters, false, - "") + "", + -1) val deserialized = deserializer.deserialize(data) expected match { case None => assert(deserialized == None) diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala index 47faaf7662a5..a7f7abadcf48 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.execution.LocalTableScanExec import org.apache.spark.sql.functions.{col, lit, struct} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{BinaryType, StructType} +import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType} class AvroFunctionsSuite extends QueryTest with SharedSparkSession { import testImplicits._ @@ -374,6 +374,37 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { } } + + test("roundtrip in to_avro and from_avro - recursive schema") { + val catalystSchema = + StructType(Seq( + StructField("Id", IntegerType), + StructField("Name", StructType(Seq( + StructField("Id", IntegerType), + StructField("Name", StructType(Seq( + StructField("Id", IntegerType))))))))) + + val avroSchema = s""" + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [ + | {"name": "Id", "type": "int"}, + | {"name": "Name", "type": ["null", "test_schema"]} + | ] + |} + """.stripMargin + + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(2, Row(3, Row(4))), Row(1, null))), + catalystSchema).select(struct("Id", "Name").as("struct")) + + val avroStructDF = df.select(functions.to_avro($"struct", avroSchema).as("avro")) + checkAnswer(avroStructDF.select( + functions.from_avro($"avro", avroSchema, Map( + "recursiveFieldMaxDepth" -> "3").asJava)), df) + } + private def serialize(record: GenericRecord, avroSchema: String): Array[Byte] = { val schema = new Schema.Parser().parse(avroSchema) val datumWriter = new GenericDatumWriter[GenericRecord](schema) diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala index 9b3bb929a700..c1ab96a63eb2 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala @@ -77,7 +77,8 @@ class AvroRowReaderSuite RebaseSpec(CORRECTED), new NoopFilters, false, - "") + "", + -1) override val stopPosition = fileSize override def hasNext: Boolean = hasNextRow diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala index cbcbc2e7e76a..3643a95abe19 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala @@ -228,7 +228,8 @@ object AvroSerdeSuite { RebaseSpec(CORRECTED), new NoopFilters, false, - "") + "", + -1) } /** diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 14ed6c43e4c0..be887bd5237b 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -2220,7 +2220,8 @@ abstract class AvroSuite } } - private def checkSchemaWithRecursiveLoop(avroSchema: String): Unit = { + private def checkSchemaWithRecursiveLoop(avroSchema: String, recursiveFieldMaxDepth: Int): + Unit = { val message = intercept[IncompatibleSchemaException] { SchemaConverters.toSqlType(new Schema.Parser().parse(avroSchema), false, "") }.getMessage @@ -2229,7 +2230,79 @@ abstract class AvroSuite } test("Detect recursive loop") { - checkSchemaWithRecursiveLoop(""" + for (recursiveFieldMaxDepth <- Seq(-1, 0)) { + checkSchemaWithRecursiveLoop( + """ + |{ + | "type": "record", + | "name": "LongList", + | "fields" : [ + | {"name": "value", "type": "long"}, // each element has a long + | {"name": "next", "type": ["null", "LongList"]} // optional next element + | ] + |} + """.stripMargin, recursiveFieldMaxDepth) + + checkSchemaWithRecursiveLoop( + """ + |{ + | "type": "record", + | "name": "LongList", + | "fields": [ + | { + | "name": "value", + | "type": { + | "type": "record", + | "name": "foo", + | "fields": [ + | { + | "name": "parent", + | "type": "LongList" + | } + | ] + | } + | } + | ] + |} + """.stripMargin, recursiveFieldMaxDepth) + + checkSchemaWithRecursiveLoop( + """ + |{ + | "type": "record", + | "name": "LongList", + | "fields" : [ + | {"name": "value", "type": "long"}, + | {"name": "array", "type": {"type": "array", "items": "LongList"}} + | ] + |} + """.stripMargin, recursiveFieldMaxDepth) + + checkSchemaWithRecursiveLoop( + """ + |{ + | "type": "record", + | "name": "LongList", + | "fields" : [ + | {"name": "value", "type": "long"}, + | {"name": "map", "type": {"type": "map", "values": "LongList"}} + | ] + |} + """.stripMargin, recursiveFieldMaxDepth) + } + } + + private def checkSparkSchemaEquals( + avroSchema: String, expectedSchema: StructType, recursiveFieldMaxDepth: Int): Unit = { + val sparkSchema = + SchemaConverters.toSqlType( + new Schema.Parser().parse(avroSchema), false, "", recursiveFieldMaxDepth).dataType + + assert(sparkSchema === expectedSchema) + } + + test("Translate recursive schema - union") { + val avroSchema = """ |{ | "type": "record", | "name": "LongList", @@ -2238,9 +2311,57 @@ abstract class AvroSuite | {"name": "next", "type": ["null", "LongList"]} // optional next element | ] |} - """.stripMargin) + """.stripMargin + val nonRecursiveFields = new StructType().add("value", LongType, nullable = false) + var expectedSchema = nonRecursiveFields + for (i <- 1 to 5) { + checkSparkSchemaEquals(avroSchema, expectedSchema, i) + expectedSchema = nonRecursiveFields.add("next", expectedSchema) + } + } + + test("Translate recursive schema - union - 2 non-null fields") { + val avroSchema = """ + |{ + | "type": "record", + | "name": "TreeNode", + | "fields": [ + | { + | "name": "name", + | "type": "string" + | }, + | { + | "name": "value", + | "type": [ + | "long" + | ] + | }, + | { + | "name": "children", + | "type": [ + | "null", + | { + | "type": "array", + | "items": "TreeNode" + | } + | ], + | "default": null + | } + | ] + |} + """.stripMargin + val nonRecursiveFields = new StructType().add("name", StringType, nullable = false) + .add("value", LongType, nullable = false) + var expectedSchema = nonRecursiveFields + for (i <- 1 to 5) { + checkSparkSchemaEquals(avroSchema, expectedSchema, i) + expectedSchema = nonRecursiveFields.add("children", + new ArrayType(expectedSchema, false), nullable = true) + } + } - checkSchemaWithRecursiveLoop(""" + test("Translate recursive schema - record") { + val avroSchema = """ |{ | "type": "record", | "name": "LongList", @@ -2260,9 +2381,18 @@ abstract class AvroSuite | } | ] |} - """.stripMargin) + """.stripMargin + val nonRecursiveFields = new StructType().add("value", StructType(Seq()), nullable = false) + var expectedSchema = nonRecursiveFields + for (i <- 1 to 5) { + checkSparkSchemaEquals(avroSchema, expectedSchema, i) + expectedSchema = new StructType().add("value", + new StructType().add("parent", expectedSchema, nullable = false), nullable = false) + } + } - checkSchemaWithRecursiveLoop(""" + test("Translate recursive schema - array") { + val avroSchema = """ |{ | "type": "record", | "name": "LongList", @@ -2271,9 +2401,18 @@ abstract class AvroSuite | {"name": "array", "type": {"type": "array", "items": "LongList"}} | ] |} - """.stripMargin) + """.stripMargin + val nonRecursiveFields = new StructType().add("value", LongType, nullable = false) + var expectedSchema = nonRecursiveFields + for (i <- 1 to 5) { + checkSparkSchemaEquals(avroSchema, expectedSchema, i) + expectedSchema = + nonRecursiveFields.add("array", new ArrayType(expectedSchema, false), nullable = false) + } + } - checkSchemaWithRecursiveLoop(""" + test("Translate recursive schema - map") { + val avroSchema = """ |{ | "type": "record", | "name": "LongList", @@ -2282,7 +2421,70 @@ abstract class AvroSuite | {"name": "map", "type": {"type": "map", "values": "LongList"}} | ] |} - """.stripMargin) + """.stripMargin + val nonRecursiveFields = new StructType().add("value", LongType, nullable = false) + var expectedSchema = nonRecursiveFields + for (i <- 1 to 5) { + checkSparkSchemaEquals(avroSchema, expectedSchema, i) + expectedSchema = + nonRecursiveFields.add("map", + new MapType(StringType, expectedSchema, false), nullable = false) + } + } + + test("recursive schema integration test") { + val catalystSchema = + StructType(Seq( + StructField("Id", IntegerType), + StructField("Name", StructType(Seq( + StructField("Id", IntegerType), + StructField("Name", StructType(Seq( + StructField("Id", IntegerType), + StructField("Name", NullType))))))))) + + val avroSchema = s""" + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [ + | {"name": "Id", "type": "int"}, + | {"name": "Name", "type": ["null", "test_schema"]} + | ] + |} + """.stripMargin + + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(2, Row(3, Row(4, null))), Row(1, null))), + catalystSchema) + + withTempPath { tempDir => + df.write.format("avro").save(tempDir.getPath) + + val exc = intercept[AnalysisException] { + spark.read + .format("avro") + .option("avroSchema", avroSchema) + .option("recursiveFieldMaxDepth", 16) + .load(tempDir.getPath) + } + assert(exc.getMessage.contains("Should not be greater than 15.")) + + checkAnswer( + spark.read + .format("avro") + .option("avroSchema", avroSchema) + .option("recursiveFieldMaxDepth", 10) + .load(tempDir.getPath), + df) + + checkAnswer( + spark.read + .format("avro") + .option("avroSchema", avroSchema) + .option("recursiveFieldMaxDepth", 1) + .load(tempDir.getPath), + df.select("Id")) + } } test("log a warning of ignoreExtension deprecation") { @@ -2777,7 +2979,7 @@ abstract class AvroSuite } test("SPARK-40667: validate Avro Options") { - assert(AvroOptions.getAllOptions.size == 11) + assert(AvroOptions.getAllOptions.size == 12) // Please add validation on any new Avro options here assert(AvroOptions.isValidOption("ignoreExtension")) assert(AvroOptions.isValidOption("mode")) @@ -2790,6 +2992,7 @@ abstract class AvroSuite assert(AvroOptions.isValidOption("datetimeRebaseMode")) assert(AvroOptions.isValidOption("enableStableIdentifiersForUnionType")) assert(AvroOptions.isValidOption("stableIdentifierPrefixForUnionType")) + assert(AvroOptions.isValidOption("recursiveFieldMaxDepth")) } test("SPARK-46633: read file with empty blocks") { diff --git a/docs/sql-data-sources-avro.md b/docs/sql-data-sources-avro.md index 3721f92d9326..c06e1fd46d2d 100644 --- a/docs/sql-data-sources-avro.md +++ b/docs/sql-data-sources-avro.md @@ -353,6 +353,13 @@ Data source options of Avro can be set via: read 4.0.0 + + recursiveFieldMaxDepth + -1 + If this option is specified to negative or is set to 0, recursive fields are not permitted. Setting it to 1 drops all recursive fields, 2 allows recursive fields to be recursed once, and 3 allows it to be recursed twice and so on, up to 15. Values larger than 15 are not allowed in order to avoid inadvertently creating very large schemas. If an avro message has depth beyond this limit, the Spark struct returned is truncated after the recursion limit. An example of usage can be found in section Handling circular references of Avro fields + read + 4.0.0 + ## Configuration @@ -628,3 +635,41 @@ You can also specify the whole output Avro schema with the option `avroSchema`, decimal + +## Handling circular references of Avro fields +In Avro, a circular reference occurs when the type of a field is defined in one of the parent records. This can cause issues when parsing the data, as it can result in infinite loops or other unexpected behavior. +To read Avro data with schema that has circular reference, users can use the `recursiveFieldMaxDepth` option to specify the maximum number of levels of recursion to allow when parsing the schema. By default, Spark Avro data source will not permit recursive fields by setting `recursiveFieldMaxDepth` to -1. However, you can set this option to 1 to 15 if needed. + +Setting `recursiveFieldMaxDepth` to 1 drops all recursive fields, setting it to 2 allows it to be recursed once, and setting it to 3 allows it to be recursed twice. A `recursiveFieldMaxDepth` value greater than 15 is not allowed, as it can lead to performance issues and even stack overflows. + +SQL Schema for the below Avro message will vary based on the value of `recursiveFieldMaxDepth`. + +
+
+This div is only used to make markdown editor/viewer happy and does not display on web + +```avro +
+ +{% highlight avro %} +{ + "type": "record", + "name": "Node", + "fields": [ + {"name": "Id", "type": "int"}, + {"name": "Next", "type": ["null", "Node"]} + ] +} + +// The Avro schema defined above, would be converted into a Spark SQL columns with the following +// structure based on `recursiveFieldMaxDepth` value. + +1: struct +2: struct> +3: struct>> + +{% endhighlight %} +
+``` +
+
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index fa8ea2f5289f..29802d5277e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -4090,6 +4090,13 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat ) } + def avroOptionsException(optionName: String, message: String): Throwable = { + new AnalysisException( + errorClass = "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE", + messageParameters = Map("optionName" -> optionName, "message" -> message) + ) + } + def protobufNotLoadedSqlFunctionsUnusable(functionName: String): Throwable = { new AnalysisException( errorClass = "PROTOBUF_NOT_LOADED_SQL_FUNCTIONS_UNUSABLE",