Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ import java.util.UUID
import java.util.concurrent.atomic.AtomicInteger

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SparkSession, Strategy}
import org.apache.spark.sql.{AnalysisException, SparkSession, Strategy}
import org.apache.spark.sql.catalyst.expressions.CurrentBatchTimestamp
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, HashPartitioning, SinglePartition}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode}
import org.apache.spark.sql.execution.exchange.ShuffleExchange
import org.apache.spark.sql.streaming.OutputMode

/**
Expand Down Expand Up @@ -89,7 +91,7 @@ class IncrementalExecution(
override def apply(plan: SparkPlan): SparkPlan = plan transform {
case StateStoreSaveExec(keys, None, None, None,
UnaryExecNode(agg,
StateStoreRestoreExec(keys2, None, child))) =>
StateStoreRestoreExec(_, None, child))) =>
val aggStateInfo = nextStatefulOperationStateInfo
StateStoreSaveExec(
keys,
Expand Down Expand Up @@ -117,8 +119,34 @@ class IncrementalExecution(
}
}

override def preparations: Seq[Rule[SparkPlan]] = state +: super.preparations
override def preparations: Seq[Rule[SparkPlan]] =
Seq(state, EnsureStatefulOpPartitioning) ++ super.preparations

/** No need assert supported, as this check has already been done */
override def assertSupported(): Unit = { }
}

object EnsureStatefulOpPartitioning extends Rule[SparkPlan] {
// Needs to be transformUp to avoid extra shuffles
override def apply(plan: SparkPlan): SparkPlan = plan transformUp {
case so: StatefulOperator =>
val numPartitions = plan.sqlContext.sessionState.conf.numShufflePartitions
val distributions = so.requiredChildDistribution
val children = so.children.zip(distributions).map { case (child, reqDistribution) =>
val expectedPartitioning = reqDistribution match {
case AllTuples => SinglePartition
case ClusteredDistribution(keys) => HashPartitioning(keys, numPartitions)
case _ => throw new AnalysisException("Unexpected distribution expected for " +
s"Stateful Operator: $so. Expect AllTuples or ClusteredDistribution but got " +
s"$reqDistribution.")
}
if (child.outputPartitioning.guarantees(expectedPartitioning) &&
child.execute().getNumPartitions == expectedPartitioning.numPartitions) {
child
} else {
ShuffleExchange(expectedPartitioning, child)
}
}
so.withNewChildren(children)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,7 @@ class StreamExecution(
if (streamDeathCause != null) {
throw streamDeathCause
}
if (!isActive) return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 good catch

awaitBatchLock.lock()
try {
noNewData = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate}
import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
Expand Down Expand Up @@ -200,18 +200,35 @@ case class StateStoreRestoreExec(
sqlContext.sessionState,
Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
iter.flatMap { row =>
val key = getKey(row)
val savedState = store.get(key)
numOutputRows += 1
row +: Option(savedState).toSeq
val hasInput = iter.hasNext
if (!hasInput && keyExpressions.isEmpty) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add docs on why we are doing this. similar to the docs in other places related to batch aggregation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there wasn't any docs in batch :)

// If our `keyExpressions` are empty, we're getting a global aggregation. In that case
// the `HashAggregateExec` will output a 0 value for the partial merge. We need to
// restore the value, so that we don't overwrite our state with a 0 value, but rather
// merge the 0 with existing state.
store.iterator().map(_.value)
} else {
iter.flatMap { row =>
val key = getKey(row)
val savedState = store.get(key)
numOutputRows += 1
row +: Option(savedState).toSeq
}
}
}
}

override def output: Seq[Attribute] = child.output

override def outputPartitioning: Partitioning = child.outputPartitioning

override def requiredChildDistribution: Seq[Distribution] = {
if (keyExpressions.isEmpty) {
AllTuples :: Nil
} else {
ClusteredDistribution(keyExpressions) :: Nil
}
}
}

/**
Expand Down Expand Up @@ -351,6 +368,14 @@ case class StateStoreSaveExec(
override def output: Seq[Attribute] = child.output

override def outputPartitioning: Partitioning = child.outputPartitioning

override def requiredChildDistribution: Seq[Distribution] = {
if (keyExpressions.isEmpty) {
AllTuples :: Nil
} else {
ClusteredDistribution(keyExpressions) :: Nil
}
}
}

/** Physical operator for executing streaming Deduplicate. */
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.streaming

import java.util.UUID

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode}
import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchange}
import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata, StatefulOperator, StatefulOperatorStateInfo}
import org.apache.spark.sql.test.SharedSQLContext

class EnsureStatefulOpPartitioningSuite extends SparkPlanTest with SharedSQLContext {

import testImplicits._
super.beforeAll()

private val baseDf = Seq((1, "A"), (2, "b")).toDF("num", "char")

testEnsureStatefulOpPartitioning(
"ClusteredDistribution generates Exchange with HashPartitioning",
baseDf.queryExecution.sparkPlan,
requiredDistribution = keys => ClusteredDistribution(keys),
expectedPartitioning =
keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions),
expectShuffle = true)

testEnsureStatefulOpPartitioning(
"ClusteredDistribution with coalesce(1) generates Exchange with HashPartitioning",
baseDf.coalesce(1).queryExecution.sparkPlan,
requiredDistribution = keys => ClusteredDistribution(keys),
expectedPartitioning =
keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions),
expectShuffle = true)

testEnsureStatefulOpPartitioning(
"AllTuples generates Exchange with SinglePartition",
baseDf.queryExecution.sparkPlan,
requiredDistribution = _ => AllTuples,
expectedPartitioning = _ => SinglePartition,
expectShuffle = true)

testEnsureStatefulOpPartitioning(
"AllTuples with coalesce(1) doesn't need Exchange",
baseDf.coalesce(1).queryExecution.sparkPlan,
requiredDistribution = _ => AllTuples,
expectedPartitioning = _ => SinglePartition,
expectShuffle = false)

/**
* For `StatefulOperator` with the given `requiredChildDistribution`, and child SparkPlan
* `inputPlan`, ensures that the incremental planner adds exchanges, if required, in order to
* ensure the expected partitioning.
*/
private def testEnsureStatefulOpPartitioning(
testName: String,
inputPlan: SparkPlan,
requiredDistribution: Seq[Attribute] => Distribution,
expectedPartitioning: Seq[Attribute] => Partitioning,
expectShuffle: Boolean): Unit = {
test(testName) {
val operator = TestStatefulOperator(inputPlan, requiredDistribution(inputPlan.output.take(1)))
val executed = executePlan(operator, OutputMode.Complete())
if (expectShuffle) {
val exchange = executed.children.find(_.isInstanceOf[Exchange])
if (exchange.isEmpty) {
fail(s"Was expecting an exchange but didn't get one in:\n$executed")
}
assert(exchange.get ===
ShuffleExchange(expectedPartitioning(inputPlan.output.take(1)), inputPlan),
s"Exchange didn't have expected properties:\n${exchange.get}")
} else {
assert(!executed.children.exists(_.isInstanceOf[Exchange]),
s"Unexpected exchange found in:\n$executed")
}
}
}

/** Executes a SparkPlan using the IncrementalPlanner used for Structured Streaming. */
private def executePlan(
p: SparkPlan,
outputMode: OutputMode = OutputMode.Append()): SparkPlan = {
val execution = new IncrementalExecution(
spark,
null,
OutputMode.Complete(),
"chk",
UUID.randomUUID(),
0L,
OffsetSeqMetadata()) {
override lazy val sparkPlan: SparkPlan = p transform {
case plan: SparkPlan =>
val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
plan transformExpressions {
case UnresolvedAttribute(Seq(u)) =>
inputMap.getOrElse(u,
sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
}
}
}
execution.executedPlan
}
}

/** Used to emulate a `StatefulOperator` with the given requiredDistribution. */
case class TestStatefulOperator(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

child: SparkPlan,
requiredDist: Distribution) extends UnaryExecNode with StatefulOperator {
override def output: Seq[Attribute] = child.output
override def doExecute(): RDD[InternalRow] = child.execute()
override def requiredChildDistribution: Seq[Distribution] = requiredDist :: Nil
override def stateInfo: Option[StatefulOperatorStateInfo] = None
}
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
case class StartStream(
trigger: Trigger = Trigger.ProcessingTime(0),
triggerClock: Clock = new SystemClock,
additionalConfs: Map[String, String] = Map.empty)
additionalConfs: Map[String, String] = Map.empty,
checkpointLocation: String = null)
extends StreamAction

/** Advance the trigger clock's time manually. */
Expand Down Expand Up @@ -349,20 +350,22 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
""".stripMargin)
}

val metadataRoot = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
var manualClockExpectedTime = -1L
val defaultCheckpointLocation =
Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
try {
startedTest.foreach { action =>
logInfo(s"Processing test stream action: $action")
action match {
case StartStream(trigger, triggerClock, additionalConfs) =>
case StartStream(trigger, triggerClock, additionalConfs, checkpointLocation) =>
verify(currentStream == null, "stream already running")
verify(triggerClock.isInstanceOf[SystemClock]
|| triggerClock.isInstanceOf[StreamManualClock],
"Use either SystemClock or StreamManualClock to start the stream")
if (triggerClock.isInstanceOf[StreamManualClock]) {
manualClockExpectedTime = triggerClock.asInstanceOf[StreamManualClock].getTimeMillis()
}
val metadataRoot = Option(checkpointLocation).getOrElse(defaultCheckpointLocation)

additionalConfs.foreach(pair => {
val value =
Expand Down Expand Up @@ -479,7 +482,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
verify(currentStream != null || lastStream != null,
"cannot assert when no stream has been started")
val streamToAssert = Option(currentStream).getOrElse(lastStream)
verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}")
try {
verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}")
} catch {
case NonFatal(e) =>
failTest(s"Assert on query failed: ${a.message}", e)
}

case a: Assert =>
val streamToAssert = Option(currentStream).getOrElse(lastStream)
Expand Down
Loading