Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -276,11 +276,14 @@ object AgnosticEncoders {
* another encoder. This is fallback for scenarios where objects can't be represented using
* standard encoders, an example of this is where we use a different (opaque) serialization
* format (i.e. java serialization, kryo serialization, or protobuf).
* @param nullable
* defaults to false indicating the codec guarantees decode / encode results are non-nullable
*/
case class TransformingEncoder[I, O](
clsTag: ClassTag[I],
transformed: AgnosticEncoder[O],
codecProvider: () => Codec[_ >: I, O])
codecProvider: () => Codec[_ >: I, O],
override val nullable: Boolean = false)
extends AgnosticEncoder[I] {
override def isPrimitive: Boolean = transformed.isPrimitive
override def dataType: DataType = transformed.dataType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst

import org.apache.spark.sql.catalyst.{expressions => exprs}
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec, JavaSerializationCodec, KryoSerializationCodec}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, AgnosticExpressionPathEncoder, Codec, JavaSerializationCodec, KryoSerializationCodec}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedLeafEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, VarcharEncoder, YearMonthIntervalEncoder}
import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder}
import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, IsNull, Literal, MapKeys, MapValues, UpCast}
Expand Down Expand Up @@ -270,6 +270,8 @@ object DeserializerBuildHelper {
enc: AgnosticEncoder[_],
path: Expression,
walkedTypePath: WalkedTypePath): Expression = enc match {
case ae: AgnosticExpressionPathEncoder[_] =>
ae.fromCatalyst(path)
case _ if isNativeEncoder(enc) =>
path
case _: BoxedLeafEncoder[_, _] =>
Expand Down Expand Up @@ -447,13 +449,13 @@ object DeserializerBuildHelper {
val result = InitializeJavaBean(newInstance, setters.toMap)
exprs.If(IsNull(path), exprs.Literal.create(null, ObjectType(cls)), result)

case TransformingEncoder(tag, _, codec) if codec == JavaSerializationCodec =>
case TransformingEncoder(tag, _, codec, _) if codec == JavaSerializationCodec =>
DecodeUsingSerializer(path, tag, kryo = false)

case TransformingEncoder(tag, _, codec) if codec == KryoSerializationCodec =>
case TransformingEncoder(tag, _, codec, _) if codec == KryoSerializationCodec =>
DecodeUsingSerializer(path, tag, kryo = true)

case TransformingEncoder(tag, encoder, provider) =>
case TransformingEncoder(tag, encoder, provider, _) =>
Invoke(
Literal.create(provider(), ObjectType(classOf[Codec[_, _]])),
"decode",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.language.existentials

import org.apache.spark.sql.catalyst.{expressions => exprs}
import org.apache.spark.sql.catalyst.DeserializerBuildHelper.expressionWithNullSafety
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec, JavaSerializationCodec, KryoSerializationCodec}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, AgnosticExpressionPathEncoder, Codec, JavaSerializationCodec, KryoSerializationCodec}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLeafEncoder, BoxedLongEncoder, BoxedShortEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveLeafEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, VarcharEncoder, YearMonthIntervalEncoder}
import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder, lenientExternalDataTypeFor}
import org.apache.spark.sql.catalyst.expressions.{BoundReference, CheckOverflow, CreateNamedStruct, Expression, IsNull, KnownNotNull, Literal, UnsafeArrayData}
Expand Down Expand Up @@ -306,6 +306,7 @@ object SerializerBuildHelper {
* by encoder `enc`.
*/
private def createSerializer(enc: AgnosticEncoder[_], input: Expression): Expression = enc match {
case ae: AgnosticExpressionPathEncoder[_] => ae.toCatalyst(input)
case _ if isNativeEncoder(enc) => input
case BoxedBooleanEncoder => createSerializerForBoolean(input)
case BoxedByteEncoder => createSerializerForByte(input)
Expand Down Expand Up @@ -418,18 +419,21 @@ object SerializerBuildHelper {
}
createSerializerForObject(input, serializedFields)

case TransformingEncoder(_, _, codec) if codec == JavaSerializationCodec =>
case TransformingEncoder(_, _, codec, _) if codec == JavaSerializationCodec =>
EncodeUsingSerializer(input, kryo = false)

case TransformingEncoder(_, _, codec) if codec == KryoSerializationCodec =>
case TransformingEncoder(_, _, codec, _) if codec == KryoSerializationCodec =>
EncodeUsingSerializer(input, kryo = true)

case TransformingEncoder(_, encoder, codecProvider) =>
case TransformingEncoder(_, encoder, codecProvider, _) =>
val encoded = Invoke(
Literal(codecProvider(), ObjectType(classOf[Codec[_, _]])),
"encode",
externalDataTypeFor(encoder),
input :: Nil)
input :: Nil,
propagateNull = input.nullable,
returnNullable = input.nullable
)
createSerializer(encoder, encoded)
}

Expand Down Expand Up @@ -486,6 +490,7 @@ object SerializerBuildHelper {
nullable: Boolean): Expression => Expression = { input =>
val expected = enc match {
case OptionEncoder(_) => lenientExternalDataTypeFor(enc)
case TransformingEncoder(_, transformed, _, _) => lenientExternalDataTypeFor(transformed)
case _ => enc.dataType
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.encoders

import scala.collection.Map

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, CalendarIntervalEncoder, NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, SparkDecimalEncoder, VariantEncoder}
import org.apache.spark.sql.catalyst.expressions.Expression
Expand All @@ -26,6 +27,30 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, DataType, DateType, DayTimeIntervalType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ObjectType, ShortType, StringType, StructType, TimestampNTZType, TimestampType, UserDefinedType, VariantType, YearMonthIntervalType}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal}

/**
* :: DeveloperApi ::
* Extensible [[AgnosticEncoder]] providing conversion extension points over type T
* @tparam T over T
*/
@DeveloperApi
@deprecated("This trait is intended only as a migration tool and will be removed in 4.1")
trait AgnosticExpressionPathEncoder[T]
extends AgnosticEncoder[T] {
/**
* Converts from T to InternalRow
* @param input the starting input path
* @return
*/
def toCatalyst(input: Expression): Expression

/**
* Converts from InternalRow to T
* @param inputPath path expression from InternalRow
* @return
*/
def fromCatalyst(inputPath: Expression): Expression
}

/**
* Helper class for Generating [[ExpressionEncoder]]s.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.SparkRuntimeException
import org.apache.spark.sql.{Encoder, Row}
import org.apache.spark.sql.catalyst.{DeserializerBuildHelper, InternalRow, JavaTypeInference, ScalaReflection, SerializerBuildHelper}
import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{OptionEncoder, TransformingEncoder}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.{Deserializer, Serializer}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
Expand Down Expand Up @@ -215,6 +216,13 @@ case class ExpressionEncoder[T](
StructField(s.name, s.dataType, s.nullable)
})

private def transformerOfOption(enc: AgnosticEncoder[_]): Boolean =
enc match {
case t: TransformingEncoder[_, _] => transformerOfOption(t.transformed)
case _: OptionEncoder[_] => true
case _ => false
}

/**
* Returns true if the type `T` is serialized as a struct by `objSerializer`.
*/
Expand All @@ -228,7 +236,8 @@ case class ExpressionEncoder[T](
* returns true if `T` is serialized as struct and is not `Option` type.
*/
def isSerializedAsStructForTopLevel: Boolean = {
isSerializedAsStruct && !classOf[Option[_]].isAssignableFrom(clsTag.runtimeClass)
isSerializedAsStruct && !classOf[Option[_]].isAssignableFrom(clsTag.runtimeClass) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isSerializedAsStruct && !transformerOfOption(encoder)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if we should make these checks part of the AgnosticEncoder api.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it'd make sense, for the path encoder backwards compat logic I can embed / document that in shim. The Builders could embed that. I can take a stab at that post rc2.

!transformerOfOption(encoder)
}

// serializer expressions are used to encode an object to a row, while the object is usually an
Expand Down
Loading