From 90595dd91a868ec146347659022ecb22eab5e430 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 7 Feb 2022 16:05:23 +0900 Subject: [PATCH 1/6] [SPARK-38124][SQL][SS] Introduce StatefulOpClusteredDistribution and apply to all stateful operators --- .../plans/physical/partitioning.scala | 75 +++++++++++++++++++ .../sql/execution/aggregate/AggUtils.scala | 20 +++-- .../aggregate/BaseAggregateExec.scala | 18 +++-- .../aggregate/HashAggregateExec.scala | 2 + .../aggregate/ObjectHashAggregateExec.scala | 2 + .../aggregate/SortAggregateExec.scala | 2 + .../exchange/ShuffleExchangeExec.scala | 10 +++ .../FlatMapGroupsWithStateExec.scala | 6 +- .../StreamingSymmetricHashJoinExec.scala | 4 +- .../streaming/statefulOperators.scala | 13 ++-- .../sql/streaming/StreamingJoinSuite.scala | 6 +- 11 files changed, 131 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 7a730c4b7318b..75b42807dd143 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -90,6 +90,34 @@ case class ClusteredDistribution( } } +/** + * Represents the requirement of distribution on the stateful operator. + * + * Each partition in stateful operator initializes state store(s), which are independent with state + * store(s) in other partitions. Since it is not possible to repartition the data in state store, + * Spark should make sure the physical partitioning of the stateful operator is unchanged across + * Spark versions. Violation of this requirement may bring silent correctness issue. + * + * Since this distribution relies on [[StatefulOpPartitioning]] on the physical partitioning of the + * stateful operator, only [[StatefulOpPartitioning]] can satisfy this distribution. + */ +case class StatefulOpClusteredDistribution( + expressions: Seq[Expression], + requiredNumPartitions: Option[Int] = None) extends Distribution { + require( + expressions != Nil, + "The expressions for hash of a StatefulOpClusteredDistribution should not be Nil. " + + "An AllTuples should be used to represent a distribution that only has " + + "a single partition.") + + override def createPartitioning(numPartitions: Int): Partitioning = { + assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions, + s"This StatefulOpClusteredDistribution requires ${requiredNumPartitions.get} " + + s"partitions, but the actual number of partitions is $numPartitions.") + StatefulOpPartitioning(expressions, numPartitions) + } +} + /** * Represents data where tuples have been ordered according to the `ordering` * [[Expression Expressions]]. Its requirement is defined as the following: @@ -231,6 +259,53 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) newChildren: IndexedSeq[Expression]): HashPartitioning = copy(expressions = newChildren) } +/** + * Represents the partitioning of stateful operator. + * + * This is basically hash partitioning, where rows are split up across partitions based on the hash + * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be + * in the same partition. + * + * Since we require stateful operator to retain the same physical partitioning during the lifetime + * of the query (including restart), the implementation of `partitionIdExpression` must be unchanged + * across Spark versions. + */ +case class StatefulOpPartitioning(expressions: Seq[Expression], numPartitions: Int) + extends Expression with Partitioning with Unevaluable { + + override def children: Seq[Expression] = expressions + override def nullable: Boolean = false + override def dataType: DataType = IntegerType + + override def satisfies0(required: Distribution): Boolean = { + super.satisfies0(required) || { + required match { + case h: StatefulOpClusteredDistribution => + expressions.length == h.expressions.length && expressions.zip(h.expressions).forall { + case (l, r) => l.semanticEquals(r) + } + case ClusteredDistribution(requiredClustering, _) => + expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) + case _ => false + } + } + } + + /** + * Returns an expression that will produce a valid partition ID(i.e. non-negative and is less + * than numPartitions) based on hashing expressions. + * + * NOTE: Spark must ensure this expression with specific tuple evaluates to the same value + * across Spark versions. + */ + def partitionIdExpression: Expression = Pmod(new Murmur3Hash(expressions), Literal(numPartitions)) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): StatefulOpPartitioning = { + copy(expressions = newChildren) + } +} + /** * Represents a partitioning where rows are split across partitions based on some total ordering of * the expressions specified in `ordering`. When data is partitioned in this manner, it guarantees: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 32db622c9f931..972e2aa85f24d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.internal.SQLConf @@ -46,6 +47,7 @@ object AggUtils { } private def createAggregate( + requiredChildDistributionOption: Option[Seq[Distribution]] = None, requiredChildDistributionExpressions: Option[Seq[Expression]] = None, groupingExpressions: Seq[NamedExpression] = Nil, aggregateExpressions: Seq[AggregateExpression] = Nil, @@ -59,6 +61,7 @@ object AggUtils { if (useHash && !forceSortAggregate) { HashAggregateExec( + requiredChildDistributionOption = requiredChildDistributionOption, requiredChildDistributionExpressions = requiredChildDistributionExpressions, groupingExpressions = groupingExpressions, aggregateExpressions = mayRemoveAggFilters(aggregateExpressions), @@ -72,6 +75,7 @@ object AggUtils { if (objectHashEnabled && useObjectHash && !forceSortAggregate) { ObjectHashAggregateExec( + requiredChildDistributionOption = requiredChildDistributionOption, requiredChildDistributionExpressions = requiredChildDistributionExpressions, groupingExpressions = groupingExpressions, aggregateExpressions = mayRemoveAggFilters(aggregateExpressions), @@ -81,6 +85,7 @@ object AggUtils { child = child) } else { SortAggregateExec( + requiredChildDistributionOption = requiredChildDistributionOption, requiredChildDistributionExpressions = requiredChildDistributionExpressions, groupingExpressions = groupingExpressions, aggregateExpressions = mayRemoveAggFilters(aggregateExpressions), @@ -299,12 +304,15 @@ object AggUtils { child = child) } + // This is only used to pick up the required child distribution for the stateful operator + val tempRestored = StateStoreRestoreExec(groupingAttributes, None, stateFormatVersion, + partialAggregate) + val partialMerged1: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) createAggregate( - requiredChildDistributionExpressions = - Some(groupingAttributes), + requiredChildDistributionOption = Some(tempRestored.requiredChildDistribution), groupingExpressions = groupingAttributes, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, @@ -314,15 +322,13 @@ object AggUtils { child = partialAggregate) } - val restored = StateStoreRestoreExec(groupingAttributes, None, stateFormatVersion, - partialMerged1) + val restored = tempRestored.copy(child = partialMerged1) val partialMerged2: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) createAggregate( - requiredChildDistributionExpressions = - Some(groupingAttributes), + requiredChildDistributionOption = Some(restored.requiredChildDistribution), groupingExpressions = groupingAttributes, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, @@ -349,7 +355,7 @@ object AggUtils { val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) createAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), + requiredChildDistributionOption = Some(restored.requiredChildDistribution), groupingExpressions = groupingAttributes, aggregateExpressions = finalAggregateExpressions, aggregateAttributes = finalAggregateAttributes, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala index b709c8092e46d..9f79dde2d66aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, ExplainUtil */ trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning { def requiredChildDistributionExpressions: Option[Seq[Expression]] + def requiredChildDistributionOption: Option[Seq[Distribution]] def groupingExpressions: Seq[NamedExpression] def aggregateExpressions: Seq[AggregateExpression] def aggregateAttributes: Seq[Attribute] @@ -90,10 +91,14 @@ trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning override protected def outputExpressions: Seq[NamedExpression] = resultExpressions override def requiredChildDistribution: List[Distribution] = { - requiredChildDistributionExpressions match { - case Some(exprs) if exprs.isEmpty => AllTuples :: Nil - case Some(exprs) => ClusteredDistribution(exprs) :: Nil - case None => UnspecifiedDistribution :: Nil + requiredChildDistributionOption match { + case Some(dist) => dist.toList + case _ => + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.isEmpty => AllTuples :: Nil + case Some(exprs) => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil + } } } @@ -102,7 +107,8 @@ trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning */ def toSortAggregate: SortAggregateExec = { SortAggregateExec( - requiredChildDistributionExpressions, groupingExpressions, aggregateExpressions, - aggregateAttributes, initialInputBufferOffset, resultExpressions, child) + requiredChildDistributionOption, requiredChildDistributionExpressions, groupingExpressions, + aggregateExpressions, aggregateAttributes, initialInputBufferOffset, resultExpressions, + child) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index ef0eb3e5da257..5856a1096bebb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution._ @@ -44,6 +45,7 @@ import org.apache.spark.util.Utils * Hash-based aggregate operator that can also fallback to sorting when data exceeds memory size. */ case class HashAggregateExec( + requiredChildDistributionOption: Option[Seq[Distribution]], requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index c98c9f42e69da..4c59b6003c171 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics @@ -58,6 +59,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics * }}} */ case class ObjectHashAggregateExec( + requiredChildDistributionOption: Option[Seq[Distribution]], requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index a0557822795af..c6878cb674b1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.{AliasAwareOutputOrdering, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -31,6 +32,7 @@ import org.apache.spark.sql.internal.SQLConf * Sort-based aggregate operator. */ case class SortAggregateExec( + requiredChildDistributionOption: Option[Seq[Distribution]], requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index c033aedc7786d..984d64d7f4fb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -274,6 +274,13 @@ object ShuffleExchangeExec { // `HashPartitioning.partitionIdExpression` to produce partitioning key. override def getPartition(key: Any): Int = key.asInstanceOf[Int] } + case StatefulOpPartitioning(_, n) => + new Partitioner { + override def numPartitions: Int = n + // For StatefulOpPartitioning, the partitioning key is already a valid partition ID, as + // we use `StatefulOpPartitioning.partitionIdExpression` to produce partitioning key. + override def getPartition(key: Any): Int = key.asInstanceOf[Int] + } case RangePartitioning(sortingExpressions, numPartitions) => // Extract only fields used for sorting to avoid collecting large fields that does not // affect sorting result when deciding partition bounds in RangePartitioner @@ -315,6 +322,9 @@ object ShuffleExchangeExec { case h: HashPartitioning => val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes) row => projection(row).getInt(0) + case h: StatefulOpPartitioning => + val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes) + row => projection(row).getInt(0) case RangePartitioning(sortingExpressions, _) => val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) row => projection(row) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index a00a62216f3dc..84578254ad47b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution} +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, StatefulOpClusteredDistribution} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._ import org.apache.spark.sql.execution.streaming.state._ @@ -93,8 +93,8 @@ case class FlatMapGroupsWithStateExec( * to have the same grouping so that the data are co-lacated on the same task. */ override def requiredChildDistribution: Seq[Distribution] = { - ClusteredDistribution(groupingAttributes, stateInfo.map(_.numPartitions)) :: - ClusteredDistribution(initialStateGroupAttrs, stateInfo.map(_.numPartitions)) :: + StatefulOpClusteredDistribution(groupingAttributes, stateInfo.map(_.numPartitions)) :: + StatefulOpClusteredDistribution(initialStateGroupAttrs, stateInfo.map(_.numPartitions)) :: Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 74b82451e029f..f25f24accceef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -185,8 +185,8 @@ case class StreamingSymmetricHashJoinExec( val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length) override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) :: - ClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil + StatefulOpClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) :: + StatefulOpClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil override def output: Seq[Attribute] = joinType match { case _: InnerLike => left.output ++ right.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 3431823765c1b..6a9770889d616 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark -import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, Partitioning, StatefulOpClusteredDistribution} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution._ @@ -337,7 +337,7 @@ case class StateStoreRestoreExec( if (keyExpressions.isEmpty) { AllTuples :: Nil } else { - ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil + StatefulOpClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil } } @@ -496,7 +496,7 @@ case class StateStoreSaveExec( if (keyExpressions.isEmpty) { AllTuples :: Nil } else { - ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil + StatefulOpClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil } } @@ -573,7 +573,8 @@ case class SessionWindowStateStoreRestoreExec( } override def requiredChildDistribution: Seq[Distribution] = { - ClusteredDistribution(keyWithoutSessionExpressions, stateInfo.map(_.numPartitions)) :: Nil + StatefulOpClusteredDistribution(keyWithoutSessionExpressions, + stateInfo.map(_.numPartitions)) :: Nil } override def requiredChildOrdering: Seq[Seq[SortOrder]] = { @@ -684,7 +685,7 @@ case class SessionWindowStateStoreSaveExec( override def outputPartitioning: Partitioning = child.outputPartitioning override def requiredChildDistribution: Seq[Distribution] = { - ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil + StatefulOpClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil } override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { @@ -742,7 +743,7 @@ case class StreamingDeduplicateExec( /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil + StatefulOpClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 5ec47bb2aa527..388a48c013b57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -29,7 +29,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} -import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.catalyst.plans.physical.StatefulOpPartitioning import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinExec, StreamingSymmetricHashJoinHelper} import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreProviderId} @@ -595,8 +595,8 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { assert(query.lastExecution.executedPlan.collect { case j @ StreamingSymmetricHashJoinExec(_, _, _, _, _, _, _, _, - ShuffleExchangeExec(opA: HashPartitioning, _, _), - ShuffleExchangeExec(opB: HashPartitioning, _, _)) + ShuffleExchangeExec(opA: StatefulOpPartitioning, _, _), + ShuffleExchangeExec(opB: StatefulOpPartitioning, _, _)) if partitionExpressionsColumns(opA.expressions) === Seq("a", "b") && partitionExpressionsColumns(opB.expressions) === Seq("a", "b") && opA.numPartitions == numPartitions && opB.numPartitions == numPartitions => j From 4639b7ecd5c24e482ad488208ce38a52702a66ef Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 7 Feb 2022 16:41:09 +0900 Subject: [PATCH 2/6] Remove StatefulOpPartitioning and merge to HashPartitioning --- .../plans/physical/partitioning.scala | 62 ++++--------------- .../sql/execution/aggregate/AggUtils.scala | 6 +- .../exchange/ShuffleExchangeExec.scala | 10 --- .../sql/streaming/StreamingJoinSuite.scala | 8 +-- 4 files changed, 20 insertions(+), 66 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 75b42807dd143..febb05c8c2f02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -98,8 +98,8 @@ case class ClusteredDistribution( * Spark should make sure the physical partitioning of the stateful operator is unchanged across * Spark versions. Violation of this requirement may bring silent correctness issue. * - * Since this distribution relies on [[StatefulOpPartitioning]] on the physical partitioning of the - * stateful operator, only [[StatefulOpPartitioning]] can satisfy this distribution. + * Since this distribution relies on [[HashPartitioning]] on the physical partitioning of the + * stateful operator, only [[HashPartitioning]] can satisfy this distribution. */ case class StatefulOpClusteredDistribution( expressions: Seq[Expression], @@ -114,7 +114,7 @@ case class StatefulOpClusteredDistribution( assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions, s"This StatefulOpClusteredDistribution requires ${requiredNumPartitions.get} " + s"partitions, but the actual number of partitions is $numPartitions.") - StatefulOpPartitioning(expressions, numPartitions) + HashPartitioning(expressions, numPartitions) } } @@ -228,49 +228,13 @@ case object SinglePartition extends Partitioning { * Represents a partitioning where rows are split up across partitions based on the hash * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be * in the same partition. - */ -case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) - extends Expression with Partitioning with Unevaluable { - - override def children: Seq[Expression] = expressions - override def nullable: Boolean = false - override def dataType: DataType = IntegerType - - override def satisfies0(required: Distribution): Boolean = { - super.satisfies0(required) || { - required match { - case ClusteredDistribution(requiredClustering, _) => - expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) - case _ => false - } - } - } - - override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = - HashShuffleSpec(this, distribution) - - /** - * Returns an expression that will produce a valid partition ID(i.e. non-negative and is less - * than numPartitions) based on hashing expressions. - */ - def partitionIdExpression: Expression = Pmod(new Murmur3Hash(expressions), Literal(numPartitions)) - - override protected def withNewChildrenInternal( - newChildren: IndexedSeq[Expression]): HashPartitioning = copy(expressions = newChildren) -} - -/** - * Represents the partitioning of stateful operator. - * - * This is basically hash partitioning, where rows are split up across partitions based on the hash - * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be - * in the same partition. * - * Since we require stateful operator to retain the same physical partitioning during the lifetime - * of the query (including restart), the implementation of `partitionIdExpression` must be unchanged - * across Spark versions. + * Since [[StatefulOpClusteredDistribution]] relies on this partitioning and Spark requires + * stateful operators to retain the same physical partitioning during the lifetime of the query + * (including restart), the result of evaluation on `partitionIdExpression` must be unchanged + * across Spark versions. Violation of this requirement may bring silent correctness issue. */ -case class StatefulOpPartitioning(expressions: Seq[Expression], numPartitions: Int) +case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) extends Expression with Partitioning with Unevaluable { override def children: Seq[Expression] = expressions @@ -291,19 +255,17 @@ case class StatefulOpPartitioning(expressions: Seq[Expression], numPartitions: I } } + override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = + HashShuffleSpec(this, distribution) + /** * Returns an expression that will produce a valid partition ID(i.e. non-negative and is less * than numPartitions) based on hashing expressions. - * - * NOTE: Spark must ensure this expression with specific tuple evaluates to the same value - * across Spark versions. */ def partitionIdExpression: Expression = Pmod(new Murmur3Hash(expressions), Literal(numPartitions)) override protected def withNewChildrenInternal( - newChildren: IndexedSeq[Expression]): StatefulOpPartitioning = { - copy(expressions = newChildren) - } + newChildren: IndexedSeq[Expression]): HashPartitioning = copy(expressions = newChildren) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 972e2aa85f24d..66aeb2bb655e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -304,7 +304,8 @@ object AggUtils { child = child) } - // This is only used to pick up the required child distribution for the stateful operator + // This is used temporarily to pick up the required child distribution for the stateful + // operator. val tempRestored = StateStoreRestoreExec(groupingAttributes, None, stateFormatVersion, partialAggregate) @@ -322,7 +323,8 @@ object AggUtils { child = partialAggregate) } - val restored = tempRestored.copy(child = partialMerged1) + val restored = StateStoreRestoreExec(groupingAttributes, None, stateFormatVersion, + partialMerged1) val partialMerged2: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 984d64d7f4fb2..c033aedc7786d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -274,13 +274,6 @@ object ShuffleExchangeExec { // `HashPartitioning.partitionIdExpression` to produce partitioning key. override def getPartition(key: Any): Int = key.asInstanceOf[Int] } - case StatefulOpPartitioning(_, n) => - new Partitioner { - override def numPartitions: Int = n - // For StatefulOpPartitioning, the partitioning key is already a valid partition ID, as - // we use `StatefulOpPartitioning.partitionIdExpression` to produce partitioning key. - override def getPartition(key: Any): Int = key.asInstanceOf[Int] - } case RangePartitioning(sortingExpressions, numPartitions) => // Extract only fields used for sorting to avoid collecting large fields that does not // affect sorting result when deciding partition bounds in RangePartitioner @@ -322,9 +315,6 @@ object ShuffleExchangeExec { case h: HashPartitioning => val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes) row => projection(row).getInt(0) - case h: StatefulOpPartitioning => - val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes) - row => projection(row).getInt(0) case RangePartitioning(sortingExpressions, _) => val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) row => projection(row) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 388a48c013b57..e0926ef0a82ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -29,7 +29,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} -import org.apache.spark.sql.catalyst.plans.physical.StatefulOpPartitioning +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinExec, StreamingSymmetricHashJoinHelper} import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreProviderId} @@ -571,7 +571,7 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { CheckNewAnswer((5, 10, 5, 15, 5, 25))) } - test("streaming join should require HashClusteredDistribution from children") { + test("streaming join should require StatefulOpClusteredDistribution from children") { val input1 = MemoryStream[Int] val input2 = MemoryStream[Int] @@ -595,8 +595,8 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { assert(query.lastExecution.executedPlan.collect { case j @ StreamingSymmetricHashJoinExec(_, _, _, _, _, _, _, _, - ShuffleExchangeExec(opA: StatefulOpPartitioning, _, _), - ShuffleExchangeExec(opB: StatefulOpPartitioning, _, _)) + ShuffleExchangeExec(opA: HashPartitioning, _, _), + ShuffleExchangeExec(opB: HashPartitioning, _, _)) if partitionExpressionsColumns(opA.expressions) === Seq("a", "b") && partitionExpressionsColumns(opB.expressions) === Seq("a", "b") && opA.numPartitions == numPartitions && opB.numPartitions == numPartitions => j From adfe79628bd0a3a4cd5336861f270d290170936a Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 7 Feb 2022 21:10:09 +0900 Subject: [PATCH 3/6] fix compilation --- .../org/apache/spark/sql/execution/WholeStageCodegenSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 7332d49b942f8..30fe9d2f904cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -742,7 +742,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession assert( executedPlan.find { case WholeStageCodegenExec( - HashAggregateExec(_, _, _, _, _, _, _: LocalTableScanExec)) => true + HashAggregateExec(_, _, _, _, _, _, _, _: LocalTableScanExec)) => true case _ => false }.isDefined, "LocalTableScanExec should be within a WholeStageCodegen domain.") From d722b2cd6b741e3784601b98d69ff953de600103 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 8 Feb 2022 16:20:37 +0900 Subject: [PATCH 4/6] Revert the changes in non stream-stream join operators --- .../sql/execution/aggregate/AggUtils.scala | 18 +++++------------- .../aggregate/BaseAggregateExec.scala | 18 ++++++------------ .../aggregate/HashAggregateExec.scala | 2 -- .../aggregate/ObjectHashAggregateExec.scala | 2 -- .../aggregate/SortAggregateExec.scala | 2 -- .../streaming/FlatMapGroupsWithStateExec.scala | 6 +++--- .../streaming/statefulOperators.scala | 13 ++++++------- .../sql/execution/WholeStageCodegenSuite.scala | 2 +- 8 files changed, 21 insertions(+), 42 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 66aeb2bb655e4..32db622c9f931 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.internal.SQLConf @@ -47,7 +46,6 @@ object AggUtils { } private def createAggregate( - requiredChildDistributionOption: Option[Seq[Distribution]] = None, requiredChildDistributionExpressions: Option[Seq[Expression]] = None, groupingExpressions: Seq[NamedExpression] = Nil, aggregateExpressions: Seq[AggregateExpression] = Nil, @@ -61,7 +59,6 @@ object AggUtils { if (useHash && !forceSortAggregate) { HashAggregateExec( - requiredChildDistributionOption = requiredChildDistributionOption, requiredChildDistributionExpressions = requiredChildDistributionExpressions, groupingExpressions = groupingExpressions, aggregateExpressions = mayRemoveAggFilters(aggregateExpressions), @@ -75,7 +72,6 @@ object AggUtils { if (objectHashEnabled && useObjectHash && !forceSortAggregate) { ObjectHashAggregateExec( - requiredChildDistributionOption = requiredChildDistributionOption, requiredChildDistributionExpressions = requiredChildDistributionExpressions, groupingExpressions = groupingExpressions, aggregateExpressions = mayRemoveAggFilters(aggregateExpressions), @@ -85,7 +81,6 @@ object AggUtils { child = child) } else { SortAggregateExec( - requiredChildDistributionOption = requiredChildDistributionOption, requiredChildDistributionExpressions = requiredChildDistributionExpressions, groupingExpressions = groupingExpressions, aggregateExpressions = mayRemoveAggFilters(aggregateExpressions), @@ -304,16 +299,12 @@ object AggUtils { child = child) } - // This is used temporarily to pick up the required child distribution for the stateful - // operator. - val tempRestored = StateStoreRestoreExec(groupingAttributes, None, stateFormatVersion, - partialAggregate) - val partialMerged1: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) createAggregate( - requiredChildDistributionOption = Some(tempRestored.requiredChildDistribution), + requiredChildDistributionExpressions = + Some(groupingAttributes), groupingExpressions = groupingAttributes, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, @@ -330,7 +321,8 @@ object AggUtils { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) createAggregate( - requiredChildDistributionOption = Some(restored.requiredChildDistribution), + requiredChildDistributionExpressions = + Some(groupingAttributes), groupingExpressions = groupingAttributes, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, @@ -357,7 +349,7 @@ object AggUtils { val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) createAggregate( - requiredChildDistributionOption = Some(restored.requiredChildDistribution), + requiredChildDistributionExpressions = Some(groupingAttributes), groupingExpressions = groupingAttributes, aggregateExpressions = finalAggregateExpressions, aggregateAttributes = finalAggregateAttributes, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala index 9f79dde2d66aa..b709c8092e46d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, ExplainUtil */ trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning { def requiredChildDistributionExpressions: Option[Seq[Expression]] - def requiredChildDistributionOption: Option[Seq[Distribution]] def groupingExpressions: Seq[NamedExpression] def aggregateExpressions: Seq[AggregateExpression] def aggregateAttributes: Seq[Attribute] @@ -91,14 +90,10 @@ trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning override protected def outputExpressions: Seq[NamedExpression] = resultExpressions override def requiredChildDistribution: List[Distribution] = { - requiredChildDistributionOption match { - case Some(dist) => dist.toList - case _ => - requiredChildDistributionExpressions match { - case Some(exprs) if exprs.isEmpty => AllTuples :: Nil - case Some(exprs) => ClusteredDistribution(exprs) :: Nil - case None => UnspecifiedDistribution :: Nil - } + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.isEmpty => AllTuples :: Nil + case Some(exprs) => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil } } @@ -107,8 +102,7 @@ trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning */ def toSortAggregate: SortAggregateExec = { SortAggregateExec( - requiredChildDistributionOption, requiredChildDistributionExpressions, groupingExpressions, - aggregateExpressions, aggregateAttributes, initialInputBufferOffset, resultExpressions, - child) + requiredChildDistributionExpressions, groupingExpressions, aggregateExpressions, + aggregateAttributes, initialInputBufferOffset, resultExpressions, child) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 5856a1096bebb..ef0eb3e5da257 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -30,7 +30,6 @@ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution._ @@ -45,7 +44,6 @@ import org.apache.spark.util.Utils * Hash-based aggregate operator that can also fallback to sorting when data exceeds memory size. */ case class HashAggregateExec( - requiredChildDistributionOption: Option[Seq[Distribution]], requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 4c59b6003c171..c98c9f42e69da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -23,7 +23,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics @@ -59,7 +58,6 @@ import org.apache.spark.sql.execution.metric.SQLMetrics * }}} */ case class ObjectHashAggregateExec( - requiredChildDistributionOption: Option[Seq[Distribution]], requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index c6878cb674b1e..a0557822795af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.{AliasAwareOutputOrdering, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -32,7 +31,6 @@ import org.apache.spark.sql.internal.SQLConf * Sort-based aggregate operator. */ case class SortAggregateExec( - requiredChildDistributionOption: Option[Seq[Distribution]], requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 84578254ad47b..a00a62216f3dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.physical.{Distribution, StatefulOpClusteredDistribution} +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._ import org.apache.spark.sql.execution.streaming.state._ @@ -93,8 +93,8 @@ case class FlatMapGroupsWithStateExec( * to have the same grouping so that the data are co-lacated on the same task. */ override def requiredChildDistribution: Seq[Distribution] = { - StatefulOpClusteredDistribution(groupingAttributes, stateInfo.map(_.numPartitions)) :: - StatefulOpClusteredDistribution(initialStateGroupAttrs, stateInfo.map(_.numPartitions)) :: + ClusteredDistribution(groupingAttributes, stateInfo.map(_.numPartitions)) :: + ClusteredDistribution(initialStateGroupAttrs, stateInfo.map(_.numPartitions)) :: Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 6a9770889d616..3431823765c1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark -import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, Partitioning, StatefulOpClusteredDistribution} +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution._ @@ -337,7 +337,7 @@ case class StateStoreRestoreExec( if (keyExpressions.isEmpty) { AllTuples :: Nil } else { - StatefulOpClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil + ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil } } @@ -496,7 +496,7 @@ case class StateStoreSaveExec( if (keyExpressions.isEmpty) { AllTuples :: Nil } else { - StatefulOpClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil + ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil } } @@ -573,8 +573,7 @@ case class SessionWindowStateStoreRestoreExec( } override def requiredChildDistribution: Seq[Distribution] = { - StatefulOpClusteredDistribution(keyWithoutSessionExpressions, - stateInfo.map(_.numPartitions)) :: Nil + ClusteredDistribution(keyWithoutSessionExpressions, stateInfo.map(_.numPartitions)) :: Nil } override def requiredChildOrdering: Seq[Seq[SortOrder]] = { @@ -685,7 +684,7 @@ case class SessionWindowStateStoreSaveExec( override def outputPartitioning: Partitioning = child.outputPartitioning override def requiredChildDistribution: Seq[Distribution] = { - StatefulOpClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil + ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil } override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { @@ -743,7 +742,7 @@ case class StreamingDeduplicateExec( /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = - StatefulOpClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil + ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 30fe9d2f904cf..7332d49b942f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -742,7 +742,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession assert( executedPlan.find { case WholeStageCodegenExec( - HashAggregateExec(_, _, _, _, _, _, _, _: LocalTableScanExec)) => true + HashAggregateExec(_, _, _, _, _, _, _: LocalTableScanExec)) => true case _ => false }.isDefined, "LocalTableScanExec should be within a WholeStageCodegen domain.") From 4b68d289d32313e75fa9c7c08f63f794a323bc5c Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 8 Feb 2022 17:21:53 +0900 Subject: [PATCH 5/6] update comment --- .../spark/sql/catalyst/plans/physical/partitioning.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index febb05c8c2f02..591c58b62f227 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -99,7 +99,8 @@ case class ClusteredDistribution( * Spark versions. Violation of this requirement may bring silent correctness issue. * * Since this distribution relies on [[HashPartitioning]] on the physical partitioning of the - * stateful operator, only [[HashPartitioning]] can satisfy this distribution. + * stateful operator, only [[HashPartitioning]] (and HashPartitioning in + * [[PartitioningCollection]]) can satisfy this distribution. */ case class StatefulOpClusteredDistribution( expressions: Seq[Expression], From 753a5b6c62c88f1a9e0318ebcd655cb04beee058 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 9 Feb 2022 06:14:15 +0900 Subject: [PATCH 6/6] reflect review comments --- .../sql/catalyst/plans/physical/partitioning.scala | 10 ++++++---- .../streaming/StreamingSymmetricHashJoinExec.scala | 4 ++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 591c58b62f227..4418d3253a8b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -91,7 +91,7 @@ case class ClusteredDistribution( } /** - * Represents the requirement of distribution on the stateful operator. + * Represents the requirement of distribution on the stateful operator in Structured Streaming. * * Each partition in stateful operator initializes state store(s), which are independent with state * store(s) in other partitions. Since it is not possible to repartition the data in state store, @@ -104,16 +104,18 @@ case class ClusteredDistribution( */ case class StatefulOpClusteredDistribution( expressions: Seq[Expression], - requiredNumPartitions: Option[Int] = None) extends Distribution { + _requiredNumPartitions: Int) extends Distribution { require( expressions != Nil, "The expressions for hash of a StatefulOpClusteredDistribution should not be Nil. " + "An AllTuples should be used to represent a distribution that only has " + "a single partition.") + override val requiredNumPartitions: Option[Int] = Some(_requiredNumPartitions) + override def createPartitioning(numPartitions: Int): Partitioning = { - assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions, - s"This StatefulOpClusteredDistribution requires ${requiredNumPartitions.get} " + + assert(_requiredNumPartitions == numPartitions, + s"This StatefulOpClusteredDistribution requires ${_requiredNumPartitions} " + s"partitions, but the actual number of partitions is $numPartitions.") HashPartitioning(expressions, numPartitions) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index f25f24accceef..adb84a3b7d3fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -185,8 +185,8 @@ case class StreamingSymmetricHashJoinExec( val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length) override def requiredChildDistribution: Seq[Distribution] = - StatefulOpClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) :: - StatefulOpClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil + StatefulOpClusteredDistribution(leftKeys, getStateInfo.numPartitions) :: + StatefulOpClusteredDistribution(rightKeys, getStateInfo.numPartitions) :: Nil override def output: Seq[Attribute] = joinType match { case _: InnerLike => left.output ++ right.output