Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ private[sql] case class AvroDataToCatalyst(

override lazy val dataType: DataType = {
val dt = SchemaConverters.toSqlType(
expectedSchema, avroOptions.useStableIdForUnionType).dataType
expectedSchema,
avroOptions.useStableIdForUnionType,
avroOptions.stableIdPrefixForUnionType).dataType
parseMode match {
// With PermissiveMode, the output Catalyst row might contain columns of null values for
// corrupt records, even if some of the columns are not nullable in the user-provided schema.
Expand All @@ -62,8 +64,12 @@ private[sql] case class AvroDataToCatalyst(
@transient private lazy val reader = new GenericDatumReader[Any](actualSchema, expectedSchema)

@transient private lazy val deserializer =
new AvroDeserializer(expectedSchema, dataType,
avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType)
new AvroDeserializer(
expectedSchema,
dataType,
avroOptions.datetimeRebaseModeInRead,
avroOptions.useStableIdForUnionType,
avroOptions.stableIdPrefixForUnionType)

@transient private var decoder: BinaryDecoder = _

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,23 @@ private[sql] class AvroDeserializer(
positionalFieldMatch: Boolean,
datetimeRebaseSpec: RebaseSpec,
filters: StructFilters,
useStableIdForUnionType: Boolean) {
useStableIdForUnionType: Boolean,
stableIdPrefixForUnionType: String) {

def this(
rootAvroType: Schema,
rootCatalystType: DataType,
datetimeRebaseMode: String,
useStableIdForUnionType: Boolean) = {
useStableIdForUnionType: Boolean,
stableIdPrefixForUnionType: String) = {
this(
rootAvroType,
rootCatalystType,
positionalFieldMatch = false,
RebaseSpec(LegacyBehaviorPolicy.withName(datetimeRebaseMode)),
new NoopFilters,
useStableIdForUnionType)
useStableIdForUnionType,
stableIdPrefixForUnionType)
}

private lazy val decimalConversions = new DecimalConversion()
Expand Down Expand Up @@ -124,7 +127,8 @@ private[sql] class AvroDeserializer(
val incompatibleMsg = errorPrefix +
s"schema is incompatible (avroType = $avroType, sqlType = ${catalystType.sql})"

val realDataType = SchemaConverters.toSqlType(avroType, useStableIdForUnionType).dataType
val realDataType = SchemaConverters.toSqlType(
avroType, useStableIdForUnionType, stableIdPrefixForUnionType).dataType

(avroType.getType, catalystType) match {
case (NULL, NullType) => (updater, ordinal, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ private[sql] class AvroFileFormat extends FileFormat
parsedOptions.positionalFieldMatching,
datetimeRebaseMode,
avroFilters,
parsedOptions.useStableIdForUnionType)
parsedOptions.useStableIdForUnionType,
parsedOptions.stableIdPrefixForUnionType)
override val stopPosition = file.start + file.length

override def hasNext: Boolean = hasNextRow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ private[sql] class AvroOptions(

val useStableIdForUnionType: Boolean =
parameters.get(STABLE_ID_FOR_UNION_TYPE).map(_.toBoolean).getOrElse(false)

val stableIdPrefixForUnionType: String = parameters
.getOrElse(STABLE_ID_PREFIX_FOR_UNION_TYPE, "member_")
}

private[sql] object AvroOptions extends DataSourceOptions {
Expand Down Expand Up @@ -164,4 +167,7 @@ private[sql] object AvroOptions extends DataSourceOptions {
// type name are identical regardless of case, an exception will be raised. However, in other
// cases, the field names can be uniquely identified.
val STABLE_ID_FOR_UNION_TYPE = newOption("enableStableIdentifiersForUnionType")
// When STABLE_ID_FOR_UNION_TYPE is enabled, the option allows to configure the prefix for fields
// of Avro Union type.
val STABLE_ID_PREFIX_FOR_UNION_TYPE = newOption("stableIdentifierPrefixForUnionType")
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ private[sql] object AvroUtils extends Logging {
new FileSourceOptions(CaseInsensitiveMap(options)).ignoreCorruptFiles)
}

SchemaConverters.toSqlType(avroSchema, parsedOptions.useStableIdForUnionType).dataType match {
SchemaConverters.toSqlType(
avroSchema,
parsedOptions.useStableIdForUnionType,
parsedOptions.stableIdPrefixForUnionType).dataType match {
case t: StructType => Some(t)
case _ => throw new RuntimeException(
s"""Avro schema cannot be converted to a Spark SQL StructType:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,29 @@ object SchemaConverters {
*
* @since 4.0.0
*/
def toSqlType(avroSchema: Schema, useStableIdForUnionType: Boolean): SchemaType = {
toSqlTypeHelper(avroSchema, Set.empty, useStableIdForUnionType)
def toSqlType(
avroSchema: Schema,
useStableIdForUnionType: Boolean,
stableIdPrefixForUnionType: String): SchemaType = {
toSqlTypeHelper(avroSchema, Set.empty, useStableIdForUnionType, stableIdPrefixForUnionType)
}
/**
* Converts an Avro schema to a corresponding Spark SQL schema.
*
* @since 2.4.0
*/
def toSqlType(avroSchema: Schema): SchemaType = {
toSqlType(avroSchema, false)
toSqlType(avroSchema, false, "")
}

@deprecated("using toSqlType(..., useStableIdForUnionType: Boolean) instead", "4.0.0")
def toSqlType(avroSchema: Schema, options: Map[String, String]): SchemaType = {
toSqlTypeHelper(avroSchema, Set.empty, AvroOptions(options).useStableIdForUnionType)
val avroOptions = AvroOptions(options)
toSqlTypeHelper(
avroSchema,
Set.empty,
avroOptions.useStableIdForUnionType,
avroOptions.stableIdPrefixForUnionType)
}

// The property specifies Catalyst type of the given field
Expand All @@ -74,7 +82,8 @@ object SchemaConverters {
private def toSqlTypeHelper(
avroSchema: Schema,
existingRecordNames: Set[String],
useStableIdForUnionType: Boolean): SchemaType = {
useStableIdForUnionType: Boolean,
stableIdPrefixForUnionType: String): SchemaType = {
avroSchema.getType match {
case INT => avroSchema.getLogicalType match {
case _: Date => SchemaType(DateType, nullable = false)
Expand Down Expand Up @@ -127,7 +136,11 @@ object SchemaConverters {
}
val newRecordNames = existingRecordNames + avroSchema.getFullName
val fields = avroSchema.getFields.asScala.map { f =>
val schemaType = toSqlTypeHelper(f.schema(), newRecordNames, useStableIdForUnionType)
val schemaType = toSqlTypeHelper(
f.schema(),
newRecordNames,
useStableIdForUnionType,
stableIdPrefixForUnionType)
StructField(f.name, schemaType.dataType, schemaType.nullable)
}

Expand All @@ -137,14 +150,15 @@ object SchemaConverters {
val schemaType = toSqlTypeHelper(
avroSchema.getElementType,
existingRecordNames,
useStableIdForUnionType)
useStableIdForUnionType,
stableIdPrefixForUnionType)
SchemaType(
ArrayType(schemaType.dataType, containsNull = schemaType.nullable),
nullable = false)

case MAP =>
val schemaType = toSqlTypeHelper(avroSchema.getValueType,
existingRecordNames, useStableIdForUnionType)
existingRecordNames, useStableIdForUnionType, stableIdPrefixForUnionType)
SchemaType(
MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable),
nullable = false)
Expand All @@ -154,18 +168,22 @@ object SchemaConverters {
// In case of a union with null, eliminate it and make a recursive call
val remainingUnionTypes = AvroUtils.nonNullUnionBranches(avroSchema)
if (remainingUnionTypes.size == 1) {
toSqlTypeHelper(remainingUnionTypes.head, existingRecordNames, useStableIdForUnionType)
.copy(nullable = true)
toSqlTypeHelper(
remainingUnionTypes.head,
existingRecordNames,
useStableIdForUnionType,
stableIdPrefixForUnionType).copy(nullable = true)
} else {
toSqlTypeHelper(
Schema.createUnion(remainingUnionTypes.asJava),
existingRecordNames,
useStableIdForUnionType).copy(nullable = true)
useStableIdForUnionType,
stableIdPrefixForUnionType).copy(nullable = true)
}
} else avroSchema.getTypes.asScala.map(_.getType).toSeq match {
case Seq(t1) =>
toSqlTypeHelper(avroSchema.getTypes.get(0),
existingRecordNames, useStableIdForUnionType)
existingRecordNames, useStableIdForUnionType, stableIdPrefixForUnionType)
case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) =>
SchemaType(LongType, nullable = false)
case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) =>
Expand All @@ -179,20 +197,26 @@ object SchemaConverters {
val fieldNameSet : mutable.Set[String] = mutable.Set()
val fields = avroSchema.getTypes.asScala.zipWithIndex.map {
case (s, i) =>
val schemaType = toSqlTypeHelper(s, existingRecordNames, useStableIdForUnionType)
val schemaType = toSqlTypeHelper(
s,
existingRecordNames,
useStableIdForUnionType,
stableIdPrefixForUnionType)

val fieldName = if (useStableIdForUnionType) {
// Avro's field name may be case sensitive, so field names for two named type
// could be "a" and "A" and we need to distinguish them. In this case, we throw
// an exception.
val temp_name = s"member_${s.getName.toLowerCase(Locale.ROOT)}"
if (fieldNameSet.contains(temp_name)) {
// Stable id prefix can be empty so the name of the field can be just the type.
val tempFieldName =
s"${stableIdPrefixForUnionType}${s.getName.toLowerCase(Locale.ROOT)}"
if (fieldNameSet.contains(tempFieldName)) {
throw new IncompatibleSchemaException(
"Cannot generate stable indentifier for Avro union type due to name " +
s"conflict of type name ${s.getName}")
}
fieldNameSet.add(temp_name)
temp_name
fieldNameSet.add(tempFieldName)
tempFieldName
} else {
s"member$i"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ case class AvroPartitionReaderFactory(
options.positionalFieldMatching,
datetimeRebaseMode,
avroFilters,
options.useStableIdForUnionType)
options.useStableIdForUnionType,
options.stableIdPrefixForUnionType)
override val stopPosition = partitionedFile.start + partitionedFile.length

override def next(): Boolean = hasNextRow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite

val expected = {
val avroSchema = new Schema.Parser().parse(schema)
SchemaConverters.toSqlType(avroSchema, false).dataType match {
SchemaConverters.toSqlType(avroSchema, false, "").dataType match {
case st: StructType => Row.fromSeq((0 until st.length).map(_ => null))
case _ => null
}
Expand Down Expand Up @@ -283,14 +283,15 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite
data: GenericData.Record,
expected: Option[Any],
filters: StructFilters = new NoopFilters): Unit = {
val dataType = SchemaConverters.toSqlType(schema, false).dataType
val dataType = SchemaConverters.toSqlType(schema, false, "").dataType
val deserializer = new AvroDeserializer(
schema,
dataType,
false,
RebaseSpec(LegacyBehaviorPolicy.CORRECTED),
filters,
false)
false,
"")
val deserialized = deserializer.deserialize(data)
expected match {
case None => assert(deserialized == None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,8 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession {
val avroOptions = AvroOptions(options)
val avroSchema = avroOptions.schema.get
val sparkSchema = SchemaConverters
.toSqlType(avroSchema, avroOptions.useStableIdForUnionType)
.toSqlType(avroSchema, avroOptions.useStableIdForUnionType,
avroOptions.stableIdPrefixForUnionType)
.dataType
.asInstanceOf[StructType]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ class AvroRowReaderSuite
false,
RebaseSpec(CORRECTED),
new NoopFilters,
false)
false,
"")
override val stopPosition = fileSize

override def hasNext: Boolean = hasNextRow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ object AvroSerdeSuite {
isPositional(matchType),
RebaseSpec(CORRECTED),
new NoopFilters,
false)
false,
"")
}

/**
Expand Down
Loading