@@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchange}
2929import org .apache .spark .sql .execution .streaming .{IncrementalExecution , OffsetSeqMetadata , StatefulOperator , StatefulOperatorStateInfo }
3030import org .apache .spark .sql .test .SharedSQLContext
3131
32- class IncrementalExecutionRulesSuite extends SparkPlanTest with SharedSQLContext {
32+ class EnsureStatefulOpPartitioningSuite extends SparkPlanTest with SharedSQLContext {
3333
3434 import testImplicits ._
3535 super .beforeAll()
@@ -39,39 +39,46 @@ class IncrementalExecutionRulesSuite extends SparkPlanTest with SharedSQLContext
3939 testEnsureStatefulOpPartitioning(
4040 " ClusteredDistribution generates Exchange with HashPartitioning" ,
4141 baseDf.queryExecution.sparkPlan,
42- keys => ClusteredDistribution (keys),
43- keys => HashPartitioning (keys, spark.sessionState.conf.numShufflePartitions),
42+ requiredDistribution = keys => ClusteredDistribution (keys),
43+ expectedPartitioning =
44+ keys => HashPartitioning (keys, spark.sessionState.conf.numShufflePartitions),
4445 expectShuffle = true )
4546
4647 testEnsureStatefulOpPartitioning(
4748 " ClusteredDistribution with coalesce(1) generates Exchange with HashPartitioning" ,
4849 baseDf.coalesce(1 ).queryExecution.sparkPlan,
49- keys => ClusteredDistribution (keys),
50- keys => HashPartitioning (keys, spark.sessionState.conf.numShufflePartitions),
50+ requiredDistribution = keys => ClusteredDistribution (keys),
51+ expectedPartitioning =
52+ keys => HashPartitioning (keys, spark.sessionState.conf.numShufflePartitions),
5153 expectShuffle = true )
5254
5355 testEnsureStatefulOpPartitioning(
5456 " AllTuples generates Exchange with SinglePartition" ,
5557 baseDf.queryExecution.sparkPlan,
56- keys => AllTuples ,
57- keys => SinglePartition ,
58+ requiredDistribution = _ => AllTuples ,
59+ expectedPartitioning = _ => SinglePartition ,
5860 expectShuffle = true )
5961
6062 testEnsureStatefulOpPartitioning(
6163 " AllTuples with coalesce(1) doesn't need Exchange" ,
6264 baseDf.coalesce(1 ).queryExecution.sparkPlan,
63- keys => AllTuples ,
64- keys => SinglePartition ,
65+ requiredDistribution = _ => AllTuples ,
66+ expectedPartitioning = _ => SinglePartition ,
6567 expectShuffle = false )
6668
69+ /**
70+ * For `StatefulOperator` with the given `requiredChildDistribution`, and child SparkPlan
71+ * `inputPlan`, ensures that the incremental planner adds exchanges, if required, in order to
72+ * ensure the expected partitioning.
73+ */
6774 private def testEnsureStatefulOpPartitioning (
6875 testName : String ,
6976 inputPlan : SparkPlan ,
7077 requiredDistribution : Seq [Attribute ] => Distribution ,
7178 expectedPartitioning : Seq [Attribute ] => Partitioning ,
7279 expectShuffle : Boolean ): Unit = {
73- test(" EnsureStatefulOpPartitioning - " + testName) {
74- val operator = TestOperator (inputPlan, requiredDistribution(inputPlan.output.take(1 )))
80+ test(testName) {
81+ val operator = TestStatefulOperator (inputPlan, requiredDistribution(inputPlan.output.take(1 )))
7582 val executed = executePlan(operator, OutputMode .Complete ())
7683 if (expectShuffle) {
7784 val exchange = executed.children.find(_.isInstanceOf [Exchange ])
@@ -88,6 +95,7 @@ class IncrementalExecutionRulesSuite extends SparkPlanTest with SharedSQLContext
8895 }
8996 }
9097
98+ /** Executes a SparkPlan using the IncrementalPlanner used for Structured Streaming. */
9199 private def executePlan (
92100 p : SparkPlan ,
93101 outputMode : OutputMode = OutputMode .Append ()): SparkPlan = {
@@ -111,13 +119,14 @@ class IncrementalExecutionRulesSuite extends SparkPlanTest with SharedSQLContext
111119 }
112120 execution.executedPlan
113121 }
114- }
115122
116- case class TestOperator (
117- child : SparkPlan ,
118- requiredDist : Distribution ) extends UnaryExecNode with StatefulOperator {
119- override def output : Seq [Attribute ] = child.output
120- override def doExecute (): RDD [InternalRow ] = child.execute()
121- override def requiredChildDistribution : Seq [Distribution ] = requiredDist :: Nil
122- override def stateInfo : Option [StatefulOperatorStateInfo ] = None
123+ /** Used to emulate a [[StatefulOperator ]] with the given requiredDistribution. */
124+ case class TestStatefulOperator (
125+ child : SparkPlan ,
126+ requiredDist : Distribution ) extends UnaryExecNode with StatefulOperator {
127+ override def output : Seq [Attribute ] = child.output
128+ override def doExecute (): RDD [InternalRow ] = child.execute()
129+ override def requiredChildDistribution : Seq [Distribution ] = requiredDist :: Nil
130+ override def stateInfo : Option [StatefulOperatorStateInfo ] = None
131+ }
123132}
0 commit comments