Skip to content

Commit e5011fb

Browse files
author
Max Seiden
committed
[SPARK-5277][SQL] - SparkSqlSerializer does not register user specified KryoRegistrators
There were a few places where new SparkSqlSerializer instances were created with new, empty SparkConfs resulting in user specified registrators sometimes not getting initialized. The fix is to try and pull a conf from the SparkEnv, and construct a new conf (that loads default) if one cannot be found. The changes touched: 1) SparkSqlSerializer's resource pool (this appears to fix the issue in the comment) 2) execution.Exchange (for all of the partitioners) 3) execution.Limit (for the HashPartitioner) A few tests were added to ColumnTypeSuite, ensuring that a custom registrator and serde is initialized and used when in-memory columns are written.
1 parent 5909f09 commit e5011fb

File tree

4 files changed

+68
-11
lines changed

4 files changed

+68
-11
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
4646
child.sqlContext.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
4747

4848
override def execute(): RDD[Row] = attachTree(this , "execute") {
49+
lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
4950
newPartitioning match {
5051
case HashPartitioning(expressions, numPartitions) =>
5152
// TODO: Eliminate redundant expressions in grouping key and value.
@@ -70,7 +71,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
7071
}
7172
val part = new HashPartitioner(numPartitions)
7273
val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part)
73-
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
74+
shuffled.setSerializer(new SparkSqlSerializer(sparkConf))
7475
shuffled.map(_._2)
7576

7677
case RangePartitioning(sortingExpressions, numPartitions) =>
@@ -88,7 +89,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
8889

8990
val part = new RangePartitioner(numPartitions, rdd, ascending = true)
9091
val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part)
91-
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
92+
shuffled.setSerializer(new SparkSqlSerializer(sparkConf))
9293

9394
shuffled.map(_._1)
9495

@@ -107,7 +108,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
107108
}
108109
val partitioner = new HashPartitioner(1)
109110
val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner)
110-
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
111+
shuffled.setSerializer(new SparkSqlSerializer(sparkConf))
111112
shuffled.map(_._2)
112113

113114
case _ => sys.error(s"Exchange not implemented for $newPartitioning")

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,9 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
6464
private[execution] class KryoResourcePool(size: Int)
6565
extends ResourcePool[SerializerInstance](size) {
6666

67-
val ser: KryoSerializer = {
67+
val ser: SparkSqlSerializer = {
6868
val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
69-
// TODO (lian) Using KryoSerializer here is workaround, needs further investigation
70-
// Using SparkSqlSerializer here makes BasicQuerySuite to fail because of Kryo serialization
71-
// related error.
72-
new KryoSerializer(sparkConf)
69+
new SparkSqlSerializer(sparkConf)
7370
}
7471

7572
def newInstance(): SerializerInstance = ser.newInstance()

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ case class Limit(limit: Int, child: SparkPlan)
117117
}
118118
val part = new HashPartitioner(1)
119119
val shuffled = new ShuffledRDD[Boolean, Row, Row](rdd, part)
120-
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
120+
val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
121+
shuffled.setSerializer(new SparkSqlSerializer(sparkConf))
121122
shuffled.mapPartitions(_.take(limit).map(_._2))
122123
}
123124
}

sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@ package org.apache.spark.sql.columnar
2020
import java.nio.ByteBuffer
2121
import java.sql.Timestamp
2222

23+
import com.esotericsoftware.kryo.{Serializer, Kryo}
24+
import com.esotericsoftware.kryo.io.{Input, Output}
25+
import org.apache.spark.serializer.KryoRegistrator
2326
import org.scalatest.FunSuite
2427

25-
import org.apache.spark.Logging
28+
import org.apache.spark.{SparkConf, Logging}
2629
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
2730
import org.apache.spark.sql.columnar.ColumnarTestUtils._
2831
import org.apache.spark.sql.execution.SparkSqlSerializer
@@ -73,7 +76,7 @@ class ColumnTypeSuite extends FunSuite with Logging {
7376
checkActualSize(BINARY, binary, 4 + 4)
7477

7578
val generic = Map(1 -> "a")
76-
checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 11)
79+
checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8)
7780
}
7881

7982
testNativeColumnType[BooleanType.type](
@@ -158,6 +161,41 @@ class ColumnTypeSuite extends FunSuite with Logging {
158161
}
159162
}
160163

164+
test("CUSTOM") {
165+
val conf = new SparkConf()
166+
conf.set("spark.kryo.registrator", "org.apache.spark.sql.columnar.Registrator")
167+
val serializer = new SparkSqlSerializer(conf).newInstance()
168+
169+
val buffer = ByteBuffer.allocate(512)
170+
val obj = CustomClass(Int.MaxValue,Long.MaxValue)
171+
val serializedObj = serializer.serialize(obj).array()
172+
173+
GENERIC.append(serializer.serialize(obj).array(), buffer)
174+
buffer.rewind()
175+
176+
val length = buffer.getInt
177+
assert(length === serializedObj.length)
178+
assert(13 == length) // id (1) + int (4) + long (8)
179+
180+
val genericSerializedObj = SparkSqlSerializer.serialize(obj)
181+
assert(length != genericSerializedObj.length)
182+
assert(length < genericSerializedObj.length)
183+
184+
assertResult(obj, "Custom deserialized object didn't equal the original object") {
185+
val bytes = new Array[Byte](length)
186+
buffer.get(bytes, 0, length)
187+
serializer.deserialize(ByteBuffer.wrap(bytes))
188+
}
189+
190+
buffer.rewind()
191+
buffer.putInt(serializedObj.length).put(serializedObj)
192+
193+
assertResult(obj, "Custom deserialized object didn't equal the original object") {
194+
buffer.rewind()
195+
serializer.deserialize(ByteBuffer.wrap(GENERIC.extract(buffer)))
196+
}
197+
}
198+
161199
def testNativeColumnType[T <: NativeType](
162200
columnType: NativeColumnType[T],
163201
putter: (ByteBuffer, T#JvmType) => Unit,
@@ -229,3 +267,23 @@ class ColumnTypeSuite extends FunSuite with Logging {
229267
}
230268
}
231269
}
270+
271+
private[columnar] final case class CustomClass(a: Int, b: Long)
272+
273+
private[columnar] object CustomerSerializer extends Serializer[CustomClass] {
274+
override def write(kryo: Kryo, output: Output, t: CustomClass) {
275+
output.writeInt(t.a)
276+
output.writeLong(t.b)
277+
}
278+
override def read(kryo: Kryo, input: Input, aClass: Class[CustomClass]): CustomClass = {
279+
val a = input.readInt()
280+
val b = input.readLong()
281+
CustomClass(a,b)
282+
}
283+
}
284+
285+
private[columnar] final class Registrator extends KryoRegistrator {
286+
override def registerClasses(kryo: Kryo) {
287+
kryo.register(classOf[CustomClass], CustomerSerializer)
288+
}
289+
}

0 commit comments

Comments
 (0)