Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
child.sqlContext.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)

override def execute(): RDD[Row] = attachTree(this , "execute") {
lazy val sparkConf = child.sqlContext.sparkContext.getConf

newPartitioning match {
case HashPartitioning(expressions, numPartitions) =>
// TODO: Eliminate redundant expressions in grouping key and value.
Expand All @@ -70,7 +72,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
}
val part = new HashPartitioner(numPartitions)
val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part)
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
shuffled.setSerializer(new SparkSqlSerializer(sparkConf))
shuffled.map(_._2)

case RangePartitioning(sortingExpressions, numPartitions) =>
Expand All @@ -88,7 +90,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una

val part = new RangePartitioner(numPartitions, rdd, ascending = true)
val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part)
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
shuffled.setSerializer(new SparkSqlSerializer(sparkConf))

shuffled.map(_._1)

Expand All @@ -107,7 +109,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
}
val partitioner = new HashPartitioner(1)
val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner)
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
shuffled.setSerializer(new SparkSqlSerializer(sparkConf))
shuffled.map(_._2)

case _ => sys.error(s"Exchange not implemented for $newPartitioning")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,9 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
private[execution] class KryoResourcePool(size: Int)
extends ResourcePool[SerializerInstance](size) {

val ser: KryoSerializer = {
val ser: SparkSqlSerializer = {
val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
// TODO (lian) Using KryoSerializer here is workaround, needs further investigation
// Using SparkSqlSerializer here makes BasicQuerySuite to fail because of Kryo serialization
// related error.
new KryoSerializer(sparkConf)
new SparkSqlSerializer(sparkConf)
}

def newInstance(): SerializerInstance = ser.newInstance()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ case class Limit(limit: Int, child: SparkPlan)
}
val part = new HashPartitioner(1)
val shuffled = new ShuffledRDD[Boolean, Row, Row](rdd, part)
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
shuffled.setSerializer(new SparkSqlSerializer(child.sqlContext.sparkContext.getConf))
shuffled.mapPartitions(_.take(limit).map(_._2))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
import java.sql.Timestamp

import com.esotericsoftware.kryo.{Serializer, Kryo}
import com.esotericsoftware.kryo.io.{Input, Output}
import org.apache.spark.serializer.KryoRegistrator
import org.scalatest.FunSuite

import org.apache.spark.Logging
import org.apache.spark.{SparkConf, Logging}
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.columnar.ColumnarTestUtils._
import org.apache.spark.sql.execution.SparkSqlSerializer
Expand Down Expand Up @@ -73,7 +76,7 @@ class ColumnTypeSuite extends FunSuite with Logging {
checkActualSize(BINARY, binary, 4 + 4)

val generic = Map(1 -> "a")
checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 11)
checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8)
}

testNativeColumnType[BooleanType.type](
Expand Down Expand Up @@ -158,6 +161,41 @@ class ColumnTypeSuite extends FunSuite with Logging {
}
}

test("CUSTOM") {
val conf = new SparkConf()
conf.set("spark.kryo.registrator", "org.apache.spark.sql.columnar.Registrator")
val serializer = new SparkSqlSerializer(conf).newInstance()

val buffer = ByteBuffer.allocate(512)
val obj = CustomClass(Int.MaxValue,Long.MaxValue)
val serializedObj = serializer.serialize(obj).array()

GENERIC.append(serializer.serialize(obj).array(), buffer)
buffer.rewind()

val length = buffer.getInt
assert(length === serializedObj.length)
assert(13 == length) // id (1) + int (4) + long (8)

val genericSerializedObj = SparkSqlSerializer.serialize(obj)
assert(length != genericSerializedObj.length)
assert(length < genericSerializedObj.length)

assertResult(obj, "Custom deserialized object didn't equal the original object") {
val bytes = new Array[Byte](length)
buffer.get(bytes, 0, length)
serializer.deserialize(ByteBuffer.wrap(bytes))
}

buffer.rewind()
buffer.putInt(serializedObj.length).put(serializedObj)

assertResult(obj, "Custom deserialized object didn't equal the original object") {
buffer.rewind()
serializer.deserialize(ByteBuffer.wrap(GENERIC.extract(buffer)))
}
}

def testNativeColumnType[T <: NativeType](
columnType: NativeColumnType[T],
putter: (ByteBuffer, T#JvmType) => Unit,
Expand Down Expand Up @@ -229,3 +267,23 @@ class ColumnTypeSuite extends FunSuite with Logging {
}
}
}

private[columnar] final case class CustomClass(a: Int, b: Long)

private[columnar] object CustomerSerializer extends Serializer[CustomClass] {
override def write(kryo: Kryo, output: Output, t: CustomClass) {
output.writeInt(t.a)
output.writeLong(t.b)
}
override def read(kryo: Kryo, input: Input, aClass: Class[CustomClass]): CustomClass = {
val a = input.readInt()
val b = input.readLong()
CustomClass(a,b)
}
}

private[columnar] final class Registrator extends KryoRegistrator {
override def registerClasses(kryo: Kryo) {
kryo.register(classOf[CustomClass], CustomerSerializer)
}
}