Skip to content

Commit 2f94951

Browse files
committed
Added more checks
1 parent 090044c commit 2f94951

File tree

2 files changed

+64
-3
lines changed

2 files changed

+64
-3
lines changed

sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@ import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProj
3030
import org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState
3131
import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning
3232
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
33-
import org.apache.spark.sql.execution.RDDScanExec
34-
import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream}
33+
import org.apache.spark.sql.execution.{RDDScanExec, WholeStageCodegenExec}
34+
import org.apache.spark.sql.execution.exchange.ShuffleExchange
35+
import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream, StreamingQueryWrapper}
3536
import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair}
3637
import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore
3738
import org.apache.spark.sql.streaming.util.{MockSourceProvider, StreamManualClock}
@@ -917,6 +918,13 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
917918
inputSource.releaseLock()
918919
sq.processAllAvailable()
919920

921+
val restore1 = sq.asInstanceOf[StreamingQueryWrapper].streamingQuery
922+
.lastExecution.executedPlan
923+
.collect { case ss: FlatMapGroupsWithStateExec => ss }
924+
.head
925+
assert(restore1.child.outputPartitioning.numPartitions ===
926+
spark.sessionState.conf.numShufflePartitions)
927+
920928
checkDataset(
921929
spark.read.parquet(data).as[Int],
922930
1)
@@ -939,6 +947,13 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
939947
inputSource.releaseLock()
940948
sq2.processAllAvailable()
941949

950+
val restore2 = sq.asInstanceOf[StreamingQueryWrapper].streamingQuery
951+
.lastExecution.executedPlan
952+
.collect { case ss: FlatMapGroupsWithStateExec => ss }
953+
.head
954+
assert(restore2.child.outputPartitioning.numPartitions ===
955+
spark.sessionState.conf.numShufflePartitions)
956+
942957
checkDataset(
943958
spark.read.parquet(data).as[Int],
944959
4, 3, 1)

sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ import org.apache.spark.sql.catalyst.InternalRow
3030
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
3131
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
3232
import org.apache.spark.sql.catalyst.util.DateTimeUtils
33-
import org.apache.spark.sql.execution.SparkPlan
33+
import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec}
34+
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
35+
import org.apache.spark.sql.execution.exchange.ShuffleExchange
3436
import org.apache.spark.sql.execution.streaming._
3537
import org.apache.spark.sql.execution.streaming.state.StateStore
3638
import org.apache.spark.sql.expressions.scalalang.typed
@@ -417,10 +419,30 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
417419
spark.table("agg_test").as[Long],
418420
1L)
419421

422+
val restore1 = sq.asInstanceOf[StreamingQueryWrapper].streamingQuery
423+
.lastExecution.executedPlan
424+
.collect { case ss: StateStoreRestoreExec => ss }
425+
.head
426+
restore1.child match {
427+
case wscg: WholeStageCodegenExec =>
428+
assert(wscg.outputPartitioning.numPartitions === 1)
429+
assert(wscg.child.isInstanceOf[HashAggregateExec], "Shouldn't require shuffling")
430+
case _ => fail("Expected no shuffling")
431+
}
432+
420433
inputSource.addData()
421434
inputSource.releaseLock()
422435
sq.processAllAvailable()
423436

437+
val restore2 = sq.asInstanceOf[StreamingQueryWrapper].streamingQuery
438+
.lastExecution.executedPlan
439+
.collect { case ss: StateStoreRestoreExec => ss }
440+
.head
441+
restore2.child match {
442+
case shuffle: ShuffleExchange => assert(shuffle.newPartitioning.numPartitions === 1)
443+
case _ => fail("Expected shuffling when there was no data")
444+
}
445+
424446
checkDataset(
425447
spark.table("agg_test").as[Long],
426448
1L)
@@ -472,6 +494,17 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
472494
inputSource.releaseLock()
473495
sq.processAllAvailable()
474496

497+
val restore1 = sq.asInstanceOf[StreamingQueryWrapper].streamingQuery
498+
.lastExecution.executedPlan
499+
.collect { case ss: StateStoreRestoreExec => ss }
500+
.head
501+
restore1.child match {
502+
case shuffle: ShuffleExchange =>
503+
assert(shuffle.newPartitioning.numPartitions ===
504+
spark.sessionState.conf.numShufflePartitions)
505+
case _ => fail(s"Expected shuffling but got: ${restore1.child}")
506+
}
507+
475508
checkDataset(
476509
spark.table("agg_test").as[(Long, Long)],
477510
(0L, 1L))
@@ -502,6 +535,19 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
502535
inputSource.releaseLock()
503536
sq2.processAllAvailable()
504537

538+
val restore2 = sq2.asInstanceOf[StreamingQueryWrapper].streamingQuery
539+
.lastExecution.executedPlan
540+
.collect { case ss: StateStoreRestoreExec => ss }
541+
.head
542+
restore2.child match {
543+
case wscg: WholeStageCodegenExec =>
544+
assert(wscg.outputPartitioning.numPartitions ===
545+
spark.sessionState.conf.numShufflePartitions)
546+
case _ =>
547+
fail("Shouldn't require shuffling as HashAggregateExec should have asked for a " +
548+
s"shuffle. But got: ${restore2.child}")
549+
}
550+
505551
checkDataset(
506552
spark.table("agg_test").as[(Long, Long)],
507553
(0L, 4L))

0 commit comments

Comments
 (0)