From d590dd0efd0372df16c6913806e7b39d1174b2c0 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 19 Aug 2024 16:08:52 -0400 Subject: [PATCH 1/3] Support transforming encoders --- .../scala/org/apache/spark/sql/Encoders.scala | 27 +++++++++++- .../client/arrow/ArrowEncoderSuite.scala | 35 +++++++++++++--- .../catalyst/encoders/AgnosticEncoder.scala | 17 +++++++- .../spark/sql/catalyst/encoders/codecs.scala | 42 +++++++++++++++++++ .../catalyst/DeserializerBuildHelper.scala | 13 ++++-- .../sql/catalyst/SerializerBuildHelper.scala | 14 +++++-- .../encoders/ExpressionEncoderSuite.scala | 14 +++++++ .../client/arrow/ArrowDeserializer.scala | 7 ++++ .../client/arrow/ArrowSerializer.scala | 10 ++++- 9 files changed, 165 insertions(+), 14 deletions(-) create mode 100644 sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala index 74f013380313..39270587f3df 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -16,10 +16,11 @@ */ package org.apache.spark.sql +import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder => RowEncoderFactory} +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, JavaSerializationCodec, RowEncoder => RowEncoderFactory} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.types.StructType @@ -176,6 +177,30 @@ object Encoders { */ def row(schema: StructType): Encoder[Row] = RowEncoderFactory.encoderFor(schema) + /** + * (Scala-specific) Creates an encoder that serializes objects of type T using generic Java + * serialization. This encoder maps T into a single byte array (binary) field. + * + * T must be publicly accessible. + * + * @note This is extremely inefficient and should only be used as the last resort. + * @since 1.6.0 + */ + def javaSerialization[T: ClassTag]: Encoder[T] = { + TransformingEncoder(implicitly[ClassTag[T]], BinaryEncoder, JavaSerializationCodec) + } + + /** + * Creates an encoder that serializes objects of type T using generic Java serialization. + * This encoder maps T into a single byte array (binary) field. + * + * T must be publicly accessible. + * + * @note This is extremely inefficient and should only be used as the last resort. + * @since 1.6.0 + */ + def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz)) + private def tupleEncoder[T](encoders: Encoder[_]*): Encoder[T] = { ProductEncoder.tuple(encoders.asInstanceOf[Seq[AgnosticEncoder[_]]]).asInstanceOf[Encoder[T]] } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index 709e2cf0e84e..f827a5c11658 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -30,11 +30,11 @@ import org.apache.arrow.memory.{BufferAllocator, RootAllocator} import org.apache.arrow.vector.VarBinaryVector import org.scalatest.BeforeAndAfterAll -import org.apache.spark.SparkUnsupportedOperationException -import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.{sql, SparkUnsupportedOperationException} +import org.apache.spark.sql.{AnalysisException, Encoders, Row} import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, JavaTypeInference, ScalaReflection} -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, OuterScopes} -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, RowEncoder, ScalaDecimalEncoder, StringEncoder, TimestampEncoder, UDTEncoder, YearMonthIntervalEncoder} +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, Codec, OuterScopes} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, RowEncoder, ScalaDecimalEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, YearMonthIntervalEncoder} import org.apache.spark.sql.catalyst.encoders.RowEncoder.{encoderFor => toRowEncoder} import org.apache.spark.sql.catalyst.util.{DateFormatter, SparkStringUtils, TimestampFormatter} import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_SECOND @@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.util.SparkIntervalUtils._ import org.apache.spark.sql.connect.client.CloseableIterator import org.apache.spark.sql.connect.client.arrow.FooEnum.FooEnum import org.apache.spark.sql.test.ConnectFunSuite -import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType, Decimal, DecimalType, IntegerType, Metadata, SQLUserDefinedType, StructType, UserDefinedType, YearMonthIntervalType} +import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType, Decimal, DecimalType, IntegerType, Metadata, SQLUserDefinedType, StringType, StructType, UserDefinedType, YearMonthIntervalType} /** * Tests for encoding external data to and from arrow. @@ -769,6 +769,26 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { } } + test("java serialization") { + val encoder = sql.encoderFor(Encoders.javaSerialization[(Int, String)]) + roundTripAndCheckIdentical(encoder) { () => + Iterator.tabulate(10)(i => (i, "itr_" + i)) + } + } + + test("transforming encoder") { + val schema = new StructType() + .add("key", IntegerType) + .add("value", StringType) + val encoder = TransformingEncoder( + classTag[(Int, String)], + toRowEncoder(schema), + () => new TestCodec) + roundTripAndCheckIdentical(encoder) { () => + Iterator.tabulate(10)(i => (i, "v" + i)) + } + } + /* ******************************************************************** * * Arrow deserialization upcasting * ******************************************************************** */ @@ -1136,3 +1156,8 @@ class UDTNotSupported extends UserDefinedType[UDTNotSupportedClass] { case i: Int => UDTNotSupportedClass(i) } } + +class TestCodec extends Codec[(Int, String), Row] { + override def encode(in: (Int, String)): Row = Row(in._1, in._2) + override def decode(out: Row): (Int, String) = (out.getInt(0), out.getString(1)) +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala index 9133abce88ad..639b23f71414 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala @@ -247,5 +247,20 @@ object AgnosticEncoders { ScalaDecimalEncoder(DecimalType.SYSTEM_DEFAULT) val DEFAULT_JAVA_DECIMAL_ENCODER: JavaDecimalEncoder = JavaDecimalEncoder(DecimalType.SYSTEM_DEFAULT, lenientSerialization = false) -} + /** + * Encoder that transforms external data into a representation that can be further processed by + * another encoder. This is fallback for scenarios where objects can't be represented using + * standard encoders, an example of this is where we use a different (opaque) serialization + * format (i.e. java serialization, kryo serialization, or protobuf). + */ + case class TransformingEncoder[I, O]( + clsTag: ClassTag[I], + transformed: AgnosticEncoder[O], + codecProvider: () => Codec[_ >: I, O]) extends AgnosticEncoder[I] { + override def isPrimitive: Boolean = transformed.isPrimitive + override def dataType: DataType = transformed.dataType + override def schema: StructType = transformed.schema + override def isStruct: Boolean = transformed.isStruct + } +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala new file mode 100644 index 000000000000..46862ebbccdf --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.encoders + +import org.apache.spark.util.SparkSerDeUtils + +/** + * Codec for doing conversions between two representations. + * + * @tparam I input type (typically the external representation of the data. + * @tparam O output type (typically the internal representation of the data. + */ +trait Codec[I, O] { + def encode(in: I): O + def decode(out: O): I +} + +/** + * A codec that uses Java Serialization as its output format. + */ +class JavaSerializationCodec[I] extends Codec[I, Array[Byte]] { + override def encode(in: I): Array[Byte] = SparkSerDeUtils.serialize(in) + override def decode(out: Array[Byte]): I = SparkSerDeUtils.deserialize(out) +} + +object JavaSerializationCodec extends (() => Codec[Any, Array[Byte]]) { + override def apply(): Codec[Any, Array[Byte]] = new JavaSerializationCodec[Any] +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index 0b88d5a4130e..40b49506b58a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.{expressions => exprs} import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue} -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders} -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedLeafEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, UDTEncoder, YearMonthIntervalEncoder} +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedLeafEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, YearMonthIntervalEncoder} import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder} -import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, IsNull, MapKeys, MapValues, UpCast} +import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, IsNull, Literal, MapKeys, MapValues, UpCast} import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, InitializeJavaBean, Invoke, NewInstance, StaticInvoke, UnresolvedCatalystToExternalMap, UnresolvedMapObjects, WrapOption} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, IntervalUtils} import org.apache.spark.sql.types._ @@ -410,6 +410,13 @@ object DeserializerBuildHelper { val newInstance = NewInstance(cls, Nil, ObjectType(cls), propagateNull = false) val result = InitializeJavaBean(newInstance, setters.toMap) exprs.If(IsNull(path), exprs.Literal.create(null, ObjectType(cls)), result) + + case TransformingEncoder(tag, encoder, provider) => + Invoke( + Literal.create(provider(), ObjectType(classOf[Codec[_, _]])), + "decode", + ObjectType(tag.runtimeClass), + createDeserializer(encoder, path, walkedTypePath) :: Nil) } private def deserializeArray( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala index cd087514f4be..38bf0651d6f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala @@ -21,10 +21,10 @@ import scala.language.existentials import org.apache.spark.sql.catalyst.{expressions => exprs} import org.apache.spark.sql.catalyst.DeserializerBuildHelper.expressionWithNullSafety -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders} -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLeafEncoder, BoxedLongEncoder, BoxedShortEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveLeafEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, UDTEncoder, YearMonthIntervalEncoder} +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLeafEncoder, BoxedLongEncoder, BoxedShortEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveLeafEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, YearMonthIntervalEncoder} import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder, lenientExternalDataTypeFor} -import org.apache.spark.sql.catalyst.expressions.{BoundReference, CheckOverflow, CreateNamedStruct, Expression, IsNull, KnownNotNull, UnsafeArrayData} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, CheckOverflow, CreateNamedStruct, Expression, IsNull, KnownNotNull, Literal, UnsafeArrayData} import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData, IntervalUtils} import org.apache.spark.sql.internal.SQLConf @@ -397,6 +397,14 @@ object SerializerBuildHelper { f.name -> createSerializer(f.enc, fieldValue) } createSerializerForObject(input, serializedFields) + + case TransformingEncoder(_, encoder, codecProvider) => + val encoded = Invoke( + Literal(codecProvider(), ObjectType(classOf[Codec[_, _]])), + "encode", + externalDataTypeFor(encoder), + input :: Nil) + createSerializer(encoder, encoded) } private def serializerForArray( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index f46c02326e8b..7b8d8be6bbee 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp} import java.util.Arrays import scala.collection.mutable.ArrayBuffer +import scala.reflect.classTag import scala.reflect.runtime.universe.TypeTag import org.apache.spark.{SPARK_DOC_ROOT, SparkArithmeticException, SparkRuntimeException, SparkUnsupportedOperationException} @@ -29,6 +30,7 @@ import org.apache.spark.sql.{Encoder, Encoders, Row} import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, OptionalData, PrimitiveData, ScroogeLikeExample} import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, TransformingEncoder} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, NaNvl} import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest import org.apache.spark.sql.catalyst.plans.logical.LocalRelation @@ -550,6 +552,18 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes encodeDecodeTest(FooClassWithEnum(1, FooEnum.E1), "case class with Int and scala Enum") encodeDecodeTest(FooEnum.E1, "scala Enum") + test("transforming encoder") { + val encoder = ExpressionEncoder(TransformingEncoder( + classTag[(Long, Long)], + BinaryEncoder, + JavaSerializationCodec)) + .resolveAndBind() + assert(encoder.schema == new StructType().add("value", BinaryType)) + val toRow = encoder.createSerializer() + val fromRow = encoder.createDeserializer() + assert(fromRow(toRow((11, 14))) == (11, 14)) + } + // Scala / Java big decimals ---------------------------------------------------------- encodeDecodeTest(BigDecimal(("9" * 20) + "." + "9" * 18), diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala index 17d8444574f6..f3abaddb0110 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala @@ -359,6 +359,13 @@ object ArrowDeserializers { } } + case (TransformingEncoder(_, encoder, provider), v) => + new Deserializer[Any] { + private[this] val codec = provider() + private[this] val deserializer = deserializerFor(encoder, v, timeZoneId) + override def get(i: Int): Any = codec.decode(deserializer.get(i)) + } + case (CalendarIntervalEncoder | VariantEncoder | _: UDTEncoder[_], _) => throw ExecutionErrors.unsupportedDataTypeError(encoder.dataType) diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala index 4b7b39423545..f8a5c63ac3ab 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala @@ -35,7 +35,7 @@ import org.apache.arrow.vector.util.Text import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.DefinedByConstructorParams -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, Codec} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.util.{SparkDateTimeUtils, SparkIntervalUtils} import org.apache.spark.sql.connect.client.CloseableIterator @@ -442,6 +442,14 @@ object ArrowSerializer { o => getter.invoke(o) } + case (TransformingEncoder(_, encoder, provider), v) => + new Serializer { + private[this] val codec = provider().asInstanceOf[Codec[Any, Any]] + private[this] val delegate: Serializer = serializerFor(encoder, v) + override def write(index: Int, value: Any): Unit = + delegate.write(index, codec.encode(value)) + } + case (CalendarIntervalEncoder | VariantEncoder | _: UDTEncoder[_], _) => throw ExecutionErrors.unsupportedDataTypeError(encoder.dataType) From c9c5f0062fe443b32a200a1c30cc977385f6b507 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 19 Aug 2024 16:09:46 -0400 Subject: [PATCH 2/3] Style --- .../src/main/scala/org/apache/spark/sql/Encoders.scala | 10 ++++++---- .../sql/connect/client/arrow/ArrowEncoderSuite.scala | 6 ++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala index 39270587f3df..5a1e8e84f0d5 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -183,7 +183,8 @@ object Encoders { * * T must be publicly accessible. * - * @note This is extremely inefficient and should only be used as the last resort. + * @note + * This is extremely inefficient and should only be used as the last resort. * @since 1.6.0 */ def javaSerialization[T: ClassTag]: Encoder[T] = { @@ -191,12 +192,13 @@ object Encoders { } /** - * Creates an encoder that serializes objects of type T using generic Java serialization. - * This encoder maps T into a single byte array (binary) field. + * Creates an encoder that serializes objects of type T using generic Java serialization. This + * encoder maps T into a single byte array (binary) field. * * T must be publicly accessible. * - * @note This is extremely inefficient and should only be used as the last resort. + * @note + * This is extremely inefficient and should only be used as the last resort. * @since 1.6.0 */ def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz)) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index f827a5c11658..70b471cf74b3 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -780,10 +780,8 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { val schema = new StructType() .add("key", IntegerType) .add("value", StringType) - val encoder = TransformingEncoder( - classTag[(Int, String)], - toRowEncoder(schema), - () => new TestCodec) + val encoder = + TransformingEncoder(classTag[(Int, String)], toRowEncoder(schema), () => new TestCodec) roundTripAndCheckIdentical(encoder) { () => Iterator.tabulate(10)(i => (i, "v" + i)) } From bb833a7e5f835bafa6486182f522284118175408 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 20 Aug 2024 00:27:49 -0400 Subject: [PATCH 3/3] Version --- .../jvm/src/main/scala/org/apache/spark/sql/Encoders.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala index 5a1e8e84f0d5..ffd997577006 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -185,7 +185,7 @@ object Encoders { * * @note * This is extremely inefficient and should only be used as the last resort. - * @since 1.6.0 + * @since 4.0.0 */ def javaSerialization[T: ClassTag]: Encoder[T] = { TransformingEncoder(implicitly[ClassTag[T]], BinaryEncoder, JavaSerializationCodec) @@ -199,7 +199,7 @@ object Encoders { * * @note * This is extremely inefficient and should only be used as the last resort. - * @since 1.6.0 + * @since 4.0.0 */ def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz))