Skip to content

Commit 8a53de1

Browse files
Max Seidenmarmbrus
authored andcommitted
[SPARK-5277][SQL] - SparkSqlSerializer doesn't always register user specified KryoRegistrators
[SPARK-5277][SQL] - SparkSqlSerializer doesn't always register user specified KryoRegistrators There were a few places where new SparkSqlSerializer instances were created with new, empty SparkConfs resulting in user specified registrators sometimes not getting initialized. The fix is to try and pull a conf from the SparkEnv, and construct a new conf (that loads defaults) if one cannot be found. The changes touched: 1) SparkSqlSerializer's resource pool (this appears to fix the issue in the comment) 2) execution.Exchange (for all of the partitioners) 3) execution.Limit (for the HashPartitioner) A few tests were added to ColumnTypeSuite, ensuring that a custom registrator and serde is initialized and used when in-memory columns are written. Author: Max Seiden <[email protected]> This patch had conflicts when merged, resolved by Committer: Michael Armbrust <[email protected]> Closes apache#5237 from mhseiden/sql_udt_kryo and squashes the following commits: 3175c2f [Max Seiden] [SPARK-5277][SQL] - address code review comments e5011fb [Max Seiden] [SPARK-5277][SQL] - SparkSqlSerializer does not register user specified KryoRegistrators
1 parent d5f1b96 commit 8a53de1

File tree

4 files changed

+68
-12
lines changed

4 files changed

+68
-12
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ case class Exchange(
7878
}
7979

8080
override def execute(): RDD[Row] = attachTree(this , "execute") {
81+
lazy val sparkConf = child.sqlContext.sparkContext.getConf
82+
8183
newPartitioning match {
8284
case HashPartitioning(expressions, numPartitions) =>
8385
// TODO: Eliminate redundant expressions in grouping key and value.
@@ -109,7 +111,7 @@ case class Exchange(
109111
} else {
110112
new ShuffledRDD[Row, Row, Row](rdd, part)
111113
}
112-
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
114+
shuffled.setSerializer(new SparkSqlSerializer(sparkConf))
113115
shuffled.map(_._2)
114116

115117
case RangePartitioning(sortingExpressions, numPartitions) =>
@@ -132,8 +134,7 @@ case class Exchange(
132134
} else {
133135
new ShuffledRDD[Row, Null, Null](rdd, part)
134136
}
135-
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
136-
137+
shuffled.setSerializer(new SparkSqlSerializer(sparkConf))
137138
shuffled.map(_._1)
138139

139140
case SinglePartition =>
@@ -151,7 +152,7 @@ case class Exchange(
151152
}
152153
val partitioner = new HashPartitioner(1)
153154
val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner)
154-
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
155+
shuffled.setSerializer(new SparkSqlSerializer(sparkConf))
155156
shuffled.map(_._2)
156157

157158
case _ => sys.error(s"Exchange not implemented for $newPartitioning")

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,9 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
6565
private[execution] class KryoResourcePool(size: Int)
6666
extends ResourcePool[SerializerInstance](size) {
6767

68-
val ser: KryoSerializer = {
68+
val ser: SparkSqlSerializer = {
6969
val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
70-
// TODO (lian) Using KryoSerializer here is workaround, needs further investigation
71-
// Using SparkSqlSerializer here makes BasicQuerySuite to fail because of Kryo serialization
72-
// related error.
73-
new KryoSerializer(sparkConf)
70+
new SparkSqlSerializer(sparkConf)
7471
}
7572

7673
def newInstance(): SerializerInstance = ser.newInstance()

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ case class Limit(limit: Int, child: SparkPlan)
121121
}
122122
val part = new HashPartitioner(1)
123123
val shuffled = new ShuffledRDD[Boolean, Row, Row](rdd, part)
124-
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
124+
shuffled.setSerializer(new SparkSqlSerializer(child.sqlContext.sparkContext.getConf))
125125
shuffled.mapPartitions(_.take(limit).map(_._2))
126126
}
127127
}

sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@ package org.apache.spark.sql.columnar
2020
import java.nio.ByteBuffer
2121
import java.sql.Timestamp
2222

23+
import com.esotericsoftware.kryo.{Serializer, Kryo}
24+
import com.esotericsoftware.kryo.io.{Input, Output}
25+
import org.apache.spark.serializer.KryoRegistrator
2326
import org.scalatest.FunSuite
2427

25-
import org.apache.spark.Logging
28+
import org.apache.spark.{SparkConf, Logging}
2629
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
2730
import org.apache.spark.sql.columnar.ColumnarTestUtils._
2831
import org.apache.spark.sql.execution.SparkSqlSerializer
@@ -73,7 +76,7 @@ class ColumnTypeSuite extends FunSuite with Logging {
7376
checkActualSize(BINARY, binary, 4 + 4)
7477

7578
val generic = Map(1 -> "a")
76-
checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 11)
79+
checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8)
7780
}
7881

7982
testNativeColumnType[BooleanType.type](
@@ -158,6 +161,41 @@ class ColumnTypeSuite extends FunSuite with Logging {
158161
}
159162
}
160163

164+
test("CUSTOM") {
165+
val conf = new SparkConf()
166+
conf.set("spark.kryo.registrator", "org.apache.spark.sql.columnar.Registrator")
167+
val serializer = new SparkSqlSerializer(conf).newInstance()
168+
169+
val buffer = ByteBuffer.allocate(512)
170+
val obj = CustomClass(Int.MaxValue,Long.MaxValue)
171+
val serializedObj = serializer.serialize(obj).array()
172+
173+
GENERIC.append(serializer.serialize(obj).array(), buffer)
174+
buffer.rewind()
175+
176+
val length = buffer.getInt
177+
assert(length === serializedObj.length)
178+
assert(13 == length) // id (1) + int (4) + long (8)
179+
180+
val genericSerializedObj = SparkSqlSerializer.serialize(obj)
181+
assert(length != genericSerializedObj.length)
182+
assert(length < genericSerializedObj.length)
183+
184+
assertResult(obj, "Custom deserialized object didn't equal the original object") {
185+
val bytes = new Array[Byte](length)
186+
buffer.get(bytes, 0, length)
187+
serializer.deserialize(ByteBuffer.wrap(bytes))
188+
}
189+
190+
buffer.rewind()
191+
buffer.putInt(serializedObj.length).put(serializedObj)
192+
193+
assertResult(obj, "Custom deserialized object didn't equal the original object") {
194+
buffer.rewind()
195+
serializer.deserialize(ByteBuffer.wrap(GENERIC.extract(buffer)))
196+
}
197+
}
198+
161199
def testNativeColumnType[T <: NativeType](
162200
columnType: NativeColumnType[T],
163201
putter: (ByteBuffer, T#JvmType) => Unit,
@@ -229,3 +267,23 @@ class ColumnTypeSuite extends FunSuite with Logging {
229267
}
230268
}
231269
}
270+
271+
private[columnar] final case class CustomClass(a: Int, b: Long)
272+
273+
private[columnar] object CustomerSerializer extends Serializer[CustomClass] {
274+
override def write(kryo: Kryo, output: Output, t: CustomClass) {
275+
output.writeInt(t.a)
276+
output.writeLong(t.b)
277+
}
278+
override def read(kryo: Kryo, input: Input, aClass: Class[CustomClass]): CustomClass = {
279+
val a = input.readInt()
280+
val b = input.readLong()
281+
CustomClass(a,b)
282+
}
283+
}
284+
285+
private[columnar] final class Registrator extends KryoRegistrator {
286+
override def registerClasses(kryo: Kryo) {
287+
kryo.register(classOf[CustomClass], CustomerSerializer)
288+
}
289+
}

0 commit comments

Comments
 (0)