@@ -30,7 +30,9 @@ import org.apache.spark.sql.catalyst.InternalRow
3030import org .apache .spark .sql .catalyst .plans .logical .Aggregate
3131import org .apache .spark .sql .catalyst .plans .physical .SinglePartition
3232import org .apache .spark .sql .catalyst .util .DateTimeUtils
33- import org .apache .spark .sql .execution .SparkPlan
33+ import org .apache .spark .sql .execution .{SparkPlan , WholeStageCodegenExec }
34+ import org .apache .spark .sql .execution .aggregate .HashAggregateExec
35+ import org .apache .spark .sql .execution .exchange .ShuffleExchange
3436import org .apache .spark .sql .execution .streaming ._
3537import org .apache .spark .sql .execution .streaming .state .StateStore
3638import org .apache .spark .sql .expressions .scalalang .typed
@@ -417,10 +419,30 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
417419 spark.table(" agg_test" ).as[Long ],
418420 1L )
419421
422+ val restore1 = sq.asInstanceOf [StreamingQueryWrapper ].streamingQuery
423+ .lastExecution.executedPlan
424+ .collect { case ss : StateStoreRestoreExec => ss }
425+ .head
426+ restore1.child match {
427+ case wscg : WholeStageCodegenExec =>
428+ assert(wscg.outputPartitioning.numPartitions === 1 )
429+ assert(wscg.child.isInstanceOf [HashAggregateExec ], " Shouldn't require shuffling" )
430+ case _ => fail(" Expected no shuffling" )
431+ }
432+
420433 inputSource.addData()
421434 inputSource.releaseLock()
422435 sq.processAllAvailable()
423436
437+ val restore2 = sq.asInstanceOf [StreamingQueryWrapper ].streamingQuery
438+ .lastExecution.executedPlan
439+ .collect { case ss : StateStoreRestoreExec => ss }
440+ .head
441+ restore2.child match {
442+ case shuffle : ShuffleExchange => assert(shuffle.newPartitioning.numPartitions === 1 )
443+ case _ => fail(" Expected shuffling when there was no data" )
444+ }
445+
424446 checkDataset(
425447 spark.table(" agg_test" ).as[Long ],
426448 1L )
@@ -472,6 +494,17 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
472494 inputSource.releaseLock()
473495 sq.processAllAvailable()
474496
497+ val restore1 = sq.asInstanceOf [StreamingQueryWrapper ].streamingQuery
498+ .lastExecution.executedPlan
499+ .collect { case ss : StateStoreRestoreExec => ss }
500+ .head
501+ restore1.child match {
502+ case shuffle : ShuffleExchange =>
503+ assert(shuffle.newPartitioning.numPartitions ===
504+ spark.sessionState.conf.numShufflePartitions)
505+ case _ => fail(s " Expected shuffling but got: ${restore1.child}" )
506+ }
507+
475508 checkDataset(
476509 spark.table(" agg_test" ).as[(Long , Long )],
477510 (0L , 1L ))
@@ -502,6 +535,19 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
502535 inputSource.releaseLock()
503536 sq2.processAllAvailable()
504537
538+ val restore2 = sq2.asInstanceOf [StreamingQueryWrapper ].streamingQuery
539+ .lastExecution.executedPlan
540+ .collect { case ss : StateStoreRestoreExec => ss }
541+ .head
542+ restore2.child match {
543+ case wscg : WholeStageCodegenExec =>
544+ assert(wscg.outputPartitioning.numPartitions ===
545+ spark.sessionState.conf.numShufflePartitions)
546+ case _ =>
547+ fail(" Shouldn't require shuffling as HashAggregateExec should have asked for a " +
548+ s " shuffle. But got: ${restore2.child}" )
549+ }
550+
505551 checkDataset(
506552 spark.table(" agg_test" ).as[(Long , Long )],
507553 (0L , 4L ))
0 commit comments