Skip to content

Commit 5bcff74

Browse files
committed
Move serializer and deserializer implementation into the companion.
1 parent 2c402f9 commit 5bcff74

File tree

2 files changed

+63
-55
lines changed

2 files changed

+63
-55
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ import org.apache.spark.sql.types._
5757
* Encoders.bean(MyClass.class);
5858
* }}}
5959
*
60+
* == Implementation ==
61+
* - Encoders should be thread-safe.
62+
*
6063
* @since 1.6.0
6164
*/
6265
@implicitNotFound("Unable to find encoder for type ${T}. An implicit Encoder[${T}] is needed to " +

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala

Lines changed: 60 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -166,17 +166,69 @@ object ExpressionEncoder {
166166
e5: ExpressionEncoder[T5]): ExpressionEncoder[(T1, T2, T3, T4, T5)] =
167167
tuple(Seq(e1, e2, e3, e4, e5)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]]
168168

169+
private val anyObjectType = ObjectType(classOf[Any])
170+
169171
/**
170-
* Function that deserializes an [[InternalRow]] into an object of type `T`. Instances of this
171-
* class are not meant to be thread-safe.
172+
* Function that deserializes an [[InternalRow]] into an object of type `T`. This class is not
173+
* thread-safe.
172174
*/
173-
abstract class Deserializer[T] extends (InternalRow => T) with Serializable
175+
class Deserializer[T](private val expressions: Seq[Expression])
176+
extends (InternalRow => T) with Serializable {
177+
@transient
178+
private[this] var constructProjection: Projection = _
179+
180+
private def initialize(): Unit = {
181+
constructProjection = SafeProjection.create(expressions)
182+
}
183+
initialize()
184+
185+
override def apply(row: InternalRow): T = try {
186+
constructProjection(row).get(0, anyObjectType).asInstanceOf[T]
187+
} catch {
188+
case e: Exception =>
189+
throw new RuntimeException(s"Error while decoding: $e\n" +
190+
s"${expressions.map(_.simpleString(SQLConf.get.maxToStringFields)).mkString("\n")}", e)
191+
}
192+
193+
private def readObject(in: ObjectInputStream): Unit = {
194+
in.defaultReadObject()
195+
initialize()
196+
}
197+
}
174198

175199
/**
176-
* Function that serializesa an object of type `T` to an [[InternalRow]]. Instances of this
177-
* class are not meant to be thread-safe.
200+
* Function that serializesa an object of type `T` to an [[InternalRow]]. This class is not
201+
* thread-safe. Note that multiple calls to `apply(..)` return the same actual [[InternalRow]]
202+
* object. Thus, the caller should copy the result before making another call if required.
178203
*/
179-
abstract class Serializer[T] extends (T => InternalRow) with Serializable
204+
class Serializer[T](private val expressions: Seq[Expression])
205+
extends (T => InternalRow) with Serializable {
206+
@transient
207+
private[this] var inputRow: GenericInternalRow = _
208+
209+
@transient
210+
private[this] var extractProjection: UnsafeProjection = _
211+
212+
private def initialize(): Unit = {
213+
inputRow = new GenericInternalRow(1)
214+
extractProjection = GenerateUnsafeProjection.generate(expressions)
215+
}
216+
initialize()
217+
218+
override def apply(t: T): InternalRow = try {
219+
inputRow(0) = t
220+
extractProjection(inputRow)
221+
} catch {
222+
case e: Exception =>
223+
throw new RuntimeException(s"Error while encoding: $e\n" +
224+
s"${expressions.map(_.simpleString(SQLConf.get.maxToStringFields)).mkString("\n")}", e)
225+
}
226+
227+
private def readObject(in: ObjectInputStream): Unit = {
228+
in.defaultReadObject()
229+
initialize()
230+
}
231+
}
180232
}
181233

182234
/**
@@ -350,62 +402,15 @@ case class ExpressionEncoder[T](
350402
* `serializer.apply(..)` are allowed to return the same actual [[InternalRow]] object. Thus,
351403
* the caller should copy the result before making another call if required.
352404
*/
353-
def createSerializer(): Serializer[T] = new Serializer[T] {
354-
@transient
355-
private var inputRow: GenericInternalRow = _
356-
357-
@transient
358-
private var extractProjection: UnsafeProjection = _
359-
360-
private def initialize(): Unit = {
361-
inputRow = new GenericInternalRow(1)
362-
extractProjection = GenerateUnsafeProjection.generate(optimizedSerializer)
363-
}
364-
initialize()
365-
366-
override def apply(t: T): InternalRow = try {
367-
inputRow(0) = t
368-
extractProjection(inputRow)
369-
} catch {
370-
case e: Exception =>
371-
throw new RuntimeException(s"Error while encoding: $e\n" +
372-
s"${serializer.map(_.simpleString(SQLConf.get.maxToStringFields)).mkString("\n")}", e)
373-
}
374-
375-
private def readObject(in: ObjectInputStream): Unit = {
376-
in.defaultReadObject()
377-
initialize()
378-
}
379-
}
405+
def createSerializer(): Serializer[T] = new Serializer[T](optimizedSerializer)
380406

381407
/**
382408
* Create a deserializer that can convert a Spark SQL Row into an object of type `T`.
383409
*
384410
* Note that you must `resolveAndBind` an encoder to a specific schema before you can create a
385411
* deserializer.
386412
*/
387-
def createDeserializer(): Deserializer[T] = new Deserializer[T] {
388-
@transient
389-
private var constructProjection: Projection = _
390-
391-
private def initialize(): Unit = {
392-
constructProjection = SafeProjection.create(optimizedDeserializer)
393-
}
394-
initialize()
395-
396-
override def apply(row: InternalRow): T = try {
397-
constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T]
398-
} catch {
399-
case e: Exception =>
400-
throw new RuntimeException(s"Error while decoding: $e\n" +
401-
s"${deserializer.simpleString(SQLConf.get.maxToStringFields)}", e)
402-
}
403-
404-
private def readObject(in: ObjectInputStream): Unit = {
405-
in.defaultReadObject()
406-
initialize()
407-
}
408-
}
413+
def createDeserializer(): Deserializer[T] = new Deserializer[T](optimizedDeserializer)
409414

410415
/**
411416
* The process of resolution to a given schema throws away information about where a given field

0 commit comments

Comments
 (0)