@@ -27,7 +27,7 @@ import scala.reflect.ClassTag
2727import org .apache .spark .serializer ._
2828import org .apache .spark .Logging
2929import org .apache .spark .sql .Row
30- import org .apache .spark .sql .catalyst .expressions .GenericMutableRow
30+ import org .apache .spark .sql .catalyst .expressions .{ SpecificMutableRow , MutableRow , GenericMutableRow }
3131import org .apache .spark .sql .types ._
3232
3333/**
@@ -49,9 +49,9 @@ private[sql] class Serializer2SerializationStream(
4949 out : OutputStream )
5050 extends SerializationStream with Logging {
5151
52- val rowOut = new DataOutputStream (new BufferedOutputStream (out))
53- val writeKeyFunc = SparkSqlSerializer2 .createSerializationFunction(keySchema, rowOut)
54- val writeValueFunc = SparkSqlSerializer2 .createSerializationFunction(valueSchema, rowOut)
52+ private val rowOut = new DataOutputStream (new BufferedOutputStream (out))
53+ private val writeKeyFunc = SparkSqlSerializer2 .createSerializationFunction(keySchema, rowOut)
54+ private val writeValueFunc = SparkSqlSerializer2 .createSerializationFunction(valueSchema, rowOut)
5555
5656 override def writeObject [T : ClassTag ](t : T ): SerializationStream = {
5757 val kv = t.asInstanceOf [Product2 [Row , Row ]]
@@ -86,24 +86,44 @@ private[sql] class Serializer2SerializationStream(
8686private [sql] class Serializer2DeserializationStream (
8787 keySchema : Array [DataType ],
8888 valueSchema : Array [DataType ],
89+ hasKeyOrdering : Boolean ,
8990 in : InputStream )
9091 extends DeserializationStream with Logging {
9192
92- val rowIn = new DataInputStream (new BufferedInputStream (in))
93+ private val rowIn = new DataInputStream (new BufferedInputStream (in))
9394
94- val readKeyFunc = SparkSqlSerializer2 .createDeserializationFunction(keySchema, rowIn)
95- val readValueFunc = SparkSqlSerializer2 .createDeserializationFunction(valueSchema, rowIn)
95+ private def rowGenerator (schema : Array [DataType ]): () => (MutableRow ) = {
96+ if (schema == null ) {
97+ () => null
98+ } else {
99+ if (hasKeyOrdering) {
100+ // We have key ordering specified in a ShuffledRDD, it is not safe to reuse a mutable row.
101+ () => new GenericMutableRow (schema.length)
102+ } else {
103+ // It is safe to reuse the mutable row.
104+ val mutableRow = new SpecificMutableRow (schema)
105+ () => mutableRow
106+ }
107+ }
108+ }
109+
110+ // Functions used to return rows for key and value.
111+ private val getKey = rowGenerator(keySchema)
112+ private val getValue = rowGenerator(valueSchema)
113+ // Functions used to read a serialized row from the InputStream and deserialize it.
114+ private val readKeyFunc = SparkSqlSerializer2 .createDeserializationFunction(keySchema, rowIn)
115+ private val readValueFunc = SparkSqlSerializer2 .createDeserializationFunction(valueSchema, rowIn)
96116
97117 override def readObject [T : ClassTag ](): T = {
98- (readKeyFunc() , readValueFunc()).asInstanceOf [T ]
118+ (readKeyFunc(getKey()) , readValueFunc(getValue() )).asInstanceOf [T ]
99119 }
100120
101121 override def readKey [T : ClassTag ](): T = {
102- readKeyFunc().asInstanceOf [T ]
122+ readKeyFunc(getKey() ).asInstanceOf [T ]
103123 }
104124
105125 override def readValue [T : ClassTag ](): T = {
106- readValueFunc().asInstanceOf [T ]
126+ readValueFunc(getValue() ).asInstanceOf [T ]
107127 }
108128
109129 override def close (): Unit = {
@@ -113,7 +133,8 @@ private[sql] class Serializer2DeserializationStream(
113133
114134private [sql] class SparkSqlSerializer2Instance (
115135 keySchema : Array [DataType ],
116- valueSchema : Array [DataType ])
136+ valueSchema : Array [DataType ],
137+ hasKeyOrdering : Boolean )
117138 extends SerializerInstance {
118139
119140 def serialize [T : ClassTag ](t : T ): ByteBuffer =
@@ -130,7 +151,7 @@ private[sql] class SparkSqlSerializer2Instance(
130151 }
131152
132153 def deserializeStream (s : InputStream ): DeserializationStream = {
133- new Serializer2DeserializationStream (keySchema, valueSchema, s)
154+ new Serializer2DeserializationStream (keySchema, valueSchema, hasKeyOrdering, s)
134155 }
135156}
136157
@@ -141,12 +162,16 @@ private[sql] class SparkSqlSerializer2Instance(
141162 * The schema of keys is represented by `keySchema` and that of values is represented by
142163 * `valueSchema`.
143164 */
144- private [sql] class SparkSqlSerializer2 (keySchema : Array [DataType ], valueSchema : Array [DataType ])
165+ private [sql] class SparkSqlSerializer2 (
166+ keySchema : Array [DataType ],
167+ valueSchema : Array [DataType ],
168+ hasKeyOrdering : Boolean )
145169 extends Serializer
146170 with Logging
147171 with Serializable {
148172
149- def newInstance (): SerializerInstance = new SparkSqlSerializer2Instance (keySchema, valueSchema)
173+ def newInstance (): SerializerInstance =
174+ new SparkSqlSerializer2Instance (keySchema, valueSchema, hasKeyOrdering)
150175
151176 override def supportsRelocationOfSerializedObjects : Boolean = {
152177 // SparkSqlSerializer2 is stateless and writes no stream headers
@@ -316,12 +341,12 @@ private[sql] object SparkSqlSerializer2 {
316341 */
317342 def createDeserializationFunction (
318343 schema : Array [DataType ],
319- in : DataInputStream ): () => Row = {
320- () => {
321- // If the schema is null, the returned function does nothing when it get called.
322- if (schema != null ) {
344+ in : DataInputStream ): (MutableRow ) => Row = {
345+ if (schema == null ) {
346+ (mutableRow : MutableRow ) => null
347+ } else {
348+ (mutableRow : MutableRow ) => {
323349 var i = 0
324- val mutableRow = new GenericMutableRow (schema.length)
325350 while (i < schema.length) {
326351 schema(i) match {
327352 // When we read values from the underlying stream, we also first read the null byte
@@ -435,8 +460,6 @@ private[sql] object SparkSqlSerializer2 {
435460 }
436461
437462 mutableRow
438- } else {
439- null
440463 }
441464 }
442465 }
0 commit comments