Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ private[image] class ImageFileFormat extends FileFormat with DataSourceRegister
if (requiredSchema.isEmpty) {
filteredResult.map(_ => emptyUnsafeRow)
} else {
val converter = RowEncoder(requiredSchema)
filteredResult.map(row => converter.toRow(row))
val toRow = RowEncoder(requiredSchema).createSerializer()
filteredResult.map(row => toRow(row))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ private[libsvm] class LibSVMFileFormat
LabeledPoint(label, Vectors.sparse(numFeatures, indices, values))
}

val converter = RowEncoder(dataSchema)
val toRow = RowEncoder(dataSchema).createSerializer()
val fullOutput = dataSchema.map { f =>
AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()
}
Expand All @@ -178,7 +178,7 @@ private[libsvm] class LibSVMFileFormat

points.map { pt =>
val features = if (isSparse) pt.features.toSparse else pt.features.toDense
requiredColumns(converter.toRow(Row(pt.label, features)))
requiredColumns(toRow(Row(pt.label, features)))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,22 @@ object UDTSerializationBenchmark extends BenchmarkBase {
val iters = 1e2.toInt
val numRows = 1e3.toInt

val encoder = ExpressionEncoder[Vector].resolveAndBind()
val encoder = ExpressionEncoder[Vector]().resolveAndBind()
val toRow = encoder.createSerializer()
val fromRow = encoder.createDeserializer()

val vectors = (1 to numRows).map { i =>
Vectors.dense(Array.fill(1e5.toInt)(1.0 * i))
}.toArray
val rows = vectors.map(encoder.toRow)
val rows = vectors.map(toRow)

val benchmark = new Benchmark("VectorUDT de/serialization", numRows, iters, output = output)

benchmark.addCase("serialize") { _ =>
var sum = 0
var i = 0
while (i < numRows) {
sum += encoder.toRow(vectors(i)).numFields
sum += toRow(vectors(i)).numFields
i += 1
}
}
Expand All @@ -60,7 +62,7 @@ object UDTSerializationBenchmark extends BenchmarkBase {
var sum = 0
var i = 0
while (i < numRows) {
sum += encoder.fromRow(rows(i)).numActives
sum += fromRow(rows(i)).numActives
i += 1
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ import org.apache.spark.sql.types._
* }}}
*
* == Implementation ==
* - Encoders are not required to be thread-safe and thus they do not need to use locks to guard
* against concurrent access if they reuse internal buffers to improve performance.
* - Encoders should be thread-safe.
*
* @since 1.6.0
*/
Expand All @@ -76,10 +75,4 @@ trait Encoder[T] extends Serializable {
* A ClassTag that can be used to construct an Array to contain a collection of `T`.
*/
def clsTag: ClassTag[T]

/**
* Create a copied [[Encoder]]. The implementation may just copy internal reusable fields to speed
* up the [[Encoder]] creation.
*/
def makeCopy: Encoder[T]
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@

package org.apache.spark.sql.catalyst.encoders

import java.io.ObjectInputStream

import scala.reflect.ClassTag
import scala.reflect.runtime.universe.{typeTag, TypeTag}

import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection}
import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.{Deserializer, Serializer}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, InitializeJavaBean, Invoke, NewInstance}
Expand Down Expand Up @@ -162,6 +165,56 @@ object ExpressionEncoder {
e4: ExpressionEncoder[T4],
e5: ExpressionEncoder[T5]): ExpressionEncoder[(T1, T2, T3, T4, T5)] =
tuple(Seq(e1, e2, e3, e4, e5)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]]

private val anyObjectType = ObjectType(classOf[Any])

/**
* Function that deserializes an [[InternalRow]] into an object of type `T`. This class is not
* thread-safe.
*/
class Deserializer[T](private val expressions: Seq[Expression])
extends (InternalRow => T) with Serializable {
@transient
private[this] var constructProjection: Projection = _

override def apply(row: InternalRow): T = try {
if (constructProjection == null) {
constructProjection = SafeProjection.create(expressions)
}
constructProjection(row).get(0, anyObjectType).asInstanceOf[T]
} catch {
case e: Exception =>
throw new RuntimeException(s"Error while decoding: $e\n" +
s"${expressions.map(_.simpleString(SQLConf.get.maxToStringFields)).mkString("\n")}", e)
}
}

/**
* Function that serializesa an object of type `T` to an [[InternalRow]]. This class is not
* thread-safe. Note that multiple calls to `apply(..)` return the same actual [[InternalRow]]
* object. Thus, the caller should copy the result before making another call if required.
*/
class Serializer[T](private val expressions: Seq[Expression])
extends (T => InternalRow) with Serializable {
@transient
private[this] var inputRow: GenericInternalRow = _

@transient
private[this] var extractProjection: UnsafeProjection = _

override def apply(t: T): InternalRow = try {
if (extractProjection == null) {
inputRow = new GenericInternalRow(1)
extractProjection = GenerateUnsafeProjection.generate(expressions)
}
inputRow(0) = t
extractProjection(inputRow)
} catch {
case e: Exception =>
throw new RuntimeException(s"Error while encoding: $e\n" +
s"${expressions.map(_.simpleString(SQLConf.get.maxToStringFields)).mkString("\n")}", e)
}
}
}

/**
Expand Down Expand Up @@ -302,25 +355,22 @@ case class ExpressionEncoder[T](
}

@transient
private lazy val extractProjection = GenerateUnsafeProjection.generate({
private lazy val optimizedDeserializer: Seq[Expression] = {
// When using `ExpressionEncoder` directly, we will skip the normal query processing steps
// (analyzer, optimizer, etc.). Here we apply the ReassignLambdaVariableID rule, as it's
// important to codegen performance.
val optimizedPlan = ReassignLambdaVariableID.apply(DummyExpressionHolder(serializer))
val optimizedPlan = ReassignLambdaVariableID.apply(DummyExpressionHolder(Seq(deserializer)))
optimizedPlan.asInstanceOf[DummyExpressionHolder].exprs
})

@transient
private lazy val inputRow = new GenericInternalRow(1)
}

@transient
private lazy val constructProjection = SafeProjection.create({
private lazy val optimizedSerializer = {
// When using `ExpressionEncoder` directly, we will skip the normal query processing steps
// (analyzer, optimizer, etc.). Here we apply the ReassignLambdaVariableID rule, as it's
// important to codegen performance.
val optimizedPlan = ReassignLambdaVariableID.apply(DummyExpressionHolder(Seq(deserializer)))
val optimizedPlan = ReassignLambdaVariableID.apply(DummyExpressionHolder(serializer))
optimizedPlan.asInstanceOf[DummyExpressionHolder].exprs
})
}

/**
* Returns a new set (with unique ids) of [[NamedExpression]] that represent the serialized form
Expand All @@ -332,31 +382,21 @@ case class ExpressionEncoder[T](
}

/**
* Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to
* toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should
* copy the result before making another call if required.
* Create a serializer that can convert an object of type `T` to a Spark SQL Row.
*
* Note that the returned [[Serializer]] is not thread safe. Multiple calls to
* `serializer.apply(..)` are allowed to return the same actual [[InternalRow]] object. Thus,
* the caller should copy the result before making another call if required.
*/
def toRow(t: T): InternalRow = try {
inputRow(0) = t
extractProjection(inputRow)
} catch {
case e: Exception =>
throw new RuntimeException(s"Error while encoding: $e\n" +
s"${serializer.map(_.simpleString(SQLConf.get.maxToStringFields)).mkString("\n")}", e)
}
def createSerializer(): Serializer[T] = new Serializer[T](optimizedSerializer)

/**
* Returns an object of type `T`, extracting the required values from the provided row. Note that
* you must `resolveAndBind` an encoder to a specific schema before you can call this
* function.
* Create a deserializer that can convert a Spark SQL Row into an object of type `T`.
*
* Note that you must `resolveAndBind` an encoder to a specific schema before you can create a
* deserializer.
*/
def fromRow(row: InternalRow): T = try {
constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T]
} catch {
case e: Exception =>
throw new RuntimeException(s"Error while decoding: $e\n" +
s"${deserializer.simpleString(SQLConf.get.maxToStringFields)}", e)
}
def createDeserializer(): Deserializer[T] = new Deserializer[T](optimizedDeserializer)

/**
* The process of resolution to a given schema throws away information about where a given field
Expand All @@ -383,8 +423,6 @@ case class ExpressionEncoder[T](
.map { case(f, a) => s"${f.name}$a: ${f.dataType.simpleString}"}.mkString(", ")

override def toString: String = s"class[$schemaString]"

override def makeCopy: ExpressionEncoder[T] = copy()
}

// A dummy logical plan that can hold expressions and go through optimizer rules.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ case class ScalaUDF(
} else {
val encoder = inputEncoders(i)
if (encoder.isDefined && encoder.get.isSerializedAsStructForTopLevel) {
val enc = encoder.get.resolveAndBind()
row: Any => enc.fromRow(row.asInstanceOf[InternalRow])
val fromRow = encoder.get.resolveAndBind().createDeserializer()
row: Any => fromRow(row.asInstanceOf[InternalRow])
} else {
CatalystTypeConverters.createToScalaConverter(dataType)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ object HashBenchmark extends BenchmarkBase {
def test(name: String, schema: StructType, numRows: Int, iters: Int): Unit = {
runBenchmark(name) {
val generator = RandomDataGenerator.forType(schema, nullable = false).get
val encoder = RowEncoder(schema)
val toRow = RowEncoder(schema).createSerializer()
val attrs = schema.toAttributes
val safeProjection = GenerateSafeProjection.generate(attrs, attrs)

val rows = (1 to numRows).map(_ =>
// The output of encoder is UnsafeRow, use safeProjection to turn in into safe format.
safeProjection(encoder.toRow(generator().asInstanceOf[Row])).copy()
safeProjection(toRow(generator().asInstanceOf[Row])).copy()
).toArray

val benchmark = new Benchmark("Hash For " + name, iters * numRows.toLong, output = output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ object UnsafeProjectionBenchmark extends BenchmarkBase {

def generateRows(schema: StructType, numRows: Int): Array[InternalRow] = {
val generator = RandomDataGenerator.forType(schema, nullable = false).get
val encoder = RowEncoder(schema)
(1 to numRows).map(_ => encoder.toRow(generator().asInstanceOf[Row]).copy()).toArray
val toRow = RowEncoder(schema).createSerializer()
(1 to numRows).map(_ => toRow(generator().asInstanceOf[Row]).copy()).toArray
}

override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._
Expand All @@ -42,37 +43,44 @@ case class NestedArrayClass(nestedArr: Array[ArrayClass])
class EncoderResolutionSuite extends PlanTest {
private val str = UTF8String.fromString("hello")

def testFromRow[T](
encoder: ExpressionEncoder[T],
attributes: Seq[Attribute],
row: InternalRow): Unit = {
encoder.resolveAndBind(attributes).createDeserializer().apply(row)
}

test("real type doesn't match encoder schema but they are compatible: product") {
val encoder = ExpressionEncoder[StringLongClass]

// int type can be up cast to long type
val attrs1 = Seq('a.string, 'b.int)
encoder.resolveAndBind(attrs1).fromRow(InternalRow(str, 1))
testFromRow(encoder, attrs1, InternalRow(str, 1))

// int type can be up cast to string type
val attrs2 = Seq('a.int, 'b.long)
encoder.resolveAndBind(attrs2).fromRow(InternalRow(1, 2L))
testFromRow(encoder, attrs2, InternalRow(1, 2L))
}

test("real type doesn't match encoder schema but they are compatible: nested product") {
val encoder = ExpressionEncoder[ComplexClass]
val attrs = Seq('a.int, 'b.struct('a.int, 'b.long))
encoder.resolveAndBind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L)))
testFromRow(encoder, attrs, InternalRow(1, InternalRow(2, 3L)))
}

test("real type doesn't match encoder schema but they are compatible: tupled encoder") {
val encoder = ExpressionEncoder.tuple(
ExpressionEncoder[StringLongClass],
ExpressionEncoder[Long])
val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int)
encoder.resolveAndBind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2))
testFromRow(encoder, attrs, InternalRow(InternalRow(str, 1.toByte), 2))
}

test("real type doesn't match encoder schema but they are compatible: primitive array") {
val encoder = ExpressionEncoder[PrimitiveArrayClass]
val attrs = Seq('arr.array(IntegerType))
val array = new GenericArrayData(Array(1, 2, 3))
encoder.resolveAndBind(attrs).fromRow(InternalRow(array))
testFromRow(encoder, attrs, InternalRow(array))
}

test("the real type is not compatible with encoder schema: primitive array") {
Expand All @@ -93,7 +101,7 @@ class EncoderResolutionSuite extends PlanTest {
val encoder = ExpressionEncoder[ArrayClass]
val attrs = Seq('arr.array(new StructType().add("a", "int").add("b", "int").add("c", "int")))
val array = new GenericArrayData(Array(InternalRow(1, 2, 3)))
encoder.resolveAndBind(attrs).fromRow(InternalRow(array))
testFromRow(encoder, attrs, InternalRow(array))
}

test("real type doesn't match encoder schema but they are compatible: nested array") {
Expand All @@ -103,7 +111,7 @@ class EncoderResolutionSuite extends PlanTest {
val attrs = Seq('nestedArr.array(et))
val innerArr = new GenericArrayData(Array(InternalRow(1, 2, 3)))
val outerArr = new GenericArrayData(Array(InternalRow(innerArr)))
encoder.resolveAndBind(attrs).fromRow(InternalRow(outerArr))
testFromRow(encoder, attrs, InternalRow(outerArr))
}

test("the real type is not compatible with encoder schema: non-array field") {
Expand Down Expand Up @@ -142,14 +150,14 @@ class EncoderResolutionSuite extends PlanTest {
val attrs = 'a.array(IntegerType) :: Nil

// It should pass analysis
val bound = encoder.resolveAndBind(attrs)
val fromRow = encoder.resolveAndBind(attrs).createDeserializer()

// If no null values appear, it should work fine
bound.fromRow(InternalRow(new GenericArrayData(Array(1, 2))))
fromRow(InternalRow(new GenericArrayData(Array(1, 2))))

// If there is null value, it should throw runtime exception
val e = intercept[RuntimeException] {
bound.fromRow(InternalRow(new GenericArrayData(Array(1, null))))
fromRow(InternalRow(new GenericArrayData(Array(1, null))))
}
assert(e.getMessage.contains("Null value appeared in non-nullable field"))
}
Expand Down
Loading