@@ -166,17 +166,69 @@ object ExpressionEncoder {
166166 e5 : ExpressionEncoder [T5 ]): ExpressionEncoder [(T1 , T2 , T3 , T4 , T5 )] =
167167 tuple(Seq (e1, e2, e3, e4, e5)).asInstanceOf [ExpressionEncoder [(T1 , T2 , T3 , T4 , T5 )]]
168168
169+ private val anyObjectType = ObjectType (classOf [Any ])
170+
169171 /**
170- * Function that deserializes an [[InternalRow ]] into an object of type `T`. Instances of this
171- * class are not meant to be thread-safe.
172+ * Function that deserializes an [[InternalRow ]] into an object of type `T`. This class is not
173+ * thread-safe.
172174 */
173- abstract class Deserializer [T ] extends (InternalRow => T ) with Serializable
175+ class Deserializer [T ](private val expressions : Seq [Expression ])
176+ extends (InternalRow => T ) with Serializable {
177+ @ transient
178+ private [this ] var constructProjection : Projection = _
179+
180+ private def initialize (): Unit = {
181+ constructProjection = SafeProjection .create(expressions)
182+ }
183+ initialize()
184+
185+ override def apply (row : InternalRow ): T = try {
186+ constructProjection(row).get(0 , anyObjectType).asInstanceOf [T ]
187+ } catch {
188+ case e : Exception =>
189+ throw new RuntimeException (s " Error while decoding: $e\n " +
190+ s " ${expressions.map(_.simpleString(SQLConf .get.maxToStringFields)).mkString(" \n " )}" , e)
191+ }
192+
193+ private def readObject (in : ObjectInputStream ): Unit = {
194+ in.defaultReadObject()
195+ initialize()
196+ }
197+ }
174198
175199 /**
176- * Function that serializesa an object of type `T` to an [[InternalRow ]]. Instances of this
177- * class are not meant to be thread-safe.
200+ * Function that serializesa an object of type `T` to an [[InternalRow ]]. This class is not
201+ * thread-safe. Note that multiple calls to `apply(..)` return the same actual [[InternalRow ]]
202+ * object. Thus, the caller should copy the result before making another call if required.
178203 */
179- abstract class Serializer [T ] extends (T => InternalRow ) with Serializable
204+ class Serializer [T ](private val expressions : Seq [Expression ])
205+ extends (T => InternalRow ) with Serializable {
206+ @ transient
207+ private [this ] var inputRow : GenericInternalRow = _
208+
209+ @ transient
210+ private [this ] var extractProjection : UnsafeProjection = _
211+
212+ private def initialize (): Unit = {
213+ inputRow = new GenericInternalRow (1 )
214+ extractProjection = GenerateUnsafeProjection .generate(expressions)
215+ }
216+ initialize()
217+
218+ override def apply (t : T ): InternalRow = try {
219+ inputRow(0 ) = t
220+ extractProjection(inputRow)
221+ } catch {
222+ case e : Exception =>
223+ throw new RuntimeException (s " Error while encoding: $e\n " +
224+ s " ${expressions.map(_.simpleString(SQLConf .get.maxToStringFields)).mkString(" \n " )}" , e)
225+ }
226+
227+ private def readObject (in : ObjectInputStream ): Unit = {
228+ in.defaultReadObject()
229+ initialize()
230+ }
231+ }
180232}
181233
182234/**
@@ -350,62 +402,15 @@ case class ExpressionEncoder[T](
350402 * `serializer.apply(..)` are allowed to return the same actual [[InternalRow ]] object. Thus,
351403 * the caller should copy the result before making another call if required.
352404 */
353- def createSerializer (): Serializer [T ] = new Serializer [T ] {
354- @ transient
355- private var inputRow : GenericInternalRow = _
356-
357- @ transient
358- private var extractProjection : UnsafeProjection = _
359-
360- private def initialize (): Unit = {
361- inputRow = new GenericInternalRow (1 )
362- extractProjection = GenerateUnsafeProjection .generate(optimizedSerializer)
363- }
364- initialize()
365-
366- override def apply (t : T ): InternalRow = try {
367- inputRow(0 ) = t
368- extractProjection(inputRow)
369- } catch {
370- case e : Exception =>
371- throw new RuntimeException (s " Error while encoding: $e\n " +
372- s " ${serializer.map(_.simpleString(SQLConf .get.maxToStringFields)).mkString(" \n " )}" , e)
373- }
374-
375- private def readObject (in : ObjectInputStream ): Unit = {
376- in.defaultReadObject()
377- initialize()
378- }
379- }
405+ def createSerializer (): Serializer [T ] = new Serializer [T ](optimizedSerializer)
380406
381407 /**
382408 * Create a deserializer that can convert a Spark SQL Row into an object of type `T`.
383409 *
384410 * Note that you must `resolveAndBind` an encoder to a specific schema before you can create a
385411 * deserializer.
386412 */
387- def createDeserializer (): Deserializer [T ] = new Deserializer [T ] {
388- @ transient
389- private var constructProjection : Projection = _
390-
391- private def initialize (): Unit = {
392- constructProjection = SafeProjection .create(optimizedDeserializer)
393- }
394- initialize()
395-
396- override def apply (row : InternalRow ): T = try {
397- constructProjection(row).get(0 , ObjectType (clsTag.runtimeClass)).asInstanceOf [T ]
398- } catch {
399- case e : Exception =>
400- throw new RuntimeException (s " Error while decoding: $e\n " +
401- s " ${deserializer.simpleString(SQLConf .get.maxToStringFields)}" , e)
402- }
403-
404- private def readObject (in : ObjectInputStream ): Unit = {
405- in.defaultReadObject()
406- initialize()
407- }
408- }
413+ def createDeserializer (): Deserializer [T ] = new Deserializer [T ](optimizedDeserializer)
409414
410415 /**
411416 * The process of resolution to a given schema throws away information about where a given field
0 commit comments