@@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.adaptive
1919
2020import org .apache .spark .sql .QueryTest
2121import org .apache .spark .sql .execution .{ReusedSubqueryExec , SparkPlan }
22- import org .apache .spark .sql .execution .adaptive .rule .CoalescedShuffleReaderExec
2322import org .apache .spark .sql .execution .exchange .Exchange
2423import org .apache .spark .sql .execution .joins .{BroadcastHashJoinExec , BuildRight , SortMergeJoinExec }
2524import org .apache .spark .sql .internal .SQLConf
@@ -78,7 +77,7 @@ class AdaptiveQueryExecSuite
7877 }
7978
8079 private def checkNumLocalShuffleReaders (plan : SparkPlan , expected : Int ): Unit = {
81- val localReaders = plan. collect {
80+ val localReaders = collect(plan) {
8281 case reader : LocalShuffleReaderExec => reader
8382 }
8483 assert(localReaders.length === expected)
@@ -164,7 +163,7 @@ class AdaptiveQueryExecSuite
164163 assert(smj.size == 3 )
165164 val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
166165 assert(bhj.size == 3 )
167-
166+ // additional shuffle exchange introduced, only one shuffle reader to local shuffle reader.
168167 checkNumLocalShuffleReaders(adaptivePlan, 1 )
169168 }
170169 }
@@ -189,8 +188,8 @@ class AdaptiveQueryExecSuite
189188 assert(smj.size == 3 )
190189 val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
191190 assert(bhj.size == 3 )
192-
193- checkNumLocalShuffleReaders(adaptivePlan, 0 )
191+ // additional shuffle exchange introduced, only one shuffle reader to local shuffle reader.
192+ checkNumLocalShuffleReaders(adaptivePlan, 1 )
194193 }
195194 }
196195
@@ -214,7 +213,8 @@ class AdaptiveQueryExecSuite
214213 assert(smj.size == 3 )
215214 val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
216215 assert(bhj.size == 3 )
217- checkNumLocalShuffleReaders(adaptivePlan, 0 )
216+ // additional shuffle exchange introduced, only one shuffle reader to local shuffle reader.
217+ checkNumLocalShuffleReaders(adaptivePlan, 1 )
218218 }
219219 }
220220
@@ -229,6 +229,8 @@ class AdaptiveQueryExecSuite
229229 assert(smj.size == 3 )
230230 val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
231231 assert(bhj.size == 2 )
232+ checkNumLocalShuffleReaders(adaptivePlan, 2 )
233+ // Even with local shuffle reader, the query statge reuse can also work.
232234 val ex = findReusedExchange(adaptivePlan)
233235 assert(ex.size == 1 )
234236 }
@@ -245,6 +247,8 @@ class AdaptiveQueryExecSuite
245247 assert(smj.size == 1 )
246248 val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
247249 assert(bhj.size == 1 )
250+ checkNumLocalShuffleReaders(adaptivePlan, 1 )
251+ // Even with local shuffle reader, the query statge reuse can also work.
248252 val ex = findReusedExchange(adaptivePlan)
249253 assert(ex.size == 1 )
250254 }
@@ -263,6 +267,8 @@ class AdaptiveQueryExecSuite
263267 assert(smj.size == 1 )
264268 val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
265269 assert(bhj.size == 1 )
270+ checkNumLocalShuffleReaders(adaptivePlan, 1 )
271+ // Even with local shuffle reader, the query statge reuse can also work.
266272 val ex = findReusedExchange(adaptivePlan)
267273 assert(ex.nonEmpty)
268274 val sub = findReusedSubquery(adaptivePlan)
@@ -282,6 +288,8 @@ class AdaptiveQueryExecSuite
282288 assert(smj.size == 1 )
283289 val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
284290 assert(bhj.size == 1 )
291+ checkNumLocalShuffleReaders(adaptivePlan, 1 )
292+ // Even with local shuffle reader, the query statge reuse can also work.
285293 val ex = findReusedExchange(adaptivePlan)
286294 assert(ex.isEmpty)
287295 val sub = findReusedSubquery(adaptivePlan)
@@ -304,6 +312,8 @@ class AdaptiveQueryExecSuite
304312 assert(smj.size == 1 )
305313 val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
306314 assert(bhj.size == 1 )
315+ checkNumLocalShuffleReaders(adaptivePlan, 1 )
316+ // Even with local shuffle reader, the query statge reuse can also work.
307317 val ex = findReusedExchange(adaptivePlan)
308318 assert(ex.nonEmpty)
309319 assert(ex.head.plan.isInstanceOf [BroadcastQueryStageExec ])
0 commit comments