Skip to content

Commit c7e2129

Browse files
committed
Update tests.
1 parent 4513d13 commit c7e2129

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,15 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
148148
table("shuffle").collect())
149149
}
150150

151+
test("key schema is null") {
152+
val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",")
153+
val df = sql(s"SELECT $aggregations FROM shuffle")
154+
checkSerializer(df.queryExecution.executedPlan, serializerClass)
155+
checkAnswer(
156+
df,
157+
Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000))
158+
}
159+
151160
test("value schema is null") {
152161
val df = sql(s"SELECT col0 FROM shuffle ORDER BY col0")
153162
checkSerializer(df.queryExecution.executedPlan, serializerClass)
@@ -167,29 +176,20 @@ class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite {
167176
override def beforeAll(): Unit = {
168177
super.beforeAll()
169178
// Sort merge will not be triggered.
170-
sql("set spark.sql.shuffle.partitions = 200")
171-
}
172-
173-
test("key schema is null") {
174-
val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",")
175-
val df = sql(s"SELECT $aggregations FROM shuffle")
176-
checkSerializer(df.queryExecution.executedPlan, serializerClass)
177-
checkAnswer(
178-
df,
179-
Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000))
179+
val bypassMergeThreshold =
180+
sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
181+
sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold-1}")
180182
}
181183
}
182184

183185
/** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */
184186
class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite {
185187

186-
// We are expecting SparkSqlSerializer.
187-
override val serializerClass: Class[Serializer] =
188-
classOf[SparkSqlSerializer].asInstanceOf[Class[Serializer]]
189-
190188
override def beforeAll(): Unit = {
191189
super.beforeAll()
192190
// To trigger the sort merge.
193-
sql("set spark.sql.shuffle.partitions = 201")
191+
val bypassMergeThreshold =
192+
sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
193+
sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold + 1}")
194194
}
195195
}

0 commit comments

Comments
 (0)