Skip to content

Commit 53a5eaa

Browse files
committed
Josh's comments.
1 parent 487f540 commit 53a5eaa

File tree

2 files changed

+51
-25
lines changed

2 files changed

+51
-25
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ case class Exchange(
8484
def serializer(
8585
keySchema: Array[DataType],
8686
valueSchema: Array[DataType],
87+
hasKeyOrdering: Boolean,
8788
numPartitions: Int): Serializer = {
8889
// It is true when there is no field that needs to be write out.
8990
// For now, we will not use SparkSqlSerializer2 when noField is true.
@@ -99,7 +100,7 @@ case class Exchange(
99100

100101
val serializer = if (useSqlSerializer2) {
101102
logInfo("Using SparkSqlSerializer2.")
102-
new SparkSqlSerializer2(keySchema, valueSchema)
103+
new SparkSqlSerializer2(keySchema, valueSchema, hasKeyOrdering)
103104
} else {
104105
logInfo("Using SparkSqlSerializer.")
105106
new SparkSqlSerializer(sparkConf)
@@ -142,7 +143,8 @@ case class Exchange(
142143
}
143144
val keySchema = expressions.map(_.dataType).toArray
144145
val valueSchema = child.output.map(_.dataType).toArray
145-
shuffled.setSerializer(serializer(keySchema, valueSchema, numPartitions))
146+
shuffled.setSerializer(
147+
serializer(keySchema, valueSchema, newOrdering.nonEmpty, numPartitions))
146148

147149
shuffled.map(_._2)
148150

@@ -167,7 +169,8 @@ case class Exchange(
167169
new ShuffledRDD[Row, Null, Null](rdd, part)
168170
}
169171
val keySchema = child.output.map(_.dataType).toArray
170-
shuffled.setSerializer(serializer(keySchema, null, numPartitions))
172+
shuffled.setSerializer(
173+
serializer(keySchema, null, newOrdering.nonEmpty, numPartitions))
171174

172175
shuffled.map(_._1)
173176

@@ -187,7 +190,7 @@ case class Exchange(
187190
val partitioner = new HashPartitioner(1)
188191
val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner)
189192
val valueSchema = child.output.map(_.dataType).toArray
190-
shuffled.setSerializer(serializer(null, valueSchema, 1))
193+
shuffled.setSerializer(serializer(null, valueSchema, false, 1))
191194
shuffled.map(_._2)
192195

193196
case _ => sys.error(s"Exchange not implemented for $newPartitioning")

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import scala.reflect.ClassTag
2727
import org.apache.spark.serializer._
2828
import org.apache.spark.Logging
2929
import org.apache.spark.sql.Row
30-
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
30+
import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, MutableRow, GenericMutableRow}
3131
import org.apache.spark.sql.types._
3232

3333
/**
@@ -49,9 +49,9 @@ private[sql] class Serializer2SerializationStream(
4949
out: OutputStream)
5050
extends SerializationStream with Logging {
5151

52-
val rowOut = new DataOutputStream(new BufferedOutputStream(out))
53-
val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
54-
val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)
52+
private val rowOut = new DataOutputStream(new BufferedOutputStream(out))
53+
private val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
54+
private val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)
5555

5656
override def writeObject[T: ClassTag](t: T): SerializationStream = {
5757
val kv = t.asInstanceOf[Product2[Row, Row]]
@@ -86,24 +86,44 @@ private[sql] class Serializer2SerializationStream(
8686
private[sql] class Serializer2DeserializationStream(
8787
keySchema: Array[DataType],
8888
valueSchema: Array[DataType],
89+
hasKeyOrdering: Boolean,
8990
in: InputStream)
9091
extends DeserializationStream with Logging {
9192

92-
val rowIn = new DataInputStream(new BufferedInputStream(in))
93+
private val rowIn = new DataInputStream(new BufferedInputStream(in))
9394

94-
val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn)
95-
val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn)
95+
private def rowGenerator(schema: Array[DataType]): () => (MutableRow) = {
96+
if (schema == null) {
97+
() => null
98+
} else {
99+
if (hasKeyOrdering) {
100+
// We have key ordering specified in a ShuffledRDD, it is not safe to reuse a mutable row.
101+
() => new GenericMutableRow(schema.length)
102+
} else {
103+
// It is safe to reuse the mutable row.
104+
val mutableRow = new SpecificMutableRow(schema)
105+
() => mutableRow
106+
}
107+
}
108+
}
109+
110+
// Functions used to return rows for key and value.
111+
private val getKey = rowGenerator(keySchema)
112+
private val getValue = rowGenerator(valueSchema)
113+
// Functions used to read a serialized row from the InputStream and deserialize it.
114+
private val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn)
115+
private val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn)
96116

97117
override def readObject[T: ClassTag](): T = {
98-
(readKeyFunc(), readValueFunc()).asInstanceOf[T]
118+
(readKeyFunc(getKey()), readValueFunc(getValue())).asInstanceOf[T]
99119
}
100120

101121
override def readKey[T: ClassTag](): T = {
102-
readKeyFunc().asInstanceOf[T]
122+
readKeyFunc(getKey()).asInstanceOf[T]
103123
}
104124

105125
override def readValue[T: ClassTag](): T = {
106-
readValueFunc().asInstanceOf[T]
126+
readValueFunc(getValue()).asInstanceOf[T]
107127
}
108128

109129
override def close(): Unit = {
@@ -113,7 +133,8 @@ private[sql] class Serializer2DeserializationStream(
113133

114134
private[sql] class SparkSqlSerializer2Instance(
115135
keySchema: Array[DataType],
116-
valueSchema: Array[DataType])
136+
valueSchema: Array[DataType],
137+
hasKeyOrdering: Boolean)
117138
extends SerializerInstance {
118139

119140
def serialize[T: ClassTag](t: T): ByteBuffer =
@@ -130,7 +151,7 @@ private[sql] class SparkSqlSerializer2Instance(
130151
}
131152

132153
def deserializeStream(s: InputStream): DeserializationStream = {
133-
new Serializer2DeserializationStream(keySchema, valueSchema, s)
154+
new Serializer2DeserializationStream(keySchema, valueSchema, hasKeyOrdering, s)
134155
}
135156
}
136157

@@ -141,12 +162,16 @@ private[sql] class SparkSqlSerializer2Instance(
141162
* The schema of keys is represented by `keySchema` and that of values is represented by
142163
* `valueSchema`.
143164
*/
144-
private[sql] class SparkSqlSerializer2(keySchema: Array[DataType], valueSchema: Array[DataType])
165+
private[sql] class SparkSqlSerializer2(
166+
keySchema: Array[DataType],
167+
valueSchema: Array[DataType],
168+
hasKeyOrdering: Boolean)
145169
extends Serializer
146170
with Logging
147171
with Serializable{
148172

149-
def newInstance(): SerializerInstance = new SparkSqlSerializer2Instance(keySchema, valueSchema)
173+
def newInstance(): SerializerInstance =
174+
new SparkSqlSerializer2Instance(keySchema, valueSchema, hasKeyOrdering)
150175

151176
override def supportsRelocationOfSerializedObjects: Boolean = {
152177
// SparkSqlSerializer2 is stateless and writes no stream headers
@@ -316,12 +341,12 @@ private[sql] object SparkSqlSerializer2 {
316341
*/
317342
def createDeserializationFunction(
318343
schema: Array[DataType],
319-
in: DataInputStream): () => Row = {
320-
() => {
321-
// If the schema is null, the returned function does nothing when it get called.
322-
if (schema != null) {
344+
in: DataInputStream): (MutableRow) => Row = {
345+
if (schema == null) {
346+
(mutableRow: MutableRow) => null
347+
} else {
348+
(mutableRow: MutableRow) => {
323349
var i = 0
324-
val mutableRow = new GenericMutableRow(schema.length)
325350
while (i < schema.length) {
326351
schema(i) match {
327352
// When we read values from the underlying stream, we also first read the null byte
@@ -435,8 +460,6 @@ private[sql] object SparkSqlSerializer2 {
435460
}
436461

437462
mutableRow
438-
} else {
439-
null
440463
}
441464
}
442465
}

0 commit comments

Comments
 (0)