diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index 55613b2b2013..ce2bb243b38c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.{expressions => exprs} import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue} -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec, JavaSerializationCodec, KryoSerializationCodec} +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, AgnosticExpressionPathEncoder, Codec, JavaSerializationCodec, KryoSerializationCodec} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedLeafEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, VarcharEncoder, YearMonthIntervalEncoder} import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder} import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, IsNull, Literal, MapKeys, MapValues, UpCast} @@ -270,6 +270,8 @@ object DeserializerBuildHelper { enc: AgnosticEncoder[_], path: Expression, walkedTypePath: WalkedTypePath): Expression = enc match { + case ae: AgnosticExpressionPathEncoder[_] => + ae.fromCatalyst(path) case _ if isNativeEncoder(enc) => path case _: BoxedLeafEncoder[_, _] => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala index 089d463ecacb..108ade1fec74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala @@ -21,7 +21,7 @@ import scala.language.existentials import org.apache.spark.sql.catalyst.{expressions => exprs} import org.apache.spark.sql.catalyst.DeserializerBuildHelper.expressionWithNullSafety -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec, JavaSerializationCodec, KryoSerializationCodec} +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, AgnosticExpressionPathEncoder, Codec, JavaSerializationCodec, KryoSerializationCodec} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLeafEncoder, BoxedLongEncoder, BoxedShortEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveLeafEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, VarcharEncoder, YearMonthIntervalEncoder} import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder, lenientExternalDataTypeFor} import org.apache.spark.sql.catalyst.expressions.{BoundReference, CheckOverflow, CreateNamedStruct, Expression, IsNull, KnownNotNull, Literal, UnsafeArrayData} @@ -306,6 +306,7 @@ object SerializerBuildHelper { * by encoder `enc`. */ private def createSerializer(enc: AgnosticEncoder[_], input: Expression): Expression = enc match { + case ae: AgnosticExpressionPathEncoder[_] => ae.toCatalyst(input) case _ if isNativeEncoder(enc) => input case BoxedBooleanEncoder => createSerializerForBoolean(input) case BoxedByteEncoder => createSerializerForByte(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala index 81743251bada..0eaf9361e02a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.encoders import scala.collection.Map +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, CalendarIntervalEncoder, NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, SparkDecimalEncoder, VariantEncoder} import org.apache.spark.sql.catalyst.expressions.Expression @@ -26,6 +27,29 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, DataType, DateType, DayTimeIntervalType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ObjectType, ShortType, StringType, StructType, TimestampNTZType, TimestampType, UserDefinedType, VariantType, YearMonthIntervalType} import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal} +/** + * :: DeveloperApi :: + * Extensible [[AgnosticEncoder]] providing conversion extension points over type T + * @tparam T over T + */ +@DeveloperApi +trait AgnosticExpressionPathEncoder[T] + extends AgnosticEncoder[T] { + /** + * Converts from T to InternalRow + * @param input the starting input path + * @return + */ + def toCatalyst(input: Expression): Expression + + /** + * Converts from InternalRow to T + * @param inputPath path expression from InternalRow + * @return + */ + def fromCatalyst(inputPath: Expression): Expression +} + /** * Helper class for Generating [[ExpressionEncoder]]s. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 8963c9de4ee4..272ad4783d15 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -35,9 +35,11 @@ import org.apache.spark.TestUtils.withListener import org.apache.spark.internal.config.MAX_RESULT_SIZE import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, ScroogeLikeExample} -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoders, ExpressionEncoder, OuterScopes} -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.BoxedIntEncoder -import org.apache.spark.sql.catalyst.expressions.{CodegenObjectFactoryMode, GenericRowWithSchema} +import org.apache.spark.sql.catalyst.DeserializerBuildHelper.createDeserializerForString +import org.apache.spark.sql.catalyst.SerializerBuildHelper.createSerializerForString +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoders, AgnosticExpressionPathEncoder, ExpressionEncoder, OuterScopes} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedIntEncoder, ProductEncoder} +import org.apache.spark.sql.catalyst.expressions.{CodegenObjectFactoryMode, Expression, GenericRowWithSchema} import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.trees.DataFrameQueryContext import org.apache.spark.sql.catalyst.util.sideBySide @@ -2803,6 +2805,21 @@ class DatasetSuite extends QueryTest } } + test("SPARK-49960: joinWith custom encoder") { + /* + test based on "joinWith class with primitive, toDF" + with "custom" encoder. Removing the use of AgnosticExpressionPathEncoder + within SerializerBuildHelper and DeserializerBuildHelper will trigger MatchErrors + */ + val ds1 = Seq(1, 1, 2).toDS() + val ds2 = SparkSession.active.createDataset[ClassData](Seq(ClassData("a", 1), + ClassData("b", 2)))(CustomPathEncoder.custClassDataEnc) + + checkAnswer( + ds1.joinWith(ds2, $"value" === $"b").toDF().select($"_1", $"_2.a", $"_2.b"), + Row(1, "a", 1) :: Row(1, "a", 1) :: Row(2, "b", 2) :: Nil) + } + test("SPARK-49961: transform type should be consistent (classic)") { val ds = Seq(1, 2).toDS() val f: classic.Dataset[Int] => classic.Dataset[Int] = @@ -2828,6 +2845,38 @@ class DatasetSuite extends QueryTest } } +/** + * SPARK-49960 - Mimic a custom encoder such as those provided by typelevel Frameless + */ +object CustomPathEncoder { + + val realClassDataEnc: ProductEncoder[ClassData] = + Encoders.product[ClassData].asInstanceOf[ProductEncoder[ClassData]] + + val custStringEnc: AgnosticExpressionPathEncoder[String] = + new AgnosticExpressionPathEncoder[String] { + + override def toCatalyst(input: Expression): Expression = + createSerializerForString(input) + + override def fromCatalyst(inputPath: Expression): Expression = + createDeserializerForString(inputPath, returnNullable = false) + + override def isPrimitive: Boolean = false + + override def dataType: DataType = StringType + + override def clsTag: ClassTag[String] = implicitly[ClassTag[String]] + + override def isStruct: Boolean = true + } + + val custClassDataEnc: ProductEncoder[ClassData] = realClassDataEnc.copy(fields = + Seq(realClassDataEnc.fields.head.copy(enc = custStringEnc), + realClassDataEnc.fields.last) + ) +} + class DatasetLargeResultCollectingSuite extends QueryTest with SharedSparkSession {