diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 7a730c4b7318b..4418d3253a8b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -90,6 +90,37 @@ case class ClusteredDistribution( } } +/** + * Represents the requirement of distribution on the stateful operator in Structured Streaming. + * + * Each partition in stateful operator initializes state store(s), which are independent with state + * store(s) in other partitions. Since it is not possible to repartition the data in state store, + * Spark should make sure the physical partitioning of the stateful operator is unchanged across + * Spark versions. Violation of this requirement may bring silent correctness issue. + * + * Since this distribution relies on [[HashPartitioning]] on the physical partitioning of the + * stateful operator, only [[HashPartitioning]] (and HashPartitioning in + * [[PartitioningCollection]]) can satisfy this distribution. + */ +case class StatefulOpClusteredDistribution( + expressions: Seq[Expression], + _requiredNumPartitions: Int) extends Distribution { + require( + expressions != Nil, + "The expressions for hash of a StatefulOpClusteredDistribution should not be Nil. " + + "An AllTuples should be used to represent a distribution that only has " + + "a single partition.") + + override val requiredNumPartitions: Option[Int] = Some(_requiredNumPartitions) + + override def createPartitioning(numPartitions: Int): Partitioning = { + assert(_requiredNumPartitions == numPartitions, + s"This StatefulOpClusteredDistribution requires ${_requiredNumPartitions} " + + s"partitions, but the actual number of partitions is $numPartitions.") + HashPartitioning(expressions, numPartitions) + } +} + /** * Represents data where tuples have been ordered according to the `ordering` * [[Expression Expressions]]. Its requirement is defined as the following: @@ -200,6 +231,11 @@ case object SinglePartition extends Partitioning { * Represents a partitioning where rows are split up across partitions based on the hash * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be * in the same partition. + * + * Since [[StatefulOpClusteredDistribution]] relies on this partitioning and Spark requires + * stateful operators to retain the same physical partitioning during the lifetime of the query + * (including restart), the result of evaluation on `partitionIdExpression` must be unchanged + * across Spark versions. Violation of this requirement may bring silent correctness issue. */ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) extends Expression with Partitioning with Unevaluable { @@ -211,6 +247,10 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def satisfies0(required: Distribution): Boolean = { super.satisfies0(required) || { required match { + case h: StatefulOpClusteredDistribution => + expressions.length == h.expressions.length && expressions.zip(h.expressions).forall { + case (l, r) => l.semanticEquals(r) + } case ClusteredDistribution(requiredClustering, _) => expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) case _ => false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 74b82451e029f..adb84a3b7d3fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -185,8 +185,8 @@ case class StreamingSymmetricHashJoinExec( val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length) override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) :: - ClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil + StatefulOpClusteredDistribution(leftKeys, getStateInfo.numPartitions) :: + StatefulOpClusteredDistribution(rightKeys, getStateInfo.numPartitions) :: Nil override def output: Seq[Attribute] = joinType match { case _: InnerLike => left.output ++ right.output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 5ec47bb2aa527..e0926ef0a82ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -571,7 +571,7 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { CheckNewAnswer((5, 10, 5, 15, 5, 25))) } - test("streaming join should require HashClusteredDistribution from children") { + test("streaming join should require StatefulOpClusteredDistribution from children") { val input1 = MemoryStream[Int] val input2 = MemoryStream[Int]