@@ -19,13 +19,15 @@ package org.apache.spark.sql.execution
1919
2020import org .apache .spark .annotation .DeveloperApi
2121import org .apache .spark .shuffle .sort .SortShuffleManager
22- import org .apache .spark .{SparkEnv , HashPartitioner , RangePartitioner , SparkConf }
22+ import org .apache .spark .{SparkEnv , HashPartitioner , RangePartitioner }
2323import org .apache .spark .rdd .{RDD , ShuffledRDD }
24+ import org .apache .spark .serializer .Serializer
2425import org .apache .spark .sql .{SQLContext , Row }
2526import org .apache .spark .sql .catalyst .errors .attachTree
2627import org .apache .spark .sql .catalyst .expressions ._
2728import org .apache .spark .sql .catalyst .plans .physical ._
2829import org .apache .spark .sql .catalyst .rules .Rule
30+ import org .apache .spark .sql .types .DataType
2931import org .apache .spark .util .MutablePair
3032
3133object Exchange {
@@ -77,9 +79,48 @@ case class Exchange(
7779 }
7880 }
7981
80- override def execute (): RDD [Row ] = attachTree(this , " execute" ) {
81- lazy val sparkConf = child.sqlContext.sparkContext.getConf
82+ @ transient private lazy val sparkConf = child.sqlContext.sparkContext.getConf
83+
84+ def serializer (
85+ keySchema : Array [DataType ],
86+ valueSchema : Array [DataType ],
87+ numPartitions : Int ): Serializer = {
88+ // In ExternalSorter's spillToMergeableFile function, key-value pairs are written out
89+ // through write(key) and then write(value) instead of write((key, value)). Because
90+ // SparkSqlSerializer2 assumes that objects passed in are Product2, we cannot safely use
91+ // it when spillToMergeableFile in ExternalSorter will be used.
92+ // So, we will not use SparkSqlSerializer2 when
93+ // - Sort-based shuffle is enabled and the number of reducers (numPartitions) is greater
94+ // then the bypassMergeThreshold; or
95+ // - newOrdering is defined.
96+ val cannotUseSqlSerializer2 =
97+ (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) || newOrdering.nonEmpty
98+
99+ // It is true when there is no field that needs to be write out.
100+ // For now, we will not use SparkSqlSerializer2 when noField is true.
101+ val noField =
102+ (keySchema == null || keySchema.length == 0 ) &&
103+ (valueSchema == null || valueSchema.length == 0 )
104+
105+ val useSqlSerializer2 =
106+ child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled.
107+ ! cannotUseSqlSerializer2 && // Safe to use Serializer2.
108+ SparkSqlSerializer2 .support(keySchema) && // The schema of key is supported.
109+ SparkSqlSerializer2 .support(valueSchema) && // The schema of value is supported.
110+ ! noField
111+
112+ val serializer = if (useSqlSerializer2) {
113+ logInfo(" Using SparkSqlSerializer2." )
114+ new SparkSqlSerializer2 (keySchema, valueSchema)
115+ } else {
116+ logInfo(" Using SparkSqlSerializer." )
117+ new SparkSqlSerializer (sparkConf)
118+ }
119+
120+ serializer
121+ }
82122
123+ override def execute (): RDD [Row ] = attachTree(this , " execute" ) {
83124 newPartitioning match {
84125 case HashPartitioning (expressions, numPartitions) =>
85126 // TODO: Eliminate redundant expressions in grouping key and value.
@@ -111,7 +152,10 @@ case class Exchange(
111152 } else {
112153 new ShuffledRDD [Row , Row , Row ](rdd, part)
113154 }
114- shuffled.setSerializer(new SparkSqlSerializer (sparkConf))
155+ val keySchema = expressions.map(_.dataType).toArray
156+ val valueSchema = child.output.map(_.dataType).toArray
157+ shuffled.setSerializer(serializer(keySchema, valueSchema, numPartitions))
158+
115159 shuffled.map(_._2)
116160
117161 case RangePartitioning (sortingExpressions, numPartitions) =>
@@ -134,7 +178,9 @@ case class Exchange(
134178 } else {
135179 new ShuffledRDD [Row , Null , Null ](rdd, part)
136180 }
137- shuffled.setSerializer(new SparkSqlSerializer (sparkConf))
181+ val keySchema = child.output.map(_.dataType).toArray
182+ shuffled.setSerializer(serializer(keySchema, null , numPartitions))
183+
138184 shuffled.map(_._1)
139185
140186 case SinglePartition =>
@@ -152,7 +198,8 @@ case class Exchange(
152198 }
153199 val partitioner = new HashPartitioner (1 )
154200 val shuffled = new ShuffledRDD [Null , Row , Row ](rdd, partitioner)
155- shuffled.setSerializer(new SparkSqlSerializer (sparkConf))
201+ val valueSchema = child.output.map(_.dataType).toArray
202+ shuffled.setSerializer(serializer(null , valueSchema, 1 ))
156203 shuffled.map(_._2)
157204
158205 case _ => sys.error(s " Exchange not implemented for $newPartitioning" )
0 commit comments