diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageFileFormat.scala b/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageFileFormat.scala index 7c09ce5b7a781..868056bd3cdd9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageFileFormat.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageFileFormat.scala @@ -90,8 +90,8 @@ private[image] class ImageFileFormat extends FileFormat with DataSourceRegister if (requiredSchema.isEmpty) { filteredResult.map(_ => emptyUnsafeRow) } else { - val converter = RowEncoder(requiredSchema) - filteredResult.map(row => converter.toRow(row)) + val toRow = RowEncoder(requiredSchema).createSerializer() + filteredResult.map(row => toRow(row)) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 6ead4df87fb54..da8f3a24ff27e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -166,7 +166,7 @@ private[libsvm] class LibSVMFileFormat LabeledPoint(label, Vectors.sparse(numFeatures, indices, values)) } - val converter = RowEncoder(dataSchema) + val toRow = RowEncoder(dataSchema).createSerializer() val fullOutput = dataSchema.map { f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)() } @@ -178,7 +178,7 @@ private[libsvm] class LibSVMFileFormat points.map { pt => val features = if (isSparse) pt.features.toSparse else pt.features.toDense - requiredColumns(converter.toRow(Row(pt.label, features))) + requiredColumns(toRow(Row(pt.label, features))) } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala index 5f19e466ecad0..3caa8f6d5b1e5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala @@ -38,12 +38,14 @@ object UDTSerializationBenchmark extends BenchmarkBase { val iters = 1e2.toInt val numRows = 1e3.toInt - val encoder = ExpressionEncoder[Vector].resolveAndBind() + val encoder = ExpressionEncoder[Vector]().resolveAndBind() + val toRow = encoder.createSerializer() + val fromRow = encoder.createDeserializer() val vectors = (1 to numRows).map { i => Vectors.dense(Array.fill(1e5.toInt)(1.0 * i)) }.toArray - val rows = vectors.map(encoder.toRow) + val rows = vectors.map(toRow) val benchmark = new Benchmark("VectorUDT de/serialization", numRows, iters, output = output) @@ -51,7 +53,7 @@ object UDTSerializationBenchmark extends BenchmarkBase { var sum = 0 var i = 0 while (i < numRows) { - sum += encoder.toRow(vectors(i)).numFields + sum += toRow(vectors(i)).numFields i += 1 } } @@ -60,7 +62,7 @@ object UDTSerializationBenchmark extends BenchmarkBase { var sum = 0 var i = 0 while (i < numRows) { - sum += encoder.fromRow(rows(i)).numActives + sum += fromRow(rows(i)).numActives i += 1 } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index c43a86ad48ec9..ea760d80541c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -58,8 +58,7 @@ import org.apache.spark.sql.types._ * }}} * * == Implementation == - * - Encoders are not required to be thread-safe and thus they do not need to use locks to guard - * against concurrent access if they reuse internal buffers to improve performance. + * - Encoders should be thread-safe. * * @since 1.6.0 */ @@ -76,10 +75,4 @@ trait Encoder[T] extends Serializable { * A ClassTag that can be used to construct an Array to contain a collection of `T`. */ def clsTag: ClassTag[T] - - /** - * Create a copied [[Encoder]]. The implementation may just copy internal reusable fields to speed - * up the [[Encoder]] creation. - */ - def makeCopy: Encoder[T] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index b820cb1a5c522..f08416fcaba8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.catalyst.encoders +import java.io.ObjectInputStream + import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag} import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.{Deserializer, Serializer} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, InitializeJavaBean, Invoke, NewInstance} @@ -162,6 +165,56 @@ object ExpressionEncoder { e4: ExpressionEncoder[T4], e5: ExpressionEncoder[T5]): ExpressionEncoder[(T1, T2, T3, T4, T5)] = tuple(Seq(e1, e2, e3, e4, e5)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]] + + private val anyObjectType = ObjectType(classOf[Any]) + + /** + * Function that deserializes an [[InternalRow]] into an object of type `T`. This class is not + * thread-safe. + */ + class Deserializer[T](private val expressions: Seq[Expression]) + extends (InternalRow => T) with Serializable { + @transient + private[this] var constructProjection: Projection = _ + + override def apply(row: InternalRow): T = try { + if (constructProjection == null) { + constructProjection = SafeProjection.create(expressions) + } + constructProjection(row).get(0, anyObjectType).asInstanceOf[T] + } catch { + case e: Exception => + throw new RuntimeException(s"Error while decoding: $e\n" + + s"${expressions.map(_.simpleString(SQLConf.get.maxToStringFields)).mkString("\n")}", e) + } + } + + /** + * Function that serializesa an object of type `T` to an [[InternalRow]]. This class is not + * thread-safe. Note that multiple calls to `apply(..)` return the same actual [[InternalRow]] + * object. Thus, the caller should copy the result before making another call if required. + */ + class Serializer[T](private val expressions: Seq[Expression]) + extends (T => InternalRow) with Serializable { + @transient + private[this] var inputRow: GenericInternalRow = _ + + @transient + private[this] var extractProjection: UnsafeProjection = _ + + override def apply(t: T): InternalRow = try { + if (extractProjection == null) { + inputRow = new GenericInternalRow(1) + extractProjection = GenerateUnsafeProjection.generate(expressions) + } + inputRow(0) = t + extractProjection(inputRow) + } catch { + case e: Exception => + throw new RuntimeException(s"Error while encoding: $e\n" + + s"${expressions.map(_.simpleString(SQLConf.get.maxToStringFields)).mkString("\n")}", e) + } + } } /** @@ -302,25 +355,22 @@ case class ExpressionEncoder[T]( } @transient - private lazy val extractProjection = GenerateUnsafeProjection.generate({ + private lazy val optimizedDeserializer: Seq[Expression] = { // When using `ExpressionEncoder` directly, we will skip the normal query processing steps // (analyzer, optimizer, etc.). Here we apply the ReassignLambdaVariableID rule, as it's // important to codegen performance. - val optimizedPlan = ReassignLambdaVariableID.apply(DummyExpressionHolder(serializer)) + val optimizedPlan = ReassignLambdaVariableID.apply(DummyExpressionHolder(Seq(deserializer))) optimizedPlan.asInstanceOf[DummyExpressionHolder].exprs - }) - - @transient - private lazy val inputRow = new GenericInternalRow(1) + } @transient - private lazy val constructProjection = SafeProjection.create({ + private lazy val optimizedSerializer = { // When using `ExpressionEncoder` directly, we will skip the normal query processing steps // (analyzer, optimizer, etc.). Here we apply the ReassignLambdaVariableID rule, as it's // important to codegen performance. - val optimizedPlan = ReassignLambdaVariableID.apply(DummyExpressionHolder(Seq(deserializer))) + val optimizedPlan = ReassignLambdaVariableID.apply(DummyExpressionHolder(serializer)) optimizedPlan.asInstanceOf[DummyExpressionHolder].exprs - }) + } /** * Returns a new set (with unique ids) of [[NamedExpression]] that represent the serialized form @@ -332,31 +382,21 @@ case class ExpressionEncoder[T]( } /** - * Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to - * toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should - * copy the result before making another call if required. + * Create a serializer that can convert an object of type `T` to a Spark SQL Row. + * + * Note that the returned [[Serializer]] is not thread safe. Multiple calls to + * `serializer.apply(..)` are allowed to return the same actual [[InternalRow]] object. Thus, + * the caller should copy the result before making another call if required. */ - def toRow(t: T): InternalRow = try { - inputRow(0) = t - extractProjection(inputRow) - } catch { - case e: Exception => - throw new RuntimeException(s"Error while encoding: $e\n" + - s"${serializer.map(_.simpleString(SQLConf.get.maxToStringFields)).mkString("\n")}", e) - } + def createSerializer(): Serializer[T] = new Serializer[T](optimizedSerializer) /** - * Returns an object of type `T`, extracting the required values from the provided row. Note that - * you must `resolveAndBind` an encoder to a specific schema before you can call this - * function. + * Create a deserializer that can convert a Spark SQL Row into an object of type `T`. + * + * Note that you must `resolveAndBind` an encoder to a specific schema before you can create a + * deserializer. */ - def fromRow(row: InternalRow): T = try { - constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T] - } catch { - case e: Exception => - throw new RuntimeException(s"Error while decoding: $e\n" + - s"${deserializer.simpleString(SQLConf.get.maxToStringFields)}", e) - } + def createDeserializer(): Deserializer[T] = new Deserializer[T](optimizedDeserializer) /** * The process of resolution to a given schema throws away information about where a given field @@ -383,8 +423,6 @@ case class ExpressionEncoder[T]( .map { case(f, a) => s"${f.name}$a: ${f.dataType.simpleString}"}.mkString(", ") override def toString: String = s"class[$schemaString]" - - override def makeCopy: ExpressionEncoder[T] = copy() } // A dummy logical plan that can hold expressions and go through optimizer rules. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 1ac7ca676a876..e80f03ea84756 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -110,8 +110,8 @@ case class ScalaUDF( } else { val encoder = inputEncoders(i) if (encoder.isDefined && encoder.get.isSerializedAsStructForTopLevel) { - val enc = encoder.get.resolveAndBind() - row: Any => enc.fromRow(row.asInstanceOf[InternalRow]) + val fromRow = encoder.get.resolveAndBind().createDeserializer() + row: Any => fromRow(row.asInstanceOf[InternalRow]) } else { CatalystTypeConverters.createToScalaConverter(dataType) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala index 3b4b80daf0843..3f0121bcf4a63 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala @@ -41,13 +41,13 @@ object HashBenchmark extends BenchmarkBase { def test(name: String, schema: StructType, numRows: Int, iters: Int): Unit = { runBenchmark(name) { val generator = RandomDataGenerator.forType(schema, nullable = false).get - val encoder = RowEncoder(schema) + val toRow = RowEncoder(schema).createSerializer() val attrs = schema.toAttributes val safeProjection = GenerateSafeProjection.generate(attrs, attrs) val rows = (1 to numRows).map(_ => // The output of encoder is UnsafeRow, use safeProjection to turn in into safe format. - safeProjection(encoder.toRow(generator().asInstanceOf[Row])).copy() + safeProjection(toRow(generator().asInstanceOf[Row])).copy() ).toArray val benchmark = new Benchmark("Hash For " + name, iters * numRows.toLong, output = output) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala index 42a4cfc91f826..950e313fb727a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala @@ -37,8 +37,8 @@ object UnsafeProjectionBenchmark extends BenchmarkBase { def generateRows(schema: StructType, numRows: Int): Array[InternalRow] = { val generator = RandomDataGenerator.forType(schema, nullable = false).get - val encoder = RowEncoder(schema) - (1 to numRows).map(_ => encoder.toRow(generator().asInstanceOf[Row]).copy()).toArray + val toRow = RowEncoder(schema).createSerializer() + (1 to numRows).map(_ => toRow(generator().asInstanceOf[Row]).copy()).toArray } override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 53cb8bce0a52d..48f4ef5051fb3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ @@ -42,22 +43,29 @@ case class NestedArrayClass(nestedArr: Array[ArrayClass]) class EncoderResolutionSuite extends PlanTest { private val str = UTF8String.fromString("hello") + def testFromRow[T]( + encoder: ExpressionEncoder[T], + attributes: Seq[Attribute], + row: InternalRow): Unit = { + encoder.resolveAndBind(attributes).createDeserializer().apply(row) + } + test("real type doesn't match encoder schema but they are compatible: product") { val encoder = ExpressionEncoder[StringLongClass] // int type can be up cast to long type val attrs1 = Seq('a.string, 'b.int) - encoder.resolveAndBind(attrs1).fromRow(InternalRow(str, 1)) + testFromRow(encoder, attrs1, InternalRow(str, 1)) // int type can be up cast to string type val attrs2 = Seq('a.int, 'b.long) - encoder.resolveAndBind(attrs2).fromRow(InternalRow(1, 2L)) + testFromRow(encoder, attrs2, InternalRow(1, 2L)) } test("real type doesn't match encoder schema but they are compatible: nested product") { val encoder = ExpressionEncoder[ComplexClass] val attrs = Seq('a.int, 'b.struct('a.int, 'b.long)) - encoder.resolveAndBind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L))) + testFromRow(encoder, attrs, InternalRow(1, InternalRow(2, 3L))) } test("real type doesn't match encoder schema but they are compatible: tupled encoder") { @@ -65,14 +73,14 @@ class EncoderResolutionSuite extends PlanTest { ExpressionEncoder[StringLongClass], ExpressionEncoder[Long]) val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int) - encoder.resolveAndBind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2)) + testFromRow(encoder, attrs, InternalRow(InternalRow(str, 1.toByte), 2)) } test("real type doesn't match encoder schema but they are compatible: primitive array") { val encoder = ExpressionEncoder[PrimitiveArrayClass] val attrs = Seq('arr.array(IntegerType)) val array = new GenericArrayData(Array(1, 2, 3)) - encoder.resolveAndBind(attrs).fromRow(InternalRow(array)) + testFromRow(encoder, attrs, InternalRow(array)) } test("the real type is not compatible with encoder schema: primitive array") { @@ -93,7 +101,7 @@ class EncoderResolutionSuite extends PlanTest { val encoder = ExpressionEncoder[ArrayClass] val attrs = Seq('arr.array(new StructType().add("a", "int").add("b", "int").add("c", "int"))) val array = new GenericArrayData(Array(InternalRow(1, 2, 3))) - encoder.resolveAndBind(attrs).fromRow(InternalRow(array)) + testFromRow(encoder, attrs, InternalRow(array)) } test("real type doesn't match encoder schema but they are compatible: nested array") { @@ -103,7 +111,7 @@ class EncoderResolutionSuite extends PlanTest { val attrs = Seq('nestedArr.array(et)) val innerArr = new GenericArrayData(Array(InternalRow(1, 2, 3))) val outerArr = new GenericArrayData(Array(InternalRow(innerArr))) - encoder.resolveAndBind(attrs).fromRow(InternalRow(outerArr)) + testFromRow(encoder, attrs, InternalRow(outerArr)) } test("the real type is not compatible with encoder schema: non-array field") { @@ -142,14 +150,14 @@ class EncoderResolutionSuite extends PlanTest { val attrs = 'a.array(IntegerType) :: Nil // It should pass analysis - val bound = encoder.resolveAndBind(attrs) + val fromRow = encoder.resolveAndBind(attrs).createDeserializer() // If no null values appear, it should work fine - bound.fromRow(InternalRow(new GenericArrayData(Array(1, 2)))) + fromRow(InternalRow(new GenericArrayData(Array(1, 2)))) // If there is null value, it should throw runtime exception val e = intercept[RuntimeException] { - bound.fromRow(InternalRow(new GenericArrayData(Array(1, null)))) + fromRow(InternalRow(new GenericArrayData(Array(1, null)))) } assert(e.getMessage.contains("Null value appeared in non-nullable field")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 1036dc725c205..6a094d4aaddae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -369,14 +369,14 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes } test("null check for map key: String") { - val encoder = ExpressionEncoder[Map[String, Int]]() - val e = intercept[RuntimeException](encoder.toRow(Map(("a", 1), (null, 2)))) + val toRow = ExpressionEncoder[Map[String, Int]]().createSerializer() + val e = intercept[RuntimeException](toRow(Map(("a", 1), (null, 2)))) assert(e.getMessage.contains("Cannot use null as map key")) } test("null check for map key: Integer") { - val encoder = ExpressionEncoder[Map[Integer, String]]() - val e = intercept[RuntimeException](encoder.toRow(Map((1, "a"), (null, "b")))) + val toRow = ExpressionEncoder[Map[Integer, String]]().createSerializer() + val e = intercept[RuntimeException](toRow(Map((1, "a"), (null, "b")))) assert(e.getMessage.contains("Cannot use null as map key")) } @@ -436,10 +436,6 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes testOverflowingBigNumeric(BigInt("9" * 100), "scala very large big int") testOverflowingBigNumeric(new BigInteger("9" * 100), "java very big int") - encodeDecodeTest("foo" -> 1L, "makeCopy") { - Encoders.product[(String, Long)].makeCopy.asInstanceOf[ExpressionEncoder[(String, Long)]] - } - private def testOverflowingBigNumeric[T: TypeTag](bigNumeric: T, testName: String): Unit = { Seq(true, false).foreach { ansiEnabled => testAndVerifyNotLeakingReflectionObjects( @@ -450,12 +446,14 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes // Need to construct Encoder here rather than implicitly resolving it // so that SQLConf changes are respected. val encoder = ExpressionEncoder[T]() + val toRow = encoder.createSerializer() if (!ansiEnabled) { - val convertedBack = encoder.resolveAndBind().fromRow(encoder.toRow(bigNumeric)) + val fromRow = encoder.resolveAndBind().createDeserializer() + val convertedBack = fromRow(toRow(bigNumeric)) assert(convertedBack === null) } else { val e = intercept[RuntimeException] { - encoder.toRow(bigNumeric) + toRow(bigNumeric) } assert(e.getMessage.contains("Error while encoding")) assert(e.getCause.getClass === classOf[ArithmeticException]) @@ -474,10 +472,10 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes // Make sure encoder is serializable. ClosureCleaner.clean((s: String) => encoder.getClass.getName) - val row = encoder.toRow(input) + val row = encoder.createSerializer().apply(input) val schema = encoder.schema.toAttributes val boundEncoder = encoder.resolveAndBind() - val convertedBack = try boundEncoder.fromRow(row) catch { + val convertedBack = try boundEncoder.createDeserializer().apply(row) catch { case e: Exception => fail( s"""Exception thrown while decoding diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 1a1cab823d4f3..c1158e001a780 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.encoders import scala.util.Random import org.apache.spark.sql.{RandomDataGenerator, Row} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.internal.SQLConf @@ -81,6 +82,18 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { private val mapOfString = MapType(StringType, StringType) private val arrayOfUDT = ArrayType(new ExamplePointUDT, false) + private def toRow(encoder: ExpressionEncoder[Row], row: Row): InternalRow = { + encoder.createSerializer().apply(row) + } + + private def fromRow(encoder: ExpressionEncoder[Row], row: InternalRow): Row = { + encoder.createDeserializer().apply(row) + } + + private def roundTrip(encoder: ExpressionEncoder[Row], row: Row): Row = { + fromRow(encoder, toRow(encoder, row)) + } + encodeDecodeTest( new StructType() .add("null", NullType) @@ -144,8 +157,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { val catalystDecimal = Decimal("1234.5678") val input = Row(100, "test", 0.123, javaDecimal, scalaDecimal, catalystDecimal) - val row = encoder.toRow(input) - val convertedBack = encoder.fromRow(row) + val convertedBack = roundTrip(encoder, input) // Decimal will be converted back to Java BigDecimal when decoding. assert(convertedBack.getDecimal(3).compareTo(javaDecimal) == 0) assert(convertedBack.getDecimal(4).compareTo(scalaDecimal.bigDecimal) == 0) @@ -157,7 +169,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { val encoder = RowEncoder(schema).resolveAndBind() val decimal = Decimal("67123.45") val input = Row(decimal) - val row = encoder.toRow(input) + val row = toRow(encoder, input) assert(row.toSeq(schema).head == decimal) } @@ -172,7 +184,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { val encoder = RowEncoder(schema).resolveAndBind() intercept[Exception] { - encoder.toRow(row) + toRow(encoder, row) } match { case e: ArithmeticException => assert(e.getMessage.contains("cannot be represented as Decimal")) @@ -184,7 +196,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { val encoder = RowEncoder(schema).resolveAndBind() - assert(encoder.fromRow(encoder.toRow(row)).get(0) == null) + assert(roundTrip(encoder, row).get(0) == null) } } @@ -237,8 +249,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { Array(1.1.toFloat, 123.456.toFloat, Float.MaxValue), Array(11.1111, 123456.7890123, Double.MaxValue) ) - val row = encoder.toRow(Row.fromSeq(input)) - val convertedBack = encoder.fromRow(row) + val convertedBack = roundTrip(encoder, Row.fromSeq(input)) input.zipWithIndex.map { case (array, index) => assert(convertedBack.getSeq(index) === array) } @@ -254,8 +265,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { Array(1, 2, null), Array(Array("abc", null), null), Array(Seq(Array(0L, null), null), null)) - val row = encoder.toRow(input) - val convertedBack = encoder.fromRow(row) + val convertedBack = roundTrip(encoder, input) assert(convertedBack.getSeq(0) == Seq(1, 2, null)) assert(convertedBack.getSeq(1) == Seq(Seq("abc", null), null)) assert(convertedBack.getSeq(2) == Seq(Seq(Seq(0L, null), null), null)) @@ -264,7 +274,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { test("RowEncoder should throw RuntimeException if input row object is null") { val schema = new StructType().add("int", IntegerType) val encoder = RowEncoder(schema) - val e = intercept[RuntimeException](encoder.toRow(null)) + val e = intercept[RuntimeException](toRow(encoder, null)) assert(e.getMessage.contains("Null value appeared in non-nullable field")) assert(e.getMessage.contains("top level Product or row object")) } @@ -273,14 +283,14 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { val e1 = intercept[RuntimeException] { val schema = new StructType().add("a", IntegerType) val encoder = RowEncoder(schema) - encoder.toRow(Row(1.toShort)) + toRow(encoder, Row(1.toShort)) } assert(e1.getMessage.contains("java.lang.Short is not a valid external type")) val e2 = intercept[RuntimeException] { val schema = new StructType().add("a", StringType) val encoder = RowEncoder(schema) - encoder.toRow(Row(1)) + toRow(encoder, Row(1)) } assert(e2.getMessage.contains("java.lang.Integer is not a valid external type")) @@ -288,14 +298,14 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { val schema = new StructType().add("a", new StructType().add("b", IntegerType).add("c", StringType)) val encoder = RowEncoder(schema) - encoder.toRow(Row(1 -> "a")) + toRow(encoder, Row(1 -> "a")) } assert(e3.getMessage.contains("scala.Tuple2 is not a valid external type")) val e4 = intercept[RuntimeException] { val schema = new StructType().add("a", ArrayType(TimestampType)) val encoder = RowEncoder(schema) - encoder.toRow(Row(Array("a"))) + toRow(encoder, Row(Array("a"))) } assert(e4.getMessage.contains("java.lang.String is not a valid external type")) } @@ -313,9 +323,9 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { val schema = new StructType().add("t", TimestampType) val encoder = RowEncoder(schema).resolveAndBind() val instant = java.time.Instant.parse("2019-02-26T16:56:00Z") - val row = encoder.toRow(Row(instant)) + val row = toRow(encoder, Row(instant)) assert(row.getLong(0) === DateTimeUtils.instantToMicros(instant)) - val readback = encoder.fromRow(row) + val readback = fromRow(encoder, row) assert(readback.get(0) === instant) } } @@ -325,9 +335,9 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { val schema = new StructType().add("d", DateType) val encoder = RowEncoder(schema).resolveAndBind() val localDate = java.time.LocalDate.parse("2019-02-27") - val row = encoder.toRow(Row(localDate)) + val row = toRow(encoder, Row(localDate)) assert(row.getLong(0) === DateTimeUtils.localDateToDays(localDate)) - val readback = encoder.fromRow(row) + val readback = fromRow(encoder, row) assert(readback.get(0).equals(localDate)) } } @@ -374,8 +384,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { try { for (_ <- 1 to 5) { input = inputGenerator.apply().asInstanceOf[Row] - val row = encoder.toRow(input) - val convertedBack = encoder.fromRow(row) + val convertedBack = roundTrip(encoder, input) assert(input == convertedBack) } } catch { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala index 68da1faaa8f45..af6e5a3f35ee1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -699,11 +699,11 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { private def testHash(inputSchema: StructType): Unit = { val inputGenerator = RandomDataGenerator.forType(inputSchema, nullable = false).get - val encoder = RowEncoder(inputSchema) + val toRow = RowEncoder(inputSchema).createSerializer() val seed = scala.util.Random.nextInt() test(s"murmur3/xxHash64/hive hash: ${inputSchema.simpleString}") { for (_ <- 1 to 10) { - val input = encoder.toRow(inputGenerator.apply().asInstanceOf[Row]).asInstanceOf[UnsafeRow] + val input = toRow(inputGenerator.apply().asInstanceOf[Row]).asInstanceOf[UnsafeRow] val literals = input.toSeq(inputSchema).zip(inputSchema.map(_.dataType)).map { case (value, dt) => Literal.create(value, dt) } @@ -717,7 +717,7 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val longSeed = Math.abs(seed).toLong + Integer.MAX_VALUE.toLong test(s"SPARK-30633: xxHash64 with long seed: ${inputSchema.simpleString}") { for (_ <- 1 to 10) { - val input = encoder.toRow(inputGenerator.apply().asInstanceOf[Row]).asInstanceOf[UnsafeRow] + val input = toRow(inputGenerator.apply().asInstanceOf[Row]).asInstanceOf[UnsafeRow] val literals = input.toSeq(inputSchema).zip(inputSchema.map(_.dataType)).map { case (value, dt) => Literal.create(value, dt) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index ef7764dba1e9e..c40149368b055 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -445,8 +445,8 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { testTypes.foreach { dt => genSchema(dt).map { schema => val row = RandomDataGenerator.randomRow(random, schema) - val rowConverter = RowEncoder(schema) - val internalRow = rowConverter.toRow(row) + val toRow = RowEncoder(schema).createSerializer() + val internalRow = toRow(row) val lambda = LambdaVariable("dummy", schema(0).dataType, schema(0).nullable, id = 0) checkEvaluationWithoutCodegen(lambda, internalRow.get(0, schema(0).dataType), internalRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala index fb1ea7b867a6d..dd67a61015e72 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala @@ -60,8 +60,8 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite { test("rows with all empty int arrays") { val schema = StructType(Seq( StructField("f1", ArrayType(IntegerType)), StructField("f2", ArrayType(IntegerType)))) - val emptyIntArray = - ExpressionEncoder[Array[Int]]().resolveAndBind().toRow(Array.emptyIntArray).getArray(0) + val toRow = ExpressionEncoder[Array[Int]]().resolveAndBind().createSerializer() + val emptyIntArray = toRow(Array.emptyIntArray).getArray(0) val row: UnsafeRow = UnsafeProjection.create(schema).apply( InternalRow(emptyIntArray, emptyIntArray)) testConcat(schema, row, schema, row) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala index da71e3a4d53e2..1e430351b5137 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala @@ -73,8 +73,8 @@ class ArrayDataIndexedSeqSuite extends SparkFunSuite { arrayTypes.foreach { dt => val schema = StructType(StructField("col_1", dt, nullable = false) :: Nil) val row = RandomDataGenerator.randomRow(random, schema) - val rowConverter = RowEncoder(schema) - val internalRow = rowConverter.toRow(row) + val toRow = RowEncoder(schema).createSerializer() + val internalRow = toRow(row) val unsafeRowConverter = UnsafeProjection.create(schema) val safeRowConverter = SafeProjection.create(schema) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala index e7b1c0810a033..6d8ef68473778 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.util import java.time.{ZoneId, ZoneOffset} +import scala.reflect.runtime.universe.TypeTag + import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.Row @@ -70,75 +72,55 @@ class UnsafeArraySuite extends SparkFunSuite { arrayData } + private def toUnsafeArray[T : TypeTag](array: Array[T]): ArrayData = { + val converted = ExpressionEncoder[Array[T]].createSerializer().apply(array).getArray(0) + assert(converted.isInstanceOf[T]) + assert(converted.numElements == array.length) + converted + } + test("read array") { - val unsafeBoolean = ExpressionEncoder[Array[Boolean]].resolveAndBind(). - toRow(booleanArray).getArray(0) - assert(unsafeBoolean.isInstanceOf[UnsafeArrayData]) - assert(unsafeBoolean.numElements == booleanArray.length) + val unsafeBoolean = toUnsafeArray(booleanArray) booleanArray.zipWithIndex.map { case (e, i) => assert(unsafeBoolean.getBoolean(i) == e) } - val unsafeShort = ExpressionEncoder[Array[Short]].resolveAndBind(). - toRow(shortArray).getArray(0) - assert(unsafeShort.isInstanceOf[UnsafeArrayData]) - assert(unsafeShort.numElements == shortArray.length) + val unsafeShort = toUnsafeArray(shortArray) shortArray.zipWithIndex.map { case (e, i) => assert(unsafeShort.getShort(i) == e) } - val unsafeInt = ExpressionEncoder[Array[Int]].resolveAndBind(). - toRow(intArray).getArray(0) - assert(unsafeInt.isInstanceOf[UnsafeArrayData]) - assert(unsafeInt.numElements == intArray.length) + val unsafeInt = toUnsafeArray(intArray) intArray.zipWithIndex.map { case (e, i) => assert(unsafeInt.getInt(i) == e) } - val unsafeLong = ExpressionEncoder[Array[Long]].resolveAndBind(). - toRow(longArray).getArray(0) - assert(unsafeLong.isInstanceOf[UnsafeArrayData]) - assert(unsafeLong.numElements == longArray.length) + val unsafeLong = toUnsafeArray(longArray) longArray.zipWithIndex.map { case (e, i) => assert(unsafeLong.getLong(i) == e) } - val unsafeFloat = ExpressionEncoder[Array[Float]].resolveAndBind(). - toRow(floatArray).getArray(0) - assert(unsafeFloat.isInstanceOf[UnsafeArrayData]) - assert(unsafeFloat.numElements == floatArray.length) + val unsafeFloat = toUnsafeArray(floatArray) floatArray.zipWithIndex.map { case (e, i) => assert(unsafeFloat.getFloat(i) == e) } - val unsafeDouble = ExpressionEncoder[Array[Double]].resolveAndBind(). - toRow(doubleArray).getArray(0) - assert(unsafeDouble.isInstanceOf[UnsafeArrayData]) - assert(unsafeDouble.numElements == doubleArray.length) + val unsafeDouble = toUnsafeArray(doubleArray) doubleArray.zipWithIndex.map { case (e, i) => assert(unsafeDouble.getDouble(i) == e) } - val unsafeString = ExpressionEncoder[Array[String]].resolveAndBind(). - toRow(stringArray).getArray(0) - assert(unsafeString.isInstanceOf[UnsafeArrayData]) - assert(unsafeString.numElements == stringArray.length) + val unsafeString = toUnsafeArray(stringArray) stringArray.zipWithIndex.map { case (e, i) => assert(unsafeString.getUTF8String(i).toString().equals(e)) } - val unsafeDate = ExpressionEncoder[Array[Int]].resolveAndBind(). - toRow(dateArray).getArray(0) - assert(unsafeDate.isInstanceOf[UnsafeArrayData]) - assert(unsafeDate.numElements == dateArray.length) + val unsafeDate = toUnsafeArray(dateArray) dateArray.zipWithIndex.map { case (e, i) => assert(unsafeDate.get(i, DateType).asInstanceOf[Int] == e) } - val unsafeTimestamp = ExpressionEncoder[Array[Long]].resolveAndBind(). - toRow(timestampArray).getArray(0) - assert(unsafeTimestamp.isInstanceOf[UnsafeArrayData]) - assert(unsafeTimestamp.numElements == timestampArray.length) + val unsafeTimestamp = toUnsafeArray(timestampArray) timestampArray.zipWithIndex.map { case (e, i) => assert(unsafeTimestamp.get(i, TimestampType).asInstanceOf[Long] == e) } @@ -149,7 +131,7 @@ class UnsafeArraySuite extends SparkFunSuite { "array", ArrayType(DecimalType(decimal.precision, decimal.scale))) val encoder = RowEncoder(schema).resolveAndBind() val externalRow = Row(decimalArray) - val ir = encoder.toRow(externalRow) + val ir = encoder.createSerializer().apply(externalRow) val unsafeDecimal = ir.getArray(0) assert(unsafeDecimal.isInstanceOf[UnsafeArrayData]) @@ -162,7 +144,7 @@ class UnsafeArraySuite extends SparkFunSuite { val schema = new StructType().add("array", ArrayType(CalendarIntervalType)) val encoder = RowEncoder(schema).resolveAndBind() val externalRow = Row(calenderintervalArray) - val ir = encoder.toRow(externalRow) + val ir = encoder.createSerializer().apply(externalRow) val unsafeCalendar = ir.getArray(0) assert(unsafeCalendar.isInstanceOf[UnsafeArrayData]) assert(unsafeCalendar.numElements == calenderintervalArray.length) @@ -170,10 +152,7 @@ class UnsafeArraySuite extends SparkFunSuite { assert(unsafeCalendar.getInterval(i) == e) } - val unsafeMultiDimInt = ExpressionEncoder[Array[Array[Int]]].resolveAndBind(). - toRow(intMultiDimArray).getArray(0) - assert(unsafeMultiDimInt.isInstanceOf[UnsafeArrayData]) - assert(unsafeMultiDimInt.numElements == intMultiDimArray.length) + val unsafeMultiDimInt = toUnsafeArray(intMultiDimArray) intMultiDimArray.zipWithIndex.map { case (a, j) => val u = unsafeMultiDimInt.getArray(j) assert(u.isInstanceOf[UnsafeArrayData]) @@ -183,10 +162,7 @@ class UnsafeArraySuite extends SparkFunSuite { } } - val unsafeMultiDimDouble = ExpressionEncoder[Array[Array[Double]]].resolveAndBind(). - toRow(doubleMultiDimArray).getArray(0) - assert(unsafeDouble.isInstanceOf[UnsafeArrayData]) - assert(unsafeMultiDimDouble.numElements == doubleMultiDimArray.length) + val unsafeMultiDimDouble = toUnsafeArray(doubleMultiDimArray) doubleMultiDimArray.zipWithIndex.map { case (a, j) => val u = unsafeMultiDimDouble.getArray(j) assert(u.isInstanceOf[UnsafeArrayData]) @@ -216,11 +192,9 @@ class UnsafeArraySuite extends SparkFunSuite { } test("to primitive array") { - val intEncoder = ExpressionEncoder[Array[Int]].resolveAndBind() - assert(intEncoder.toRow(intArray).getArray(0).toIntArray.sameElements(intArray)) + assert(toUnsafeArray(intArray).toIntArray().sameElements(intArray)) - val doubleEncoder = ExpressionEncoder[Array[Double]].resolveAndBind() - assert(doubleEncoder.toRow(doubleArray).getArray(0).toDoubleArray.sameElements(doubleArray)) + assert(toUnsafeArray(doubleArray).toDoubleArray().sameElements(doubleArray)) } test("unsafe java serialization") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c897170c91faa..12160c9f4c192 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2965,9 +2965,8 @@ class Dataset[T] private[sql]( */ def toLocalIterator(): java.util.Iterator[T] = { withAction("toLocalIterator", queryExecution) { plan => - // `ExpressionEncoder` is not thread-safe, here we create a new encoder. - val enc = resolvedEnc.copy() - plan.executeToIterator().map(enc.fromRow).asJava + val fromRow = resolvedEnc.createDeserializer() + plan.executeToIterator().map(fromRow).asJava } } @@ -3387,9 +3386,10 @@ class Dataset[T] private[sql]( new JSONOptions(Map.empty[String, String], sessionLocalTimeZone)) new Iterator[String] { + private val toRow = exprEnc.createSerializer() override def hasNext: Boolean = iter.hasNext override def next(): String = { - gen.write(exprEnc.toRow(iter.next())) + gen.write(toRow(iter.next())) gen.flush() val json = writer.toString @@ -3649,9 +3649,8 @@ class Dataset[T] private[sql]( * Collect all elements from a spark plan. */ private def collectFromPlan(plan: SparkPlan): Array[T] = { - // `ExpressionEncoder` is not thread-safe, here we create a new encoder. - val enc = resolvedEnc.copy() - plan.executeCollect().map(enc.fromRow) + val fromRow = resolvedEnc.createDeserializer() + plan.executeCollect().map(fromRow) } private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index bca841c48cacd..731aae882a7b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -345,7 +345,8 @@ class SparkSession private( // TODO: use MutableProjection when rowRDD is another DataFrame and the applied // schema differs from the existing schema on any field data type. val encoder = RowEncoder(schema) - val catalystRows = rowRDD.map(encoder.toRow) + val toRow = encoder.createSerializer() + val catalystRows = rowRDD.map(toRow) internalCreateDataFrame(catalystRows.setName(rowRDD.name), schema) } @@ -459,10 +460,10 @@ class SparkSession private( * @since 2.0.0 */ def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = { - // `ExpressionEncoder` is not thread-safe, here we create a new encoder. - val enc = encoderFor[T].copy() + val enc = encoderFor[T] + val toRow = enc.createSerializer() val attributes = enc.schema.toAttributes - val encoded = data.map(d => enc.toRow(d).copy()) + val encoded = data.map(d => toRow(d).copy()) val plan = new LocalRelation(attributes, encoded) Dataset[T](self, plan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index bd2684d92a1d2..12a1a1e7fc16e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -656,7 +656,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case MemoryPlan(sink, output) => val encoder = RowEncoder(StructType.fromAttributes(output)) - LocalTableScanExec(output, sink.allData.map(r => encoder.toRow(r).copy())) :: Nil + val toRow = encoder.createSerializer() + LocalTableScanExec(output, sink.allData.map(r => toRow(r).copy())) :: Nil case logical.Distinct(child) => throw new IllegalStateException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index dfae5c07e0373..544b90a736071 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -469,14 +469,17 @@ case class ScalaAggregator[IN, BUF, OUT]( with ImplicitCastInputTypes with Logging { - private[this] lazy val inputEncoder = inputEncoderNR.resolveAndBind() + private[this] lazy val inputDeserializer = inputEncoderNR.resolveAndBind().createDeserializer() private[this] lazy val bufferEncoder = agg.bufferEncoder.asInstanceOf[ExpressionEncoder[BUF]].resolveAndBind() + private[this] lazy val bufferSerializer = bufferEncoder.createSerializer() + private[this] lazy val bufferDeserializer = bufferEncoder.createDeserializer() private[this] lazy val outputEncoder = agg.outputEncoder.asInstanceOf[ExpressionEncoder[OUT]] + private[this] lazy val outputSerializer = outputEncoder.createSerializer() def dataType: DataType = outputEncoder.objSerializer.dataType - def inputTypes: Seq[DataType] = inputEncoder.schema.map(_.dataType) + def inputTypes: Seq[DataType] = inputEncoderNR.schema.map(_.dataType) override lazy val deterministic: Boolean = isDeterministic @@ -491,23 +494,23 @@ case class ScalaAggregator[IN, BUF, OUT]( def createAggregationBuffer(): BUF = agg.zero def update(buffer: BUF, input: InternalRow): BUF = - agg.reduce(buffer, inputEncoder.fromRow(inputProjection(input))) + agg.reduce(buffer, inputDeserializer(inputProjection(input))) def merge(buffer: BUF, input: BUF): BUF = agg.merge(buffer, input) def eval(buffer: BUF): Any = { - val row = outputEncoder.toRow(agg.finish(buffer)) + val row = outputSerializer(agg.finish(buffer)) if (outputEncoder.isSerializedAsStruct) row else row.get(0, dataType) } private[this] lazy val bufferRow = new UnsafeRow(bufferEncoder.namedExpressions.length) def serialize(agg: BUF): Array[Byte] = - bufferEncoder.toRow(agg).asInstanceOf[UnsafeRow].getBytes() + bufferSerializer(agg).asInstanceOf[UnsafeRow].getBytes() def deserialize(storageFormat: Array[Byte]): BUF = { bufferRow.pointTo(storageFormat, storageFormat.length) - bufferEncoder.fromRow(bufferRow) + bufferDeserializer(bufferRow) } override def toString: String = s"""${nodeName}(${children.mkString(",")})""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index faf37609ad814..a58038d127818 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -637,9 +637,9 @@ object DataSourceStrategy { output: Seq[Attribute], rdd: RDD[Row]): RDD[InternalRow] = { if (relation.needConversion) { - val converters = RowEncoder(StructType.fromAttributes(output)) + val toRow = RowEncoder(StructType.fromAttributes(output)).createSerializer() rdd.mapPartitions { iterator => - iterator.map(converters.toRow) + iterator.map(toRow) } } else { rdd.asInstanceOf[RDD[InternalRow]] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 7a73ad50284c0..db4715ef068b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -332,9 +332,9 @@ object JdbcUtils extends Logging { def resultSetToRows(resultSet: ResultSet, schema: StructType): Iterator[Row] = { val inputMetrics = Option(TaskContext.get()).map(_.taskMetrics().inputMetrics).getOrElse(new InputMetrics) - val encoder = RowEncoder(schema).resolveAndBind() + val fromRow = RowEncoder(schema).resolveAndBind().createDeserializer() val internalRows = resultSetToSparkInternalRows(resultSet, schema, inputMetrics) - internalRows.map(encoder.fromRow) + internalRows.map(fromRow) } private[spark] def resultSetToSparkInternalRows( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeNamespaceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeNamespaceExec.scala index 64b98fb83b8fa..b4a14c6face31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeNamespaceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeNamespaceExec.scala @@ -34,7 +34,9 @@ case class DescribeNamespaceExec( catalog: SupportsNamespaces, namespace: Seq[String], isExtended: Boolean) extends V2CommandExec { - private val encoder = RowEncoder(StructType.fromAttributes(output)).resolveAndBind() + private val toRow = { + RowEncoder(StructType.fromAttributes(output)).resolveAndBind().createSerializer() + } override protected def run(): Seq[InternalRow] = { val rows = new ArrayBuffer[InternalRow]() @@ -57,6 +59,6 @@ case class DescribeNamespaceExec( } private def toCatalystRow(strs: String*): InternalRow = { - encoder.toRow(new GenericRowWithSchema(strs.toArray, schema)).copy() + toRow(new GenericRowWithSchema(strs.toArray, schema)).copy() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeTableExec.scala index 9c280206c548e..bc6bb175f979e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeTableExec.scala @@ -31,7 +31,9 @@ case class DescribeTableExec( table: Table, isExtended: Boolean) extends V2CommandExec { - private val encoder = RowEncoder(StructType.fromAttributes(output)).resolveAndBind() + private val toRow = { + RowEncoder(StructType.fromAttributes(output)).resolveAndBind().createSerializer() + } override protected def run(): Seq[InternalRow] = { val rows = new ArrayBuffer[InternalRow]() @@ -85,6 +87,6 @@ case class DescribeTableExec( private def emptyRow(): InternalRow = toCatalystRow("", "", "") private def toCatalystRow(strs: String*): InternalRow = { - encoder.toRow(new GenericRowWithSchema(strs.toArray, schema)).copy() + toRow(new GenericRowWithSchema(strs.toArray, schema)).copy() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowCurrentNamespaceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowCurrentNamespaceExec.scala index 42b80a15080a6..5f7b6f4061467 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowCurrentNamespaceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowCurrentNamespaceExec.scala @@ -31,10 +31,11 @@ case class ShowCurrentNamespaceExec( catalogManager: CatalogManager) extends V2CommandExec { override protected def run(): Seq[InternalRow] = { - val encoder = RowEncoder(schema).resolveAndBind() - Seq(encoder - .toRow(new GenericRowWithSchema( - Array(catalogManager.currentCatalog.name, catalogManager.currentNamespace.quoted), schema)) - .copy()) + val toRow = RowEncoder(schema).resolveAndBind().createSerializer() + val result = new GenericRowWithSchema(Array[Any]( + catalogManager.currentCatalog.name, + catalogManager.currentNamespace.quoted), + schema) + Seq(toRow(result).copy()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowNamespacesExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowNamespacesExec.scala index 6f968481cb7cc..9188f4eb60d56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowNamespacesExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowNamespacesExec.scala @@ -44,13 +44,11 @@ case class ShowNamespacesExec( } val rows = new ArrayBuffer[InternalRow]() - val encoder = RowEncoder(schema).resolveAndBind() + val toRow = RowEncoder(schema).resolveAndBind().createSerializer() namespaces.map(_.quoted).map { ns => if (pattern.map(StringUtils.filterPattern(Seq(ns), _).nonEmpty).getOrElse(true)) { - rows += encoder - .toRow(new GenericRowWithSchema(Array(ns), schema)) - .copy() + rows += toRow(new GenericRowWithSchema(Array(ns), schema)).copy() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablePropertiesExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablePropertiesExec.scala index 7905c35f55de0..0bcd7ea541045 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablePropertiesExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablePropertiesExec.scala @@ -32,17 +32,17 @@ case class ShowTablePropertiesExec( override protected def run(): Seq[InternalRow] = { import scala.collection.JavaConverters._ - val encoder = RowEncoder(schema).resolveAndBind() + val toRow = RowEncoder(schema).resolveAndBind().createSerializer() val properties = catalogTable.properties.asScala propertyKey match { case Some(p) => val propValue = properties .getOrElse(p, s"Table ${catalogTable.name} does not have property: $p") - Seq(encoder.toRow(new GenericRowWithSchema(Array(p, propValue), schema)).copy()) + Seq(toRow(new GenericRowWithSchema(Array(p, propValue), schema)).copy()) case None => properties.keys.map(k => - encoder.toRow(new GenericRowWithSchema(Array(k, properties(k)), schema)).copy()).toSeq + toRow(new GenericRowWithSchema(Array(k, properties(k)), schema)).copy()).toSeq } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablesExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablesExec.scala index c740e0d370dfd..820f5ae8f1b12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablesExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablesExec.scala @@ -37,17 +37,15 @@ case class ShowTablesExec( pattern: Option[String]) extends V2CommandExec with LeafExecNode { override protected def run(): Seq[InternalRow] = { val rows = new ArrayBuffer[InternalRow]() - val encoder = RowEncoder(schema).resolveAndBind() + val toRow = RowEncoder(schema).resolveAndBind().createSerializer() val tables = catalog.listTables(namespace.toArray) tables.map { table => if (pattern.map(StringUtils.filterPattern(Seq(table.name()), _).nonEmpty).getOrElse(true)) { - rows += encoder - .toRow( - new GenericRowWithSchema( - Array(table.namespace().quoted, table.name()), - schema)) - .copy() + val result = new GenericRowWithSchema( + Array(table.namespace().quoted, table.name()), + schema) + rows += toRow(result).copy() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala index fc47c5ed3ac00..368dfae0cc95e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala @@ -173,6 +173,7 @@ class TextSocketContinuousStream( setDaemon(true) override def run(): Unit = { + val toRow = encoder.createSerializer() try { while (true) { val line = reader.readLine() @@ -187,7 +188,7 @@ class TextSocketContinuousStream( Timestamp.valueOf( TextSocketReader.DATE_FORMAT.format(Calendar.getInstance().getTime())) ) - buckets(currentOffset % numPartitions) += encoder.toRow(newData) + buckets(currentOffset % numPartitions) += toRow(newData) .copy().asInstanceOf[UnsafeRow] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index ea39c549bd072..e5b9e68d71026 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -27,7 +27,7 @@ import scala.collection.mutable.ListBuffer import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.truncatedString @@ -57,6 +57,8 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Spa val encoder = encoderFor[A] protected val attributes = encoder.schema.toAttributes + protected lazy val toRow: ExpressionEncoder.Serializer[A] = encoder.createSerializer() + def toDS(): Dataset[A] = { Dataset[A](sqlContext.sparkSession, logicalPlan) } @@ -176,7 +178,7 @@ case class MemoryStream[A : Encoder]( def addData(data: TraversableOnce[A]): Offset = { val objects = data.toSeq - val rows = objects.iterator.map(d => encoder.toRow(d).copy().asInstanceOf[UnsafeRow]).toArray + val rows = objects.iterator.map(d => toRow(d).copy().asInstanceOf[UnsafeRow]).toArray logDebug(s"Adding: $objects") this.synchronized { currentOffset = currentOffset + 1 @@ -243,7 +245,7 @@ case class MemoryStream[A : Encoder]( rows: Seq[UnsafeRow], startOrdinal: Int, endOrdinal: Int): String = { - val fromRow = encoder.resolveAndBind().fromRow _ + val fromRow = encoder.resolveAndBind().createDeserializer() s"MemoryBatch [$startOrdinal, $endOrdinal]: " + s"${rows.map(row => fromRow(row)).mkString(", ")}" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index f94469385b281..d0cf602c7cca2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -60,7 +60,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa // Distribute data evenly among partition lists. data.toSeq.zipWithIndex.map { case (item, index) => - records(index % numPartitions) += encoder.toRow(item).copy().asInstanceOf[UnsafeRow] + records(index % numPartitions) += toRow(item).copy().asInstanceOf[UnsafeRow] } // The new target offset is the offset where all records in all partitions have been processed. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala index 03c567c58d46a..6d5e7fd5c5cf3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala @@ -30,7 +30,8 @@ class ForeachBatchSink[T](batchWriter: (Dataset[T], Long) => Unit, encoder: Expr val resolvedEncoder = encoder.resolveAndBind( data.logicalPlan.output, data.sparkSession.sessionState.analyzer) - val rdd = data.queryExecution.toRdd.map[T](resolvedEncoder.fromRow)(encoder.clsTag) + val fromRow = resolvedEncoder.createDeserializer() + val rdd = data.queryExecution.toRdd.map[T](fromRow)(encoder.clsTag) val ds = data.sparkSession.createDataset(rdd)(encoder) batchWriter(ds, batchId) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala index 6e4f40ad080d4..ba54c85d07303 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala @@ -73,7 +73,7 @@ case class ForeachWriterTable[T]( val boundEnc = enc.resolveAndBind( inputSchema.toAttributes, SparkSession.getActiveSession.get.sessionState.analyzer) - boundEnc.fromRow + boundEnc.createDeserializer() case Right(func) => func } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala index 2b674070a70ad..deab42bea36ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala @@ -172,10 +172,10 @@ class MemoryDataWriter(partition: Int, schema: StructType) private val data = mutable.Buffer[Row]() - private val encoder = RowEncoder(schema).resolveAndBind() + private val fromRow = RowEncoder(schema).resolveAndBind().createDeserializer() override def write(row: InternalRow): Unit = { - data.append(encoder.fromRow(row)) + data.append(fromRow(row)) } override def commit(): MemoryWriterCommitMessage = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index d3ef03e9b3b74..7ca9fbb40d9f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -517,7 +517,8 @@ private[sql] object CatalogImpl { data: Seq[T], sparkSession: SparkSession): Dataset[T] = { val enc = ExpressionEncoder[T]() - val encoded = data.map(d => enc.toRow(d).copy()) + val toRow = enc.createSerializer() + val encoded = data.map(d => toRow(d).copy()) val plan = new LocalRelation(enc.schema.toAttributes, encoded) val queryExecution = sparkSession.sessionState.executePlan(plan) new Dataset[T](queryExecution, enc) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala index 80340b5552c6d..4b2a2b439c89e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala @@ -28,14 +28,16 @@ class GroupedIteratorSuite extends SparkFunSuite { test("basic") { val schema = new StructType().add("i", IntegerType).add("s", StringType) val encoder = RowEncoder(schema).resolveAndBind() + val toRow = encoder.createSerializer() + val fromRow = encoder.createDeserializer() val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c")) - val grouped = GroupedIterator(input.iterator.map(encoder.toRow), + val grouped = GroupedIterator(input.iterator.map(toRow), Seq('i.int.at(0)), schema.toAttributes) val result = grouped.map { case (key, data) => assert(key.numFields == 1) - key.getInt(0) -> data.map(encoder.fromRow).toSeq + key.getInt(0) -> data.map(fromRow).toSeq }.toSeq assert(result == @@ -46,6 +48,8 @@ class GroupedIteratorSuite extends SparkFunSuite { test("group by 2 columns") { val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType) val encoder = RowEncoder(schema).resolveAndBind() + val toRow = encoder.createSerializer() + val fromRow = encoder.createDeserializer() val input = Seq( Row(1, 2L, "a"), @@ -54,13 +58,13 @@ class GroupedIteratorSuite extends SparkFunSuite { Row(2, 1L, "d"), Row(3, 2L, "e")) - val grouped = GroupedIterator(input.iterator.map(encoder.toRow), + val grouped = GroupedIterator(input.iterator.map(toRow), Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes) val result = grouped.map { case (key, data) => assert(key.numFields == 2) - (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq) + (key.getInt(0), key.getLong(1), data.map(fromRow).toSeq) }.toSeq assert(result == @@ -73,8 +77,9 @@ class GroupedIteratorSuite extends SparkFunSuite { test("do nothing to the value iterator") { val schema = new StructType().add("i", IntegerType).add("s", StringType) val encoder = RowEncoder(schema).resolveAndBind() + val toRow = encoder.createSerializer() val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c")) - val grouped = GroupedIterator(input.iterator.map(encoder.toRow), + val grouped = GroupedIterator(input.iterator.map(toRow), Seq('i.int.at(0)), schema.toAttributes) assert(grouped.length == 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala index f582d844cdc47..9b0389c6d1ea4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala @@ -40,13 +40,16 @@ object UnsafeArrayDataBenchmark extends BenchmarkBase { UnsafeArrayData.calculateHeaderPortionInBytes(count) } + private lazy val intEncoder = ExpressionEncoder[Array[Int]]().resolveAndBind() + + private lazy val doubleEncoder = ExpressionEncoder[Array[Double]]().resolveAndBind() + def readUnsafeArray(iters: Int): Unit = { val count = 1024 * 1024 * 16 val rand = new Random(42) - + val intArrayToRow = intEncoder.createSerializer() val intPrimitiveArray = Array.fill[Int](count) { rand.nextInt } - val intEncoder = ExpressionEncoder[Array[Int]].resolveAndBind() - val intUnsafeArray = intEncoder.toRow(intPrimitiveArray).getArray(0) + val intUnsafeArray = intArrayToRow(intPrimitiveArray).getArray(0) val readIntArray = { i: Int => var n = 0 while (n < iters) { @@ -62,8 +65,8 @@ object UnsafeArrayDataBenchmark extends BenchmarkBase { } val doublePrimitiveArray = Array.fill[Double](count) { rand.nextDouble } - val doubleEncoder = ExpressionEncoder[Array[Double]].resolveAndBind() - val doubleUnsafeArray = doubleEncoder.toRow(doublePrimitiveArray).getArray(0) + val doubleArrayToRow = doubleEncoder.createSerializer() + val doubleUnsafeArray = doubleArrayToRow(doublePrimitiveArray).getArray(0) val readDoubleArray = { i: Int => var n = 0 while (n < iters) { @@ -90,12 +93,12 @@ object UnsafeArrayDataBenchmark extends BenchmarkBase { var intTotalLength: Int = 0 val intPrimitiveArray = Array.fill[Int](count) { rand.nextInt } - val intEncoder = ExpressionEncoder[Array[Int]].resolveAndBind() + val intArrayToRow = intEncoder.createSerializer() val writeIntArray = { i: Int => var len = 0 var n = 0 while (n < iters) { - len += intEncoder.toRow(intPrimitiveArray).getArray(0).numElements() + len += intArrayToRow(intPrimitiveArray).getArray(0).numElements() n += 1 } intTotalLength = len @@ -103,12 +106,12 @@ object UnsafeArrayDataBenchmark extends BenchmarkBase { var doubleTotalLength: Int = 0 val doublePrimitiveArray = Array.fill[Double](count) { rand.nextDouble } - val doubleEncoder = ExpressionEncoder[Array[Double]].resolveAndBind() + val doubleArrayToRow = doubleEncoder.createSerializer() val writeDoubleArray = { i: Int => var len = 0 var n = 0 while (n < iters) { - len += doubleEncoder.toRow(doublePrimitiveArray).getArray(0).numElements() + len += doubleArrayToRow(doublePrimitiveArray).getArray(0).numElements() n += 1 } doubleTotalLength = len @@ -126,8 +129,8 @@ object UnsafeArrayDataBenchmark extends BenchmarkBase { var intTotalLength: Int = 0 val intPrimitiveArray = Array.fill[Int](count) { rand.nextInt } - val intEncoder = ExpressionEncoder[Array[Int]].resolveAndBind() - val intUnsafeArray = intEncoder.toRow(intPrimitiveArray).getArray(0) + val intArrayToRow = intEncoder.createSerializer() + val intUnsafeArray = intArrayToRow(intPrimitiveArray).getArray(0) val readIntArray = { i: Int => var len = 0 var n = 0 @@ -140,8 +143,8 @@ object UnsafeArrayDataBenchmark extends BenchmarkBase { var doubleTotalLength: Int = 0 val doublePrimitiveArray = Array.fill[Double](count) { rand.nextDouble } - val doubleEncoder = ExpressionEncoder[Array[Double]].resolveAndBind() - val doubleUnsafeArray = doubleEncoder.toRow(doublePrimitiveArray).getArray(0) + val doubleArrayToRow = doubleEncoder.createSerializer() + val doubleUnsafeArray = doubleArrayToRow(doublePrimitiveArray).getArray(0) val readDoubleArray = { i: Int => var len = 0 var n = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala index 2cd142f913072..8462916daaab8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala @@ -304,7 +304,7 @@ class BinaryFileFormatSuite extends QueryTest with SharedSparkSession { val partitionedFile = mock(classOf[PartitionedFile]) when(partitionedFile.filePath).thenReturn(file.getPath) val encoder = RowEncoder(requiredSchema).resolveAndBind() - encoder.fromRow(reader(partitionedFile).next()) + encoder.createDeserializer().apply(reader(partitionedFile).next()) } test("column pruning") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 6d5ad873eedea..8d5439534b513 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -144,16 +144,22 @@ trait StreamTest extends QueryTest with SharedSparkSession with TimeLimits with } } + private def createToExternalRowConverter[A : Encoder](): A => Row = { + val encoder = encoderFor[A] + val toInternalRow = encoder.createSerializer() + val toExternalRow = RowEncoder(encoder.schema).resolveAndBind().createDeserializer() + toExternalRow.compose(toInternalRow) + } + /** * Checks to make sure that the current data stored in the sink matches the `expectedAnswer`. * This operation automatically blocks until all added data has been processed. */ object CheckAnswer { def apply[A : Encoder](data: A*): CheckAnswerRows = { - val encoder = encoderFor[A] - val toExternalRow = RowEncoder(encoder.schema).resolveAndBind() + val toExternalRow = createToExternalRowConverter[A]() CheckAnswerRows( - data.map(d => toExternalRow.fromRow(encoder.toRow(d))), + data.map(toExternalRow), lastOnly = false, isSorted = false) } @@ -174,10 +180,9 @@ trait StreamTest extends QueryTest with SharedSparkSession with TimeLimits with } def apply[A: Encoder](isSorted: Boolean, data: A*): CheckAnswerRows = { - val encoder = encoderFor[A] - val toExternalRow = RowEncoder(encoder.schema).resolveAndBind() + val toExternalRow = createToExternalRowConverter[A]() CheckAnswerRows( - data.map(d => toExternalRow.fromRow(encoder.toRow(d))), + data.map(toExternalRow), lastOnly = true, isSorted = isSorted) } @@ -215,9 +220,8 @@ trait StreamTest extends QueryTest with SharedSparkSession with TimeLimits with def apply(): CheckNewAnswerRows = CheckNewAnswerRows(Seq.empty) def apply[A: Encoder](data: A, moreData: A*): CheckNewAnswerRows = { - val encoder = encoderFor[A] - val toExternalRow = RowEncoder(encoder.schema).resolveAndBind() - CheckNewAnswerRows((data +: moreData).map(d => toExternalRow.fromRow(encoder.toRow(d)))) + val toExternalRow = createToExternalRowConverter[A]() + CheckNewAnswerRows((data +: moreData).map(toExternalRow)) } def apply(rows: Row*): CheckNewAnswerRows = CheckNewAnswerRows(rows)