Skip to content

Commit f34fc8a

Browse files
committed
address
1 parent be39125 commit f34fc8a

File tree

1 file changed

+28
-19
lines changed

1 file changed

+28
-19
lines changed
Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchange}
2929
import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata, StatefulOperator, StatefulOperatorStateInfo}
3030
import org.apache.spark.sql.test.SharedSQLContext
3131

32-
class IncrementalExecutionRulesSuite extends SparkPlanTest with SharedSQLContext {
32+
class EnsureStatefulOpPartitioningSuite extends SparkPlanTest with SharedSQLContext {
3333

3434
import testImplicits._
3535
super.beforeAll()
@@ -39,39 +39,46 @@ class IncrementalExecutionRulesSuite extends SparkPlanTest with SharedSQLContext
3939
testEnsureStatefulOpPartitioning(
4040
"ClusteredDistribution generates Exchange with HashPartitioning",
4141
baseDf.queryExecution.sparkPlan,
42-
keys => ClusteredDistribution(keys),
43-
keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions),
42+
requiredDistribution = keys => ClusteredDistribution(keys),
43+
expectedPartitioning =
44+
keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions),
4445
expectShuffle = true)
4546

4647
testEnsureStatefulOpPartitioning(
4748
"ClusteredDistribution with coalesce(1) generates Exchange with HashPartitioning",
4849
baseDf.coalesce(1).queryExecution.sparkPlan,
49-
keys => ClusteredDistribution(keys),
50-
keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions),
50+
requiredDistribution = keys => ClusteredDistribution(keys),
51+
expectedPartitioning =
52+
keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions),
5153
expectShuffle = true)
5254

5355
testEnsureStatefulOpPartitioning(
5456
"AllTuples generates Exchange with SinglePartition",
5557
baseDf.queryExecution.sparkPlan,
56-
keys => AllTuples,
57-
keys => SinglePartition,
58+
requiredDistribution = _ => AllTuples,
59+
expectedPartitioning = _ => SinglePartition,
5860
expectShuffle = true)
5961

6062
testEnsureStatefulOpPartitioning(
6163
"AllTuples with coalesce(1) doesn't need Exchange",
6264
baseDf.coalesce(1).queryExecution.sparkPlan,
63-
keys => AllTuples,
64-
keys => SinglePartition,
65+
requiredDistribution = _ => AllTuples,
66+
expectedPartitioning = _ => SinglePartition,
6567
expectShuffle = false)
6668

69+
/**
70+
* For `StatefulOperator` with the given `requiredChildDistribution`, and child SparkPlan
71+
* `inputPlan`, ensures that the incremental planner adds exchanges, if required, in order to
72+
* ensure the expected partitioning.
73+
*/
6774
private def testEnsureStatefulOpPartitioning(
6875
testName: String,
6976
inputPlan: SparkPlan,
7077
requiredDistribution: Seq[Attribute] => Distribution,
7178
expectedPartitioning: Seq[Attribute] => Partitioning,
7279
expectShuffle: Boolean): Unit = {
73-
test("EnsureStatefulOpPartitioning - " + testName) {
74-
val operator = TestOperator(inputPlan, requiredDistribution(inputPlan.output.take(1)))
80+
test(testName) {
81+
val operator = TestStatefulOperator(inputPlan, requiredDistribution(inputPlan.output.take(1)))
7582
val executed = executePlan(operator, OutputMode.Complete())
7683
if (expectShuffle) {
7784
val exchange = executed.children.find(_.isInstanceOf[Exchange])
@@ -88,6 +95,7 @@ class IncrementalExecutionRulesSuite extends SparkPlanTest with SharedSQLContext
8895
}
8996
}
9097

98+
/** Executes a SparkPlan using the IncrementalPlanner used for Structured Streaming. */
9199
private def executePlan(
92100
p: SparkPlan,
93101
outputMode: OutputMode = OutputMode.Append()): SparkPlan = {
@@ -111,13 +119,14 @@ class IncrementalExecutionRulesSuite extends SparkPlanTest with SharedSQLContext
111119
}
112120
execution.executedPlan
113121
}
114-
}
115122

116-
case class TestOperator(
117-
child: SparkPlan,
118-
requiredDist: Distribution) extends UnaryExecNode with StatefulOperator {
119-
override def output: Seq[Attribute] = child.output
120-
override def doExecute(): RDD[InternalRow] = child.execute()
121-
override def requiredChildDistribution: Seq[Distribution] = requiredDist :: Nil
122-
override def stateInfo: Option[StatefulOperatorStateInfo] = None
123+
/** Used to emulate a [[StatefulOperator]] with the given requiredDistribution. */
124+
case class TestStatefulOperator(
125+
child: SparkPlan,
126+
requiredDist: Distribution) extends UnaryExecNode with StatefulOperator {
127+
override def output: Seq[Attribute] = child.output
128+
override def doExecute(): RDD[InternalRow] = child.execute()
129+
override def requiredChildDistribution: Seq[Distribution] = requiredDist :: Nil
130+
override def stateInfo: Option[StatefulOperatorStateInfo] = None
131+
}
123132
}

0 commit comments

Comments
 (0)