@@ -148,6 +148,15 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
148148 table(" shuffle" ).collect())
149149 }
150150
151+ test(" key schema is null" ) {
152+ val aggregations = allColumns.split(" ," ).map(c => s " COUNT( $c) " ).mkString(" ," )
153+ val df = sql(s " SELECT $aggregations FROM shuffle " )
154+ checkSerializer(df.queryExecution.executedPlan, serializerClass)
155+ checkAnswer(
156+ df,
157+ Row (1000 , 1000 , 0 , 1000 , 1000 , 1000 , 1000 , 1000 , 1000 , 1000 , 1000 , 1000 , 1000 , 1000 ))
158+ }
159+
151160 test(" value schema is null" ) {
152161 val df = sql(s " SELECT col0 FROM shuffle ORDER BY col0 " )
153162 checkSerializer(df.queryExecution.executedPlan, serializerClass)
@@ -167,29 +176,20 @@ class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite {
167176 override def beforeAll (): Unit = {
168177 super .beforeAll()
169178 // Sort merge will not be triggered.
170- sql(" set spark.sql.shuffle.partitions = 200" )
171- }
172-
173- test(" key schema is null" ) {
174- val aggregations = allColumns.split(" ," ).map(c => s " COUNT( $c) " ).mkString(" ," )
175- val df = sql(s " SELECT $aggregations FROM shuffle " )
176- checkSerializer(df.queryExecution.executedPlan, serializerClass)
177- checkAnswer(
178- df,
179- Row (1000 , 1000 , 0 , 1000 , 1000 , 1000 , 1000 , 1000 , 1000 , 1000 , 1000 , 1000 , 1000 , 1000 ))
179+ val bypassMergeThreshold =
180+ sparkContext.conf.getInt(" spark.shuffle.sort.bypassMergeThreshold" , 200 )
181+ sql(s " set spark.sql.shuffle.partitions= ${bypassMergeThreshold- 1 }" )
180182 }
181183}
182184
183185/** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */
184186class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite {
185187
186- // We are expecting SparkSqlSerializer.
187- override val serializerClass : Class [Serializer ] =
188- classOf [SparkSqlSerializer ].asInstanceOf [Class [Serializer ]]
189-
190188 override def beforeAll (): Unit = {
191189 super .beforeAll()
192190 // To trigger the sort merge.
193- sql(" set spark.sql.shuffle.partitions = 201" )
191+ val bypassMergeThreshold =
192+ sparkContext.conf.getInt(" spark.shuffle.sort.bypassMergeThreshold" , 200 )
193+ sql(s " set spark.sql.shuffle.partitions= ${bypassMergeThreshold + 1 }" )
194194 }
195195}
0 commit comments