Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,42 +91,42 @@ case class ClusteredDistribution(
}

/**
* Represents the requirement of distribution on the stateful operator in Structured Streaming.
* Represents data where tuples have been clustered according to the hash of the given
* `expressions`. Since this distribution relies on [[HashPartitioning]] on the physical
* partitioning, only [[HashPartitioning]] (and HashPartitioning in [[PartitioningCollection]])
* can satisfy this distribution. When `requiredNumPartitions` is Some(1), [[SinglePartition]]
* is essentially same as [[HashPartitioning]], so it can satisfy this distribution as well.
*
* Each partition in stateful operator initializes state store(s), which are independent with state
* store(s) in other partitions. Since it is not possible to repartition the data in state store,
* Spark should make sure the physical partitioning of the stateful operator is unchanged across
* Spark versions. Violation of this requirement may bring silent correctness issue.
* This distribution is used majorly to represent the requirement of distribution on the stateful
* operator in Structured Streaming, but this can be used for other cases as well.
*
* Since this distribution relies on [[HashPartitioning]] on the physical partitioning of the
* stateful operator, only [[HashPartitioning]] (and HashPartitioning in
* [[PartitioningCollection]]) can satisfy this distribution.
* When `_requiredNumPartitions` is 1, [[SinglePartition]] is essentially same as
* [[HashPartitioning]], so it can satisfy this distribution as well.
* NOTE 1: Each partition in stateful operator initializes state store(s), which are independent
* with state store(s) in other partitions. Since it is not possible to repartition the data in
* state store, Spark should make sure the physical partitioning of the stateful operator is
* unchanged across Spark versions. Violation of this requirement may bring silent correctness
* issue.
*
* NOTE: This is applied only to stream-stream join as of now. For other stateful operators, we
* have been using ClusteredDistribution, which could construct the physical partitioning of the
* state in different way (ClusteredDistribution requires relaxed condition and multiple
* partitionings can satisfy the requirement.) We need to construct the way to fix this with
* minimizing possibility to break the existing checkpoints.
* NOTE 2: This is applied only to stream-stream join for stateful operators as of now. For other
* stateful operators, we have been using ClusteredDistribution, which could construct the physical
* partitioning of the state in different way (ClusteredDistribution requires relaxed condition
* and multiple partitionings can satisfy the requirement.) We need to construct the way to fix
* this with minimizing possibility to break the existing checkpoints.
*
* TODO(SPARK-38204): address the issue explained in above note.
* TODO(SPARK-38204): address the issue explained in note 2.
*/
case class StatefulOpClusteredDistribution(
case class HashClusteredDistribution(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and below lines basically restore the implementation of HashClusteredDistribution.

expressions: Seq[Expression],
_requiredNumPartitions: Int) extends Distribution {
requiredNumPartitions: Option[Int] = None) extends Distribution {
require(
expressions != Nil,
"The expressions for hash of a StatefulOpClusteredDistribution should not be Nil. " +
"The expressions for hash of a HashClusteredDistribution should not be Nil. " +
"An AllTuples should be used to represent a distribution that only has " +
"a single partition.")

override val requiredNumPartitions: Option[Int] = Some(_requiredNumPartitions)

override def createPartitioning(numPartitions: Int): Partitioning = {
assert(_requiredNumPartitions == numPartitions,
s"This StatefulOpClusteredDistribution requires ${_requiredNumPartitions} " +
s"partitions, but the actual number of partitions is $numPartitions.")
assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions,
s"This HashClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " +
s"the actual number of partitions is $numPartitions.")
HashPartitioning(expressions, numPartitions)
}
}
Expand Down Expand Up @@ -242,7 +242,7 @@ case object SinglePartition extends Partitioning {
* of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be
* in the same partition.
*
* Since [[StatefulOpClusteredDistribution]] relies on this partitioning and Spark requires
* Since [[HashClusteredDistribution]] relies on this partitioning and Spark requires
* stateful operators to retain the same physical partitioning during the lifetime of the query
* (including restart), the result of evaluation on `partitionIdExpression` must be unchanged
* across Spark versions. Violation of this requirement may bring silent correctness issue.
Expand All @@ -257,7 +257,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
override def satisfies0(required: Distribution): Boolean = {
super.satisfies0(required) || {
required match {
case h: StatefulOpClusteredDistribution =>
case h: HashClusteredDistribution =>
expressions.length == h.expressions.length && expressions.zip(h.expressions).forall {
case (l, r) => l.semanticEquals(r)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,16 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val REQUIRE_ALL_CLUSTER_KEYS_FOR_AGGREGATE =
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I picked up the similar config name with similar description in above config (spark.sql.requireAllClusterKeysForCoPartition) since the goal is very similar.

buildConf("spark.sql.aggregate.requireAllClusterKeys")
.internal()
.doc("When true, aggregate operator requires all the clustering keys as the hash partition" +
" keys from child. This is to avoid data skews which can lead to significant " +
"performance regression if shuffles are eliminated.")
.version("3.3.0")
.booleanConf
.createWithDefault(false)

val RADIX_SORT_ENABLED = buildConf("spark.sql.sort.enableRadixSort")
.internal()
.doc("When true, enable use of radix sort when possible. Radix sort is much faster but " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,65 +282,65 @@ class DistributionSuite extends SparkFunSuite {
}

// Validate only HashPartitioning (and HashPartitioning in PartitioningCollection) can satisfy
// StatefulOpClusteredDistribution. SinglePartition can also satisfy this distribution when
// `_requiredNumPartitions` is 1.
// HashClusteredDistribution. SinglePartition can also satisfy this distribution when
// `requiredNumPartitions` is Some(1).
checkSatisfied(
HashPartitioning(Seq($"a", $"b", $"c"), 10),
StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 10),
HashClusteredDistribution(Seq($"a", $"b", $"c"), Some(10)),
true)

checkSatisfied(
PartitioningCollection(Seq(
HashPartitioning(Seq($"a", $"b", $"c"), 10),
RangePartitioning(Seq($"a".asc, $"b".asc, $"c".asc), 10))),
StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 10),
HashClusteredDistribution(Seq($"a", $"b", $"c"), Some(10)),
true)

checkSatisfied(
SinglePartition,
StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 1),
HashClusteredDistribution(Seq($"a", $"b", $"c"), Some(1)),
true)

checkSatisfied(
PartitioningCollection(Seq(
HashPartitioning(Seq($"a", $"b"), 1),
SinglePartition)),
StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 1),
HashClusteredDistribution(Seq($"a", $"b", $"c"), Some(1)),
true)

checkSatisfied(
HashPartitioning(Seq($"a", $"b"), 10),
StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 10),
HashClusteredDistribution(Seq($"a", $"b", $"c"), Some(10)),
false)

checkSatisfied(
HashPartitioning(Seq($"a", $"b", $"c"), 5),
StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 10),
HashClusteredDistribution(Seq($"a", $"b", $"c"), Some(10)),
false)

checkSatisfied(
RangePartitioning(Seq($"a".asc, $"b".asc, $"c".asc), 10),
StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 10),
HashClusteredDistribution(Seq($"a", $"b", $"c"), Some(10)),
false)

checkSatisfied(
SinglePartition,
StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 10),
HashClusteredDistribution(Seq($"a", $"b", $"c"), Some(10)),
false)

checkSatisfied(
BroadcastPartitioning(IdentityBroadcastMode),
StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 10),
HashClusteredDistribution(Seq($"a", $"b", $"c"), Some(10)),
false)

checkSatisfied(
RoundRobinPartitioning(10),
StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 10),
HashClusteredDistribution(Seq($"a", $"b", $"c"), Some(10)),
false)

checkSatisfied(
UnknownPartitioning(10),
StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 10),
HashClusteredDistribution(Seq($"a", $"b", $"c"), Some(10)),
false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ package org.apache.spark.sql.execution.aggregate

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Final, PartialMerge}
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, HashClusteredDistribution, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, ExplainUtils, UnaryExecNode}
import org.apache.spark.sql.internal.SQLConf

/**
* Holds common logic for aggregate operators
Expand Down Expand Up @@ -92,7 +93,12 @@ trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning
override def requiredChildDistribution: List[Distribution] = {
requiredChildDistributionExpressions match {
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
case Some(exprs) => ClusteredDistribution(exprs) :: Nil
case Some(exprs) =>
if (conf.getConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_AGGREGATE)) {
HashClusteredDistribution(exprs) :: Nil
} else {
ClusteredDistribution(exprs) :: Nil
}
case None => UnspecifiedDistribution :: Nil
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ case class FlatMapGroupsWithStateExec(
* to have the same grouping so that the data are co-lacated on the same task.
*/
override def requiredChildDistribution: Seq[Distribution] = {
// NOTE: Please read through the NOTE on the classdoc of StatefulOpClusteredDistribution
// NOTE: Please read through the NOTE on the classdoc of HashClusteredDistribution
// before making any changes.
// TODO(SPARK-38204)
ClusteredDistribution(groupingAttributes, stateInfo.map(_.numPartitions)) ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,9 @@ abstract class StreamExecution(
// Disable cost-based join optimization as we do not want stateful operations
// to be rearranged
sparkSessionForStream.conf.set(SQLConf.CBO_ENABLED.key, "false")
// Disable any config affecting the required child distribution of stateful operators.
// Please read through the NOTE on the classdoc of HashClusteredDistribution for details.
sparkSessionForStream.conf.set(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_AGGREGATE.key, "false")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is super important. The new config should never be set to true before we fix the fundamental problem with considering backward compatibility, since stateful operator would follow the changed output partitioning as well.


updateStatusMessage("Initializing sources")
// force initialization of the logical plan so that the sources can be created
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ case class StreamingSymmetricHashJoinExec(
val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length)

override def requiredChildDistribution: Seq[Distribution] =
StatefulOpClusteredDistribution(leftKeys, getStateInfo.numPartitions) ::
StatefulOpClusteredDistribution(rightKeys, getStateInfo.numPartitions) :: Nil
HashClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) ::
HashClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil

override def output: Seq[Attribute] = joinType match {
case _: InnerLike => left.output ++ right.output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ case class StateStoreRestoreExec(
override def outputPartitioning: Partitioning = child.outputPartitioning

override def requiredChildDistribution: Seq[Distribution] = {
// NOTE: Please read through the NOTE on the classdoc of StatefulOpClusteredDistribution
// NOTE: Please read through the NOTE on the classdoc of HashClusteredDistribution
// before making any changes.
// TODO(SPARK-38204)
if (keyExpressions.isEmpty) {
Expand Down Expand Up @@ -496,7 +496,7 @@ case class StateStoreSaveExec(
override def outputPartitioning: Partitioning = child.outputPartitioning

override def requiredChildDistribution: Seq[Distribution] = {
// NOTE: Please read through the NOTE on the classdoc of StatefulOpClusteredDistribution
// NOTE: Please read through the NOTE on the classdoc of HashClusteredDistribution
// before making any changes.
// TODO(SPARK-38204)
if (keyExpressions.isEmpty) {
Expand Down Expand Up @@ -579,7 +579,7 @@ case class SessionWindowStateStoreRestoreExec(
}

override def requiredChildDistribution: Seq[Distribution] = {
// NOTE: Please read through the NOTE on the classdoc of StatefulOpClusteredDistribution
// NOTE: Please read through the NOTE on the classdoc of HashClusteredDistribution
// before making any changes.
// TODO(SPARK-38204)
ClusteredDistribution(keyWithoutSessionExpressions, stateInfo.map(_.numPartitions)) :: Nil
Expand Down Expand Up @@ -693,7 +693,7 @@ case class SessionWindowStateStoreSaveExec(
override def outputPartitioning: Partitioning = child.outputPartitioning

override def requiredChildDistribution: Seq[Distribution] = {
// NOTE: Please read through the NOTE on the classdoc of StatefulOpClusteredDistribution
// NOTE: Please read through the NOTE on the classdoc of HashClusteredDistribution
// before making any changes.
// TODO(SPARK-38204)
ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
Expand Down Expand Up @@ -754,7 +754,7 @@ case class StreamingDeduplicateExec(

/** Distribute by grouping attributes */
override def requiredChildDistribution: Seq[Distribution] = {
// NOTE: Please read through the NOTE on the classdoc of StatefulOpClusteredDistribution
// NOTE: Please read through the NOTE on the classdoc of HashClusteredDistribution
// before making any changes.
// TODO(SPARK-38204)
ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ import scala.util.Random
import org.scalatest.matchers.must.Matchers.the

import org.apache.spark.SparkException
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.execution.{InputAdapter, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -1453,6 +1455,57 @@ class DataFrameAggregateSuite extends QueryTest
val df = Seq(1).toDF("id").groupBy(Stream($"id" + 1, $"id" + 2): _*).sum("id")
checkAnswer(df, Row(2, 3, 1))
}

test("SPARK-38237: require all cluster keys for child required distribution") {
def partitionExpressionsColumns(expressions: Seq[Expression]): Seq[String] = {
expressions.flatMap {
case ref: AttributeReference => Some(ref.name)
}
}

def isShuffleExecByRequirement(
plan: ShuffleExchangeExec,
desiredClusterColumns: Seq[String],
desiredNumPartitions: Int): Boolean = plan match {
case ShuffleExchangeExec(op: HashPartitioning, _, ENSURE_REQUIREMENTS)
if partitionExpressionsColumns(op.expressions) === desiredClusterColumns &&
op.numPartitions === desiredNumPartitions => true

case _ => false
}

val df = Seq(("a", 1, 1), ("a", 2, 2), ("b", 1, 3), ("b", 1, 4)).toDF("key1", "key2", "value")

withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_AGGREGATE.key -> "true") {

val grouped = df
// repartition by sub group keys which satisfies ClusteredDistribution(group keys)
.repartition($"key1")
.groupBy($"key1", $"key2")
.agg(sum($"value"))

checkAnswer(grouped, Seq(Row("a", 1, 1), Row("a", 2, 2), Row("b", 1, 7)))

val numPartitions = spark.sqlContext.conf.getConf(SQLConf.SHUFFLE_PARTITIONS)

val shuffleByRequirement = grouped.queryExecution.executedPlan.flatMap {
case a if a.isInstanceOf[BaseAggregateExec] =>
a.children.head match {
case InputAdapter(s: ShuffleExchangeExec)
if isShuffleExecByRequirement(s, Seq("key1", "key2"), numPartitions) => Some(s)
case s: ShuffleExchangeExec
if isShuffleExecByRequirement(s, Seq("key1", "key2"), numPartitions) => Some(s)
case _ => None
}

case _ => None
}

assert(shuffleByRequirement.nonEmpty, "Can't find desired shuffle node from the query plan")
}
}
}

case class B(c: Option[Double])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite {
CheckNewAnswer((5, 10, 5, 15, 5, 25)))
}

test("streaming join should require StatefulOpClusteredDistribution from children") {
test("streaming join should require HashClusteredDistribution from children") {
val input1 = MemoryStream[Int]
val input2 = MemoryStream[Int]

Expand Down