From b7aeed6af2aaf6eb347dd0a492a62e6530900eb5 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Fri, 8 Sep 2017 11:36:02 -0700 Subject: [PATCH 01/14] couldn't repro --- .../execution/streaming/StreamExecution.scala | 3 ++- .../sql/execution/streaming/memory.scala | 1 + .../streaming/statefulOperators.scala | 2 ++ .../streaming/StreamingAggregationSuite.scala | 24 +++++++++++++++++++ 4 files changed, 29 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 71088ff6386b..1d58291909d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -685,7 +685,8 @@ class StreamExecution( runId, currentBatchId, offsetSeqMetadata) - lastExecution.executedPlan // Force the lazy generation of execution plan + println("\n\n\n") + println(lastExecution.executedPlan) // Force the lazy generation of execution plan } val nextBatch = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index c9784c093b40..39bc52c608d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -212,6 +212,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi latestBatchId.isEmpty || batchId > latestBatchId.get } if (notCommitted) { + println(data.queryExecution.toRdd.toDebugString) logDebug(s"Committing batch $batchId to $this") outputMode match { case Append | Update => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index e46356392c51..e57714d054de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -200,7 +200,9 @@ case class StateStoreRestoreExec( sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) + println(s"Restore iterHasNext: ${iter.hasNext}") iter.flatMap { row => + println(row) val key = getKey(row) val savedState = store.get(key) numOutputRows += 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index e0979ce296c3..9d4ee73080d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -49,6 +49,29 @@ class StreamingAggregationSuite extends StateStoreMetricsTest import testImplicits._ + test("simple count, update mode") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDS() + .coalesce(1) + .groupBy() + .count() + .as[Long] + + testStream(aggregated, Complete())( + AddData(inputData, 3), + CheckLastBatch(1), + AddData(inputData), + CheckLastBatch(1), + AddData(inputData, 2, 4), + CheckLastBatch(3), + AddData(inputData), + CheckLastBatch(3), + StopStream + ) + } + /* test("simple count, update mode") { val inputData = MemoryStream[Int] @@ -381,4 +404,5 @@ class StreamingAggregationSuite extends StateStoreMetricsTest AddData(streamInput, 0, 1, 2, 3), CheckLastBatch((0, 0, 2), (1, 1, 3))) } + */ } From 4a7d1240196cc4660d33aef33d893526da5f0ceb Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 11 Sep 2017 10:44:15 -0700 Subject: [PATCH 02/14] save --- .../streaming/StreamingAggregationSuite.scala | 156 ++++++++++++++++-- 1 file changed, 140 insertions(+), 16 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 9d4ee73080d4..19011a00d91f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -23,7 +23,8 @@ import org.scalatest.Assertions import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, DataFrame} +import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.SparkPlan @@ -32,7 +33,7 @@ import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.OutputMode._ -import org.apache.spark.sql.streaming.util.StreamManualClock +import org.apache.spark.sql.streaming.util.{MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType object FailureSinglton { @@ -50,27 +51,150 @@ class StreamingAggregationSuite extends StateStoreMetricsTest import testImplicits._ test("simple count, update mode") { - val inputData = MemoryStream[Int] + /* + class NonLocalRelationSource extends Source { + private var nextData: Seq[Int] = Seq.empty + private var counter = 0L + def addData(data: Int*): Unit = { + nextData = data + counter += data.length + } - val aggregated = - inputData.toDS() + def noData: Unit = { + counter += 1 + } + + override def getOffset: Option[Offset] = if (counter == 0) None else Some(LongOffset(counter)) + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + val rdd = spark.sparkContext.parallelize(nextData, nextData.length) + .map(i => InternalRow(i)) // we don't really care about the values in this test + nextData = Seq.empty + spark.internalCreateDataFrame(rdd, schema, isStreaming = true).toDF() + } + override def schema: StructType = MockSourceProvider.fakeSchema + override def stop(): Unit = {} + } + + val inputSource = new NonLocalRelationSource + MockSourceProvider.withMockSources(inputSource) { + withTempDir { tempDir => + val aggregated: Dataset[Long] = + spark.readStream + .format((new MockSourceProvider).getClass.getCanonicalName) + .load() + .coalesce(1) + .groupBy() + .count() + .as[Long] + + val sq = aggregated.writeStream + .format("memory") + .outputMode("complete") + .queryName("agg_test") + .option("checkpointLocation", tempDir.getAbsolutePath) + .start() + + try { + + inputSource.addData(1) + sq.processAllAvailable() + + checkDataset( + spark.table("agg_test").as[Long], + 1L) + + inputSource.noData + sq.processAllAvailable() + + checkDataset( + spark.table("agg_test").as[Long], + 1L) + + inputSource.addData(2, 3) + sq.processAllAvailable() + + checkDataset( + spark.table("agg_test").as[Long], + 3L) + + inputSource.noData + sq.processAllAvailable() + + checkDataset( + spark.table("agg_test").as[Long], + 3L) + } finally { + sq.stop() + } + } + } + */ + + withTempDir { tempDir => + val inputStream = MemoryStream[Int] + + val sq = inputStream.toDS() .coalesce(1) .groupBy() .count() .as[Long] + .writeStream + .format("memory") + .outputMode("complete") + .queryName("agg_test") + .option("checkpointLocation", tempDir.getAbsolutePath) + .start() + + try { + inputStream.addData(1) + sq.processAllAvailable() + + checkDataset( + spark.table("agg_test").as[Long], + 1L) + + inputStream.addData() + sq.processAllAvailable() + + checkDataset( + spark.table("agg_test").as[Long], + 1L) + } finally { + sq.stop() + } - testStream(aggregated, Complete())( - AddData(inputData, 3), - CheckLastBatch(1), - AddData(inputData), - CheckLastBatch(1), - AddData(inputData, 2, 4), - CheckLastBatch(3), - AddData(inputData), - CheckLastBatch(3), - StopStream - ) + val sq2 = inputStream.toDS() + .coalesce(2) + .groupBy() + .count() + .as[Long] + .writeStream + .format("memory") + .outputMode("complete") + .queryName("agg_test") + .option("checkpointLocation", tempDir.getAbsolutePath) + .start() + + try { + inputStream.addData(2, 3) + sq2.processAllAvailable() + + checkDataset( + spark.table("agg_test").as[Long], + 3L) + + inputStream.addData() + sq2.processAllAvailable() + + checkDataset( + spark.table("agg_test").as[Long], + 3L) + } finally { + sq2.stop() + } + } } + /* test("simple count, update mode") { val inputData = MemoryStream[Int] From 00fa5923c7663f58df72937626bfadac5dc2f1fd Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 11 Sep 2017 21:32:30 -0700 Subject: [PATCH 03/14] ready for review --- .../streaming/IncrementalExecution.scala | 29 +- .../execution/streaming/StreamExecution.scala | 4 +- .../sql/execution/streaming/memory.scala | 1 - .../streaming/statefulOperators.scala | 31 +- .../FlatMapGroupsWithStateSuite.scala | 81 +++- .../streaming/StreamingAggregationSuite.scala | 345 ++++++++++-------- 6 files changed, 325 insertions(+), 166 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 258a64216136..8ba2cb67cbc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -24,8 +24,10 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, Strategy} import org.apache.spark.sql.catalyst.expressions.CurrentBatchTimestamp import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode} +import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.streaming.OutputMode /** @@ -89,7 +91,7 @@ class IncrementalExecution( override def apply(plan: SparkPlan): SparkPlan = plan transform { case StateStoreSaveExec(keys, None, None, None, UnaryExecNode(agg, - StateStoreRestoreExec(keys2, None, child))) => + StateStoreRestoreExec(_, None, child))) => val aggStateInfo = nextStatefulOperationStateInfo StateStoreSaveExec( keys, @@ -117,8 +119,31 @@ class IncrementalExecution( } } - override def preparations: Seq[Rule[SparkPlan]] = state +: super.preparations + override def preparations: Seq[Rule[SparkPlan]] = Seq( + state, + EnsureStatefulOpPartitioning) ++ super.preparations /** No need assert supported, as this check has already been done */ override def assertSupported(): Unit = { } } + +object EnsureStatefulOpPartitioning extends Rule[SparkPlan] { + // Needs to be transformUp to avoid extra shuffles + override def apply(plan: SparkPlan): SparkPlan = plan transformUp { + case ss: StatefulOperator => + val numPartitions = plan.sqlContext.sessionState.conf.numShufflePartitions + val keys = ss.keyExpressions + val child = ss.child + val expectedPartitioning = if (keys.isEmpty) { + SinglePartition + } else { + HashPartitioning(keys, numPartitions) + } + if (child.outputPartitioning.guarantees(expectedPartitioning) && + child.execute().getNumPartitions == expectedPartitioning.numPartitions) { + ss + } else { + ss.withNewChildren(ShuffleExchange(expectedPartitioning, child) :: Nil) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 1d58291909d2..c6c4dd3ddc34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -685,8 +685,7 @@ class StreamExecution( runId, currentBatchId, offsetSeqMetadata) - println("\n\n\n") - println(lastExecution.executedPlan) // Force the lazy generation of execution plan + lastExecution.executedPlan // Force the lazy generation of execution plan } val nextBatch = @@ -801,6 +800,7 @@ class StreamExecution( if (streamDeathCause != null) { throw streamDeathCause } + if (!isActive) return awaitBatchLock.lock() try { noNewData = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 39bc52c608d2..c9784c093b40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -212,7 +212,6 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi latestBatchId.isEmpty || batchId > latestBatchId.get } if (notCommitted) { - println(data.queryExecution.toRdd.toDebugString) logDebug(s"Committing batch $batchId to $this") outputMode match { case Append | Update => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index e57714d054de..a276da89069b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -53,6 +53,10 @@ case class StatefulOperatorStateInfo( trait StatefulOperator extends SparkPlan { def stateInfo: Option[StatefulOperatorStateInfo] + def child: SparkPlan + + def keyExpressions: Seq[Attribute] + protected def getStateInfo: StatefulOperatorStateInfo = attachTree(this) { stateInfo.getOrElse { throw new IllegalStateException("State location not present for execution") @@ -200,13 +204,16 @@ case class StateStoreRestoreExec( sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) - println(s"Restore iterHasNext: ${iter.hasNext}") - iter.flatMap { row => - println(row) - val key = getKey(row) - val savedState = store.get(key) - numOutputRows += 1 - row +: Option(savedState).toSeq + val hasInput = iter.hasNext + if (!hasInput && keyExpressions.isEmpty) { + store.iterator().map(_.value) + } else { + iter.flatMap { row => + val key = getKey(row) + val savedState = store.get(key) + numOutputRows += 1 + row +: Option(savedState).toSeq + } } } } @@ -214,6 +221,14 @@ case class StateStoreRestoreExec( override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = { + if (keyExpressions.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(keyExpressions) :: Nil + } + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 9d74a5c701ef..95e00cd8ff6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.streaming +import java.io.File import java.sql.Date import java.util.concurrent.ConcurrentHashMap @@ -24,7 +25,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction -import org.apache.spark.sql.Encoder +import org.apache.spark.sql.{DataFrame, Dataset, Encoder} import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning @@ -33,7 +34,7 @@ import org.apache.spark.sql.execution.RDDScanExec import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream} import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair} import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore -import org.apache.spark.sql.streaming.util.StreamManualClock +import org.apache.spark.sql.streaming.util.{MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.{DataType, IntegerType} /** Class to check custom state types */ @@ -873,6 +874,82 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf assert(e.getMessage === "The output mode of function should be append or update") } + test("SPARK-21977: coalesce(1) should still be repartitioned when it has keyExpressions") { + val inputSource = new NonLocalRelationSource(spark) + + MockSourceProvider.withMockSources(inputSource) { + withTempDir { tempDir => + val checkpoint = new File(tempDir, "checkpoint").getAbsolutePath + val data = new File(tempDir, "data").getAbsolutePath + + def startQuery(df: Dataset[Int]): StreamingQuery = { + df.groupByKey(_ % 1) // just to give it a fake key + .flatMapGroupsWithState(OutputMode.Append(), NoTimeout()) { + (key: Int, values: Iterator[Int], state: GroupState[Int]) => + // function that returns the values that exceed the max value ever seen in past + // triggers + val existing = state.getOption.getOrElse(Int.MinValue) + values.filter { value => + val max = state.getOption.getOrElse(Int.MinValue) + if (value > max) { + state.update(value) + } + value > existing + } + } + .writeStream + .format("parquet") + .outputMode("append") + .option("checkpointLocation", checkpoint) + .start(data) + } + + + val sq = startQuery(spark.readStream + .format((new MockSourceProvider).getClass.getCanonicalName) + .load() + .coalesce(1) + .as[Int]) + + try { + + inputSource.addData(1) + inputSource.releaseLock() + sq.processAllAvailable() + + checkDataset( + spark.read.parquet(data).as[Int], + 1) + + } finally { + sq.stop() + } + + val sq2 = startQuery(spark.readStream + .format((new MockSourceProvider).getClass.getCanonicalName) + .load() + .coalesce(2) + .as[Int]) + + try { + sq2.processAllAvailable() + inputSource.addData(0) + inputSource.addData(3) + inputSource.addData(4) + inputSource.releaseLock() + sq2.processAllAvailable() + + checkDataset( + spark.read.parquet(data).as[Int], + 4, 3, 1) + + } finally { + sq2.stop() + } + } + } + } + def testWithTimeout(timeoutConf: GroupStateTimeout): Unit = { test("SPARK-20714: watermark does not fail query when timeout = " + timeoutConf) { // Function to maintain running count up to 2, and then remove the count diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 19011a00d91f..f4aa3ab9fbb9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -18,14 +18,17 @@ package org.apache.spark.sql.streaming import java.util.{Locale, TimeZone} +import java.util.concurrent.CountDownLatch import org.scalatest.Assertions import org.scalatest.BeforeAndAfterAll -import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset} +import org.apache.spark.{SparkEnv, SparkException} +import org.apache.spark.rdd.BlockRDD +import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.catalyst.plans.physical.SinglePartition import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming._ @@ -35,8 +38,9 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.OutputMode._ import org.apache.spark.sql.streaming.util.{MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType +import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} -object FailureSinglton { +object FailureSingleton { var firstTime = true } @@ -50,151 +54,6 @@ class StreamingAggregationSuite extends StateStoreMetricsTest import testImplicits._ - test("simple count, update mode") { - /* - class NonLocalRelationSource extends Source { - private var nextData: Seq[Int] = Seq.empty - private var counter = 0L - def addData(data: Int*): Unit = { - nextData = data - counter += data.length - } - - def noData: Unit = { - counter += 1 - } - - override def getOffset: Option[Offset] = if (counter == 0) None else Some(LongOffset(counter)) - override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - val rdd = spark.sparkContext.parallelize(nextData, nextData.length) - .map(i => InternalRow(i)) // we don't really care about the values in this test - nextData = Seq.empty - spark.internalCreateDataFrame(rdd, schema, isStreaming = true).toDF() - } - override def schema: StructType = MockSourceProvider.fakeSchema - override def stop(): Unit = {} - } - - val inputSource = new NonLocalRelationSource - MockSourceProvider.withMockSources(inputSource) { - withTempDir { tempDir => - val aggregated: Dataset[Long] = - spark.readStream - .format((new MockSourceProvider).getClass.getCanonicalName) - .load() - .coalesce(1) - .groupBy() - .count() - .as[Long] - - val sq = aggregated.writeStream - .format("memory") - .outputMode("complete") - .queryName("agg_test") - .option("checkpointLocation", tempDir.getAbsolutePath) - .start() - - try { - - inputSource.addData(1) - sq.processAllAvailable() - - checkDataset( - spark.table("agg_test").as[Long], - 1L) - - inputSource.noData - sq.processAllAvailable() - - checkDataset( - spark.table("agg_test").as[Long], - 1L) - - inputSource.addData(2, 3) - sq.processAllAvailable() - - checkDataset( - spark.table("agg_test").as[Long], - 3L) - - inputSource.noData - sq.processAllAvailable() - - checkDataset( - spark.table("agg_test").as[Long], - 3L) - } finally { - sq.stop() - } - } - } - */ - - withTempDir { tempDir => - val inputStream = MemoryStream[Int] - - val sq = inputStream.toDS() - .coalesce(1) - .groupBy() - .count() - .as[Long] - .writeStream - .format("memory") - .outputMode("complete") - .queryName("agg_test") - .option("checkpointLocation", tempDir.getAbsolutePath) - .start() - - try { - inputStream.addData(1) - sq.processAllAvailable() - - checkDataset( - spark.table("agg_test").as[Long], - 1L) - - inputStream.addData() - sq.processAllAvailable() - - checkDataset( - spark.table("agg_test").as[Long], - 1L) - } finally { - sq.stop() - } - - val sq2 = inputStream.toDS() - .coalesce(2) - .groupBy() - .count() - .as[Long] - .writeStream - .format("memory") - .outputMode("complete") - .queryName("agg_test") - .option("checkpointLocation", tempDir.getAbsolutePath) - .start() - - try { - inputStream.addData(2, 3) - sq2.processAllAvailable() - - checkDataset( - spark.table("agg_test").as[Long], - 3L) - - inputStream.addData() - sq2.processAllAvailable() - - checkDataset( - spark.table("agg_test").as[Long], - 3L) - } finally { - sq2.stop() - } - } - } - /* test("simple count, update mode") { val inputData = MemoryStream[Int] @@ -373,12 +232,12 @@ class StreamingAggregationSuite extends StateStoreMetricsTest testQuietly("midbatch failure") { val inputData = MemoryStream[Int] - FailureSinglton.firstTime = true + FailureSingleton.firstTime = true val aggregated = inputData.toDS() .map { i => - if (i == 4 && FailureSinglton.firstTime) { - FailureSinglton.firstTime = false + if (i == 4 && FailureSingleton.firstTime) { + FailureSingleton.firstTime = false sys.error("injected failure") } @@ -529,4 +388,188 @@ class StreamingAggregationSuite extends StateStoreMetricsTest CheckLastBatch((0, 0, 2), (1, 1, 3))) } */ + + test("SPARK-21977: coalesce(1) with 0 partition RDD should be repartitioned accordingly") { + val inputSource = new NonLocalRelationSource(spark) + MockSourceProvider.withMockSources(inputSource) { + withTempDir { tempDir => + val aggregated: Dataset[Long] = + spark.readStream + .format((new MockSourceProvider).getClass.getCanonicalName) + .load() + .coalesce(1) + .groupBy() + .count() + .as[Long] + + val sq = aggregated.writeStream + .format("memory") + .outputMode("complete") + .queryName("agg_test") + .option("checkpointLocation", tempDir.getAbsolutePath) + .start() + + try { + + inputSource.addData(1) + inputSource.releaseLock() + sq.processAllAvailable() + + checkDataset( + spark.table("agg_test").as[Long], + 1L) + + inputSource.addData() + inputSource.releaseLock() + sq.processAllAvailable() + + checkDataset( + spark.table("agg_test").as[Long], + 1L) + + inputSource.addData(2, 3) + inputSource.releaseLock() + sq.processAllAvailable() + + checkDataset( + spark.table("agg_test").as[Long], + 3L) + + inputSource.addData() + inputSource.releaseLock() + sq.processAllAvailable() + + checkDataset( + spark.table("agg_test").as[Long], + 3L) + } finally { + sq.stop() + } + } + } + } + + test("SPARK-21977: coalesce(1) should still be repartitioned when it has keyExpressions") { + val inputSource = new NonLocalRelationSource(spark) + MockSourceProvider.withMockSources(inputSource) { + withTempDir { tempDir => + + val sq = spark.readStream + .format((new MockSourceProvider).getClass.getCanonicalName) + .load() + .coalesce(1) + .groupBy('a % 1) // just to give it a fake key + .count() + .as[(Long, Long)] + .writeStream + .format("memory") + .outputMode("complete") + .queryName("agg_test") + .option("checkpointLocation", tempDir.getAbsolutePath) + .start() + + try { + + inputSource.addData(1) + inputSource.releaseLock() + sq.processAllAvailable() + + checkDataset( + spark.table("agg_test").as[(Long, Long)], + (0L, 1L)) + + } finally { + sq.stop() + } + + val sq2 = spark.readStream + .format((new MockSourceProvider).getClass.getCanonicalName) + .load() + .coalesce(2) + .groupBy('a % 1) // just to give it a fake key + .count() + .as[(Long, Long)] + .writeStream + .format("memory") + .outputMode("complete") + .queryName("agg_test") + .option("checkpointLocation", tempDir.getAbsolutePath) + .start() + + try { + sq2.processAllAvailable() + inputSource.addData(2) + inputSource.addData(3) + inputSource.addData(4) + inputSource.releaseLock() + sq2.processAllAvailable() + + checkDataset( + spark.table("agg_test").as[(Long, Long)], + (0L, 4L)) + + inputSource.addData() + inputSource.releaseLock() + sq2.processAllAvailable() + + checkDataset( + spark.table("agg_test").as[(Long, Long)], + (0L, 4L)) + } finally { + sq2.stop() + } + } + } + } +} + +/** + * LocalRelation has some optimized properties during Spark planning. In order for the bugs in + * SPARK-21977 to occur, we need to create a logical relation from an existing RDD. We use a + * BlockRDD since it accepts 0 partitions. One requirement for the one of the bugs is the use of + * `coalesce(1)`, which has several optimizations regarding [[SinglePartition]], and a 0 partition + * parentRDD. + */ +class NonLocalRelationSource(spark: SparkSession) extends Source { + private var counter = 0L + private val blockMgr = SparkEnv.get.blockManager + private var blocks: Seq[BlockId] = Seq.empty + + private var streamLock: CountDownLatch = new CountDownLatch(1) + + def addData(data: Int*): Unit = { + if (streamLock.getCount == 0) { + streamLock = new CountDownLatch(1) + } + synchronized { + if (data.nonEmpty) { + counter += data.length + val id = TestBlockId(counter.toString) + blockMgr.putIterator(id, data.iterator, StorageLevel.MEMORY_ONLY) + blocks ++= id :: Nil + } else { + counter += 1 + } + } + } + + def releaseLock(): Unit = streamLock.countDown() + + override def getOffset: Option[Offset] = { + streamLock.await() + synchronized { + if (counter == 0) None else Some(LongOffset(counter)) + } + } + + override def getBatch(start: Option[Offset], end: Offset): DataFrame = synchronized { + val rdd = new BlockRDD[Int](spark.sparkContext, blocks.toArray) + .map(i => InternalRow(i)) // we don't really care about the values in this test + blocks = Seq.empty + spark.internalCreateDataFrame(rdd, schema, isStreaming = true).toDF() + } + override def schema: StructType = MockSourceProvider.fakeSchema + override def stop(): Unit = { + blockMgr.getMatchingBlockIds(_.isInstanceOf[TestBlockId]).foreach(blockMgr.removeBlock(_)) + } } From 090044ca089870befff464d37f098c4d4fd19657 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 11 Sep 2017 21:33:05 -0700 Subject: [PATCH 04/14] uncomment --- .../apache/spark/sql/streaming/StreamingAggregationSuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index f4aa3ab9fbb9..8190fda70d48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -54,7 +54,6 @@ class StreamingAggregationSuite extends StateStoreMetricsTest import testImplicits._ - /* test("simple count, update mode") { val inputData = MemoryStream[Int] @@ -387,7 +386,6 @@ class StreamingAggregationSuite extends StateStoreMetricsTest AddData(streamInput, 0, 1, 2, 3), CheckLastBatch((0, 0, 2), (1, 1, 3))) } - */ test("SPARK-21977: coalesce(1) with 0 partition RDD should be repartitioned accordingly") { val inputSource = new NonLocalRelationSource(spark) From 2f949517ea1d667aee8ca6838a374e222492c0c7 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 11 Sep 2017 22:06:55 -0700 Subject: [PATCH 05/14] Added more checks --- .../FlatMapGroupsWithStateSuite.scala | 19 +++++++- .../streaming/StreamingAggregationSuite.scala | 48 ++++++++++++++++++- 2 files changed, 64 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 95e00cd8ff6c..5e303a51d6ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -30,8 +30,9 @@ import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProj import org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.execution.RDDScanExec -import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream} +import org.apache.spark.sql.execution.{RDDScanExec, WholeStageCodegenExec} +import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream, StreamingQueryWrapper} import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair} import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore import org.apache.spark.sql.streaming.util.{MockSourceProvider, StreamManualClock} @@ -917,6 +918,13 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf inputSource.releaseLock() sq.processAllAvailable() + val restore1 = sq.asInstanceOf[StreamingQueryWrapper].streamingQuery + .lastExecution.executedPlan + .collect { case ss: FlatMapGroupsWithStateExec => ss } + .head + assert(restore1.child.outputPartitioning.numPartitions === + spark.sessionState.conf.numShufflePartitions) + checkDataset( spark.read.parquet(data).as[Int], 1) @@ -939,6 +947,13 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf inputSource.releaseLock() sq2.processAllAvailable() + val restore2 = sq.asInstanceOf[StreamingQueryWrapper].streamingQuery + .lastExecution.executedPlan + .collect { case ss: FlatMapGroupsWithStateExec => ss } + .head + assert(restore2.child.outputPartitioning.numPartitions === + spark.sessionState.conf.numShufflePartitions) + checkDataset( spark.read.parquet(data).as[Int], 4, 3, 1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 8190fda70d48..2d5c39509f48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -30,7 +30,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.catalyst.plans.physical.SinglePartition import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.expressions.scalalang.typed @@ -417,10 +419,30 @@ class StreamingAggregationSuite extends StateStoreMetricsTest spark.table("agg_test").as[Long], 1L) + val restore1 = sq.asInstanceOf[StreamingQueryWrapper].streamingQuery + .lastExecution.executedPlan + .collect { case ss: StateStoreRestoreExec => ss } + .head + restore1.child match { + case wscg: WholeStageCodegenExec => + assert(wscg.outputPartitioning.numPartitions === 1) + assert(wscg.child.isInstanceOf[HashAggregateExec], "Shouldn't require shuffling") + case _ => fail("Expected no shuffling") + } + inputSource.addData() inputSource.releaseLock() sq.processAllAvailable() + val restore2 = sq.asInstanceOf[StreamingQueryWrapper].streamingQuery + .lastExecution.executedPlan + .collect { case ss: StateStoreRestoreExec => ss } + .head + restore2.child match { + case shuffle: ShuffleExchange => assert(shuffle.newPartitioning.numPartitions === 1) + case _ => fail("Expected shuffling when there was no data") + } + checkDataset( spark.table("agg_test").as[Long], 1L) @@ -472,6 +494,17 @@ class StreamingAggregationSuite extends StateStoreMetricsTest inputSource.releaseLock() sq.processAllAvailable() + val restore1 = sq.asInstanceOf[StreamingQueryWrapper].streamingQuery + .lastExecution.executedPlan + .collect { case ss: StateStoreRestoreExec => ss } + .head + restore1.child match { + case shuffle: ShuffleExchange => + assert(shuffle.newPartitioning.numPartitions === + spark.sessionState.conf.numShufflePartitions) + case _ => fail(s"Expected shuffling but got: ${restore1.child}") + } + checkDataset( spark.table("agg_test").as[(Long, Long)], (0L, 1L)) @@ -502,6 +535,19 @@ class StreamingAggregationSuite extends StateStoreMetricsTest inputSource.releaseLock() sq2.processAllAvailable() + val restore2 = sq2.asInstanceOf[StreamingQueryWrapper].streamingQuery + .lastExecution.executedPlan + .collect { case ss: StateStoreRestoreExec => ss } + .head + restore2.child match { + case wscg: WholeStageCodegenExec => + assert(wscg.outputPartitioning.numPartitions === + spark.sessionState.conf.numShufflePartitions) + case _ => + fail("Shouldn't require shuffling as HashAggregateExec should have asked for a " + + s"shuffle. But got: ${restore2.child}") + } + checkDataset( spark.table("agg_test").as[(Long, Long)], (0L, 4L)) From 12cf02a10ff7219f1ed405c37c2ac87c65a6c798 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 11 Sep 2017 22:59:54 -0700 Subject: [PATCH 06/14] save --- .../streaming/IncrementalExecution.scala | 14 +-- .../streaming/statefulOperators.scala | 2 - .../streaming/StreamingAggregationSuite.scala | 85 ++++++++++--------- 3 files changed, 51 insertions(+), 50 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 8ba2cb67cbc8..2055ac0410fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -133,17 +133,19 @@ object EnsureStatefulOpPartitioning extends Rule[SparkPlan] { case ss: StatefulOperator => val numPartitions = plan.sqlContext.sessionState.conf.numShufflePartitions val keys = ss.keyExpressions - val child = ss.child val expectedPartitioning = if (keys.isEmpty) { SinglePartition } else { HashPartitioning(keys, numPartitions) } - if (child.outputPartitioning.guarantees(expectedPartitioning) && - child.execute().getNumPartitions == expectedPartitioning.numPartitions) { - ss - } else { - ss.withNewChildren(ShuffleExchange(expectedPartitioning, child) :: Nil) + val children = ss.children.map { child => + if (child.outputPartitioning.guarantees(expectedPartitioning) && + child.execute().getNumPartitions == expectedPartitioning.numPartitions) { + child + } else { + ShuffleExchange(expectedPartitioning, child) + } } + ss.withNewChildren(children) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index a276da89069b..6d86009890c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -53,8 +53,6 @@ case class StatefulOperatorStateInfo( trait StatefulOperator extends SparkPlan { def stateInfo: Option[StatefulOperatorStateInfo] - def child: SparkPlan - def keyExpressions: Seq[Attribute] protected def getStateInfo: StatefulOperatorStateInfo = attachTree(this) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 2d5c39509f48..bc333eabab10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -30,9 +30,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.catalyst.plans.physical.SinglePartition import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec -import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchange} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.expressions.scalalang.typed @@ -389,6 +389,37 @@ class StreamingAggregationSuite extends StateStoreMetricsTest CheckLastBatch((0, 0, 2), (1, 1, 3))) } + private def checkAggregationChain( + sq: StreamingQuery, + requiresShuffling: Boolean, + expectedPartition: Int): Unit = { + val executedPlan = sq.asInstanceOf[StreamingQueryWrapper].streamingQuery + .lastExecution.executedPlan + val restore = executedPlan + .collect { case ss: StateStoreRestoreExec => ss } + .head + restore.child match { + case node: UnaryExecNode => + assert(node.outputPartitioning.numPartitions === expectedPartition) + if (requiresShuffling) { + assert(node.isInstanceOf[Exchange], s"Expected a shuffle, got: ${node.child}") + } else { + assert(!node.isInstanceOf[Exchange], "Didn't expect a shuffle") + } + + case _ => fail("Expected no shuffling") + } + var reachedRestore = false + // Check that there should be no exchanges after `StateStoreRestoreExec` + executedPlan.foreachUp { p => + if (reachedRestore) { + assert(!p.isInstanceOf[Exchange], "There should be no further exchanges") + } else { + reachedRestore = p.isInstanceOf[StateStoreRestoreExec] + } + } + } + test("SPARK-21977: coalesce(1) with 0 partition RDD should be repartitioned accordingly") { val inputSource = new NonLocalRelationSource(spark) MockSourceProvider.withMockSources(inputSource) { @@ -419,29 +450,13 @@ class StreamingAggregationSuite extends StateStoreMetricsTest spark.table("agg_test").as[Long], 1L) - val restore1 = sq.asInstanceOf[StreamingQueryWrapper].streamingQuery - .lastExecution.executedPlan - .collect { case ss: StateStoreRestoreExec => ss } - .head - restore1.child match { - case wscg: WholeStageCodegenExec => - assert(wscg.outputPartitioning.numPartitions === 1) - assert(wscg.child.isInstanceOf[HashAggregateExec], "Shouldn't require shuffling") - case _ => fail("Expected no shuffling") - } + checkAggregationChain(sq, requiresShuffling = false, 1) inputSource.addData() inputSource.releaseLock() sq.processAllAvailable() - val restore2 = sq.asInstanceOf[StreamingQueryWrapper].streamingQuery - .lastExecution.executedPlan - .collect { case ss: StateStoreRestoreExec => ss } - .head - restore2.child match { - case shuffle: ShuffleExchange => assert(shuffle.newPartitioning.numPartitions === 1) - case _ => fail("Expected shuffling when there was no data") - } + checkAggregationChain(sq, requiresShuffling = true, 1) checkDataset( spark.table("agg_test").as[Long], @@ -494,16 +509,10 @@ class StreamingAggregationSuite extends StateStoreMetricsTest inputSource.releaseLock() sq.processAllAvailable() - val restore1 = sq.asInstanceOf[StreamingQueryWrapper].streamingQuery - .lastExecution.executedPlan - .collect { case ss: StateStoreRestoreExec => ss } - .head - restore1.child match { - case shuffle: ShuffleExchange => - assert(shuffle.newPartitioning.numPartitions === - spark.sessionState.conf.numShufflePartitions) - case _ => fail(s"Expected shuffling but got: ${restore1.child}") - } + checkAggregationChain( + sq, + requiresShuffling = true, + spark.sessionState.conf.numShufflePartitions) checkDataset( spark.table("agg_test").as[(Long, Long)], @@ -535,18 +544,10 @@ class StreamingAggregationSuite extends StateStoreMetricsTest inputSource.releaseLock() sq2.processAllAvailable() - val restore2 = sq2.asInstanceOf[StreamingQueryWrapper].streamingQuery - .lastExecution.executedPlan - .collect { case ss: StateStoreRestoreExec => ss } - .head - restore2.child match { - case wscg: WholeStageCodegenExec => - assert(wscg.outputPartitioning.numPartitions === - spark.sessionState.conf.numShufflePartitions) - case _ => - fail("Shouldn't require shuffling as HashAggregateExec should have asked for a " + - s"shuffle. But got: ${restore2.child}") - } + checkAggregationChain( + sq2, + requiresShuffling = false, // doesn't require extra shuffle as HashAggregate adds it + spark.sessionState.conf.numShufflePartitions) checkDataset( spark.table("agg_test").as[(Long, Long)], From c5b7f230ebbdabc373d8df1478993e08d420c1f3 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 12 Sep 2017 10:23:54 -0700 Subject: [PATCH 07/14] add required child distribution --- .../spark/sql/execution/streaming/statefulOperators.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 6d86009890c5..ccc3c770a94a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -366,6 +366,14 @@ case class StateStoreSaveExec( override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = { + if (keyExpressions.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(keyExpressions) :: Nil + } + } } /** Physical operator for executing streaming Deduplicate. */ From e178d1033f3ad4bac0223ca5d4001a95305d6cdb Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 14 Sep 2017 16:20:20 -0700 Subject: [PATCH 08/14] concise tests --- .../streaming/IncrementalExecution.scala | 29 +-- .../streaming/statefulOperators.scala | 6 +- .../FlatMapGroupsWithStateSuite.scala | 90 ------- .../spark/sql/streaming/StreamTest.scala | 7 +- .../streaming/StreamingAggregationSuite.scala | 222 +++++++----------- 5 files changed, 107 insertions(+), 247 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 2055ac0410fe..ca44891ced1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -21,10 +21,10 @@ import java.util.UUID import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.internal.Logging -import org.apache.spark.sql.{SparkSession, Strategy} +import org.apache.spark.sql.{AnalysisException, SparkSession, Strategy} import org.apache.spark.sql.catalyst.expressions.CurrentBatchTimestamp import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, SinglePartition} +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, HashPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode} import org.apache.spark.sql.execution.exchange.ShuffleExchange @@ -119,9 +119,8 @@ class IncrementalExecution( } } - override def preparations: Seq[Rule[SparkPlan]] = Seq( - state, - EnsureStatefulOpPartitioning) ++ super.preparations + override def preparations: Seq[Rule[SparkPlan]] = + Seq(state, EnsureStatefulOpPartitioning) ++ super.preparations /** No need assert supported, as this check has already been done */ override def assertSupported(): Unit = { } @@ -130,15 +129,17 @@ class IncrementalExecution( object EnsureStatefulOpPartitioning extends Rule[SparkPlan] { // Needs to be transformUp to avoid extra shuffles override def apply(plan: SparkPlan): SparkPlan = plan transformUp { - case ss: StatefulOperator => + case so: StatefulOperator => val numPartitions = plan.sqlContext.sessionState.conf.numShufflePartitions - val keys = ss.keyExpressions - val expectedPartitioning = if (keys.isEmpty) { - SinglePartition - } else { - HashPartitioning(keys, numPartitions) - } - val children = ss.children.map { child => + val distributions = so.requiredChildDistribution + val children = so.children.zip(distributions).map { case (child, reqDistribution) => + val expectedPartitioning = reqDistribution match { + case AllTuples => SinglePartition + case ClusteredDistribution(keys) => HashPartitioning(keys, numPartitions) + case _ => throw new AnalysisException("Unexpected distribution expected for " + + s"Stateful Operator: $so. Expect AllTuples or ClusteredDistribution but got " + + s"$reqDistribution.") + } if (child.outputPartitioning.guarantees(expectedPartitioning) && child.execute().getNumPartitions == expectedPartitioning.numPartitions) { child @@ -146,6 +147,6 @@ object EnsureStatefulOpPartitioning extends Rule[SparkPlan] { ShuffleExchange(expectedPartitioning, child) } } - ss.withNewChildren(children) + so.withNewChildren(children) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index ccc3c770a94a..d6566b8e6b54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -53,8 +53,6 @@ case class StatefulOperatorStateInfo( trait StatefulOperator extends SparkPlan { def stateInfo: Option[StatefulOperatorStateInfo] - def keyExpressions: Seq[Attribute] - protected def getStateInfo: StatefulOperatorStateInfo = attachTree(this) { stateInfo.getOrElse { throw new IllegalStateException("State location not present for execution") @@ -204,6 +202,10 @@ case class StateStoreRestoreExec( val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) val hasInput = iter.hasNext if (!hasInput && keyExpressions.isEmpty) { + // If our `keyExpressions` are empty, we're getting a global aggregation. In that case + // the `HashAggregateExec` will output a 0 value for the partial merge. We need to + // restore the value, so that we don't overwrite our state with a 0 value, but rather + // merge the 0 with existing state. store.iterator().map(_.value) } else { iter.flatMap { row => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 5e303a51d6ec..519b14e33ab6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -875,96 +875,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf assert(e.getMessage === "The output mode of function should be append or update") } - test("SPARK-21977: coalesce(1) should still be repartitioned when it has keyExpressions") { - val inputSource = new NonLocalRelationSource(spark) - - MockSourceProvider.withMockSources(inputSource) { - withTempDir { tempDir => - val checkpoint = new File(tempDir, "checkpoint").getAbsolutePath - val data = new File(tempDir, "data").getAbsolutePath - - def startQuery(df: Dataset[Int]): StreamingQuery = { - df.groupByKey(_ % 1) // just to give it a fake key - .flatMapGroupsWithState(OutputMode.Append(), NoTimeout()) { - (key: Int, values: Iterator[Int], state: GroupState[Int]) => - // function that returns the values that exceed the max value ever seen in past - // triggers - val existing = state.getOption.getOrElse(Int.MinValue) - values.filter { value => - val max = state.getOption.getOrElse(Int.MinValue) - if (value > max) { - state.update(value) - } - value > existing - } - } - .writeStream - .format("parquet") - .outputMode("append") - .option("checkpointLocation", checkpoint) - .start(data) - } - - - val sq = startQuery(spark.readStream - .format((new MockSourceProvider).getClass.getCanonicalName) - .load() - .coalesce(1) - .as[Int]) - - try { - - inputSource.addData(1) - inputSource.releaseLock() - sq.processAllAvailable() - - val restore1 = sq.asInstanceOf[StreamingQueryWrapper].streamingQuery - .lastExecution.executedPlan - .collect { case ss: FlatMapGroupsWithStateExec => ss } - .head - assert(restore1.child.outputPartitioning.numPartitions === - spark.sessionState.conf.numShufflePartitions) - - checkDataset( - spark.read.parquet(data).as[Int], - 1) - - } finally { - sq.stop() - } - - val sq2 = startQuery(spark.readStream - .format((new MockSourceProvider).getClass.getCanonicalName) - .load() - .coalesce(2) - .as[Int]) - - try { - sq2.processAllAvailable() - inputSource.addData(0) - inputSource.addData(3) - inputSource.addData(4) - inputSource.releaseLock() - sq2.processAllAvailable() - - val restore2 = sq.asInstanceOf[StreamingQueryWrapper].streamingQuery - .lastExecution.executedPlan - .collect { case ss: FlatMapGroupsWithStateExec => ss } - .head - assert(restore2.child.outputPartitioning.numPartitions === - spark.sessionState.conf.numShufflePartitions) - - checkDataset( - spark.read.parquet(data).as[Int], - 4, 3, 1) - - } finally { - sq2.stop() - } - } - } - } - def testWithTimeout(timeoutConf: GroupStateTimeout): Unit = { test("SPARK-20714: watermark does not fail query when timeout = " + timeoutConf) { // Function to maintain running count up to 2, and then remove the count diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 4f8764060d92..1e1400c5413d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -167,7 +167,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be case class StartStream( trigger: Trigger = Trigger.ProcessingTime(0), triggerClock: Clock = new SystemClock, - additionalConfs: Map[String, String] = Map.empty) + additionalConfs: Map[String, String] = Map.empty, + queryName: String = null) extends StreamAction /** Advance the trigger clock's time manually. */ @@ -355,7 +356,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be startedTest.foreach { action => logInfo(s"Processing test stream action: $action") action match { - case StartStream(trigger, triggerClock, additionalConfs) => + case StartStream(trigger, triggerClock, additionalConfs, queryName) => verify(currentStream == null, "stream already running") verify(triggerClock.isInstanceOf[SystemClock] || triggerClock.isInstanceOf[StreamManualClock], @@ -378,7 +379,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be sparkSession .streams .startQuery( - None, + Option(queryName), Some(metadataRoot), stream, sink, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index bc333eabab10..01b0a854c856 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ import org.apache.spark.sql.streaming.util.{MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType @@ -389,19 +390,29 @@ class StreamingAggregationSuite extends StateStoreMetricsTest CheckLastBatch((0, 0, 2), (1, 1, 3))) } + /** + * This method verifies certain properties in the SparkPlan of a streaming aggregation. + * First of all, it checks that the child of a `StateStoreRestoreExec` creates the desired + * data distribution, where the child could be an Exchange, or a `HashAggregateExec` which already + * provides the expected data distribution. + * + * The second thing it checks that the child provides the expected number of partitions. + * + * The third thing it checks that we don't add an unnecessary shuffle in-between + * `StateStoreRestoreExec` and `StateStoreSaveExec`. + */ private def checkAggregationChain( - sq: StreamingQuery, - requiresShuffling: Boolean, - expectedPartition: Int): Unit = { - val executedPlan = sq.asInstanceOf[StreamingQueryWrapper].streamingQuery - .lastExecution.executedPlan + se: StreamExecution, + expectShuffling: Boolean, + expectedPartition: Int): Boolean = { + val executedPlan = se.lastExecution.executedPlan val restore = executedPlan .collect { case ss: StateStoreRestoreExec => ss } .head restore.child match { case node: UnaryExecNode => assert(node.outputPartitioning.numPartitions === expectedPartition) - if (requiresShuffling) { + if (expectShuffling) { assert(node.isInstanceOf[Exchange], s"Expected a shuffle, got: ${node.child}") } else { assert(!node.isInstanceOf[Exchange], "Didn't expect a shuffle") @@ -418,10 +429,20 @@ class StreamingAggregationSuite extends StateStoreMetricsTest reachedRestore = p.isInstanceOf[StateStoreRestoreExec] } } + true } - test("SPARK-21977: coalesce(1) with 0 partition RDD should be repartitioned accordingly") { - val inputSource = new NonLocalRelationSource(spark) + /** Add blocks of data to the `BlockRDDBackedSource`. */ + case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData { + override def addData(query: Option[StreamExecution]): (Source, Offset) = { + data.foreach(source.addData) + source.releaseLock() + (source, LongOffset(source.counter)) + } + } + + test("SPARK-21977: coalesce(1) with 0 partition RDD should be repartitioned to 1") { + val inputSource = new BlockRDDBackedSource(spark) MockSourceProvider.withMockSources(inputSource) { withTempDir { tempDir => val aggregated: Dataset[Long] = @@ -433,150 +454,77 @@ class StreamingAggregationSuite extends StateStoreMetricsTest .count() .as[Long] - val sq = aggregated.writeStream - .format("memory") - .outputMode("complete") - .queryName("agg_test") - .option("checkpointLocation", tempDir.getAbsolutePath) - .start() - - try { - - inputSource.addData(1) - inputSource.releaseLock() - sq.processAllAvailable() - - checkDataset( - spark.table("agg_test").as[Long], - 1L) - - checkAggregationChain(sq, requiresShuffling = false, 1) - - inputSource.addData() - inputSource.releaseLock() - sq.processAllAvailable() - - checkAggregationChain(sq, requiresShuffling = true, 1) - - checkDataset( - spark.table("agg_test").as[Long], - 1L) - - inputSource.addData(2, 3) - inputSource.releaseLock() - sq.processAllAvailable() - - checkDataset( - spark.table("agg_test").as[Long], - 3L) - - inputSource.addData() - inputSource.releaseLock() - sq.processAllAvailable() - - checkDataset( - spark.table("agg_test").as[Long], - 3L) - } finally { - sq.stop() - } + testStream(aggregated, Complete())( + AddBlockData(inputSource, Seq(1)), + CheckLastBatch(1), + AssertOnQuery(se => checkAggregationChain(se, expectShuffling = false, 1)), + AddBlockData(inputSource), // create an empty trigger + AssertOnQuery(se => checkAggregationChain(se, expectShuffling = true, 1)), + CheckLastBatch(1), + AddBlockData(inputSource, Seq(2, 3)), + CheckLastBatch(3), + AddBlockData(inputSource), + CheckLastBatch(3), + StopStream + ) } } } test("SPARK-21977: coalesce(1) should still be repartitioned when it has keyExpressions") { - val inputSource = new NonLocalRelationSource(spark) + val inputSource = new BlockRDDBackedSource(spark) MockSourceProvider.withMockSources(inputSource) { withTempDir { tempDir => - val sq = spark.readStream - .format((new MockSourceProvider).getClass.getCanonicalName) - .load() - .coalesce(1) - .groupBy('a % 1) // just to give it a fake key - .count() - .as[(Long, Long)] - .writeStream - .format("memory") - .outputMode("complete") - .queryName("agg_test") - .option("checkpointLocation", tempDir.getAbsolutePath) - .start() - - try { - - inputSource.addData(1) - inputSource.releaseLock() - sq.processAllAvailable() - - checkAggregationChain( - sq, - requiresShuffling = true, - spark.sessionState.conf.numShufflePartitions) - - checkDataset( - spark.table("agg_test").as[(Long, Long)], - (0L, 1L)) - - } finally { - sq.stop() + def createDf(partitions: Int): Dataset[(Long, Long)] = { + spark.readStream + .format((new MockSourceProvider).getClass.getCanonicalName) + .load() + .coalesce(partitions) + .groupBy('a % 1) // just to give it a fake key + .count() + .as[(Long, Long)] } - val sq2 = spark.readStream - .format((new MockSourceProvider).getClass.getCanonicalName) - .load() - .coalesce(2) - .groupBy('a % 1) // just to give it a fake key - .count() - .as[(Long, Long)] - .writeStream - .format("memory") - .outputMode("complete") - .queryName("agg_test") - .option("checkpointLocation", tempDir.getAbsolutePath) - .start() - - try { - sq2.processAllAvailable() - inputSource.addData(2) - inputSource.addData(3) - inputSource.addData(4) - inputSource.releaseLock() - sq2.processAllAvailable() - - checkAggregationChain( - sq2, - requiresShuffling = false, // doesn't require extra shuffle as HashAggregate adds it - spark.sessionState.conf.numShufflePartitions) - - checkDataset( - spark.table("agg_test").as[(Long, Long)], - (0L, 4L)) - - inputSource.addData() - inputSource.releaseLock() - sq2.processAllAvailable() - - checkDataset( - spark.table("agg_test").as[(Long, Long)], - (0L, 4L)) - } finally { - sq2.stop() - } + val confs = Map(SQLConf.CHECKPOINT_LOCATION.key -> tempDir.getAbsolutePath) + + testStream(createDf(1), Complete())( + StartStream(additionalConfs = confs, queryName = "agg_test"), + AddBlockData(inputSource, Seq(1)), + AssertOnQuery { se => + checkAggregationChain( + se, + expectShuffling = true, + spark.sessionState.conf.numShufflePartitions) + }, + CheckLastBatch((0L, 1L)), + StopStream + ) + + testStream(createDf(2), Complete())( + StartStream(additionalConfs = confs, queryName = "agg_test"), + AddBlockData(inputSource, Seq(2), Seq(3), Seq(4)), + AssertOnQuery { se => + checkAggregationChain( + se, + expectShuffling = false, + spark.sessionState.conf.numShufflePartitions) + }, + CheckLastBatch((0L, 4L)), + AddBlockData(inputSource), + CheckLastBatch((0L, 4L)), + StopStream + ) } } } } /** - * LocalRelation has some optimized properties during Spark planning. In order for the bugs in - * SPARK-21977 to occur, we need to create a logical relation from an existing RDD. We use a - * BlockRDD since it accepts 0 partitions. One requirement for the one of the bugs is the use of - * `coalesce(1)`, which has several optimizations regarding [[SinglePartition]], and a 0 partition - * parentRDD. + * A Streaming Source that is backed by a BlockRDD and that can create RDDs with 0 blocks at will. */ -class NonLocalRelationSource(spark: SparkSession) extends Source { - private var counter = 0L +class BlockRDDBackedSource(spark: SparkSession) extends Source { + var counter = 0L private val blockMgr = SparkEnv.get.blockManager private var blocks: Seq[BlockId] = Seq.empty @@ -588,13 +536,11 @@ class NonLocalRelationSource(spark: SparkSession) extends Source { } synchronized { if (data.nonEmpty) { - counter += data.length val id = TestBlockId(counter.toString) blockMgr.putIterator(id, data.iterator, StorageLevel.MEMORY_ONLY) blocks ++= id :: Nil - } else { - counter += 1 } + counter += 1 } } From 6cc8c46e83b188c8f952cdf49cf6ff89604fbd43 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 14 Sep 2017 23:29:09 -0700 Subject: [PATCH 09/14] address comments --- .../spark/sql/execution/SparkPlanTest.scala | 4 + .../IncrementalExecutionRulesSuite.scala | 123 ++++++++++++++++++ .../spark/sql/streaming/StreamTest.scala | 16 ++- .../streaming/StreamingAggregationSuite.scala | 44 ++++--- 4 files changed, 164 insertions(+), 23 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/IncrementalExecutionRulesSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index b29e822add8b..8349f25f1b27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -231,6 +231,10 @@ object SparkPlanTest { } } + /** + * + */ + /** * Runs the plan * @param outputPlan SparkPlan to be executed diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/IncrementalExecutionRulesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/IncrementalExecutionRulesSuite.scala new file mode 100644 index 000000000000..9af44fac91fb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/IncrementalExecutionRulesSuite.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.util.UUID + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode} +import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchange} +import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata, StatefulOperator, StatefulOperatorStateInfo} +import org.apache.spark.sql.test.SharedSQLContext + +class IncrementalExecutionRulesSuite extends SparkPlanTest with SharedSQLContext { + + import testImplicits._ + super.beforeAll() + + private val baseDf = Seq((1, "A"), (2, "b")).toDF("num", "char") + + testEnsureStatefulOpPartitioning( + "ClusteredDistribution generates Exchange with HashPartitioning", + baseDf.queryExecution.sparkPlan, + keys => ClusteredDistribution(keys), + keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions), + expectShuffle = true) + + testEnsureStatefulOpPartitioning( + "ClusteredDistribution with coalesce(1) generates Exchange with HashPartitioning", + baseDf.coalesce(1).queryExecution.sparkPlan, + keys => ClusteredDistribution(keys), + keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions), + expectShuffle = true) + + testEnsureStatefulOpPartitioning( + "AllTuples generates Exchange with SinglePartition", + baseDf.queryExecution.sparkPlan, + keys => AllTuples, + keys => SinglePartition, + expectShuffle = true) + + testEnsureStatefulOpPartitioning( + "AllTuples with coalesce(1) doesn't need Exchange", + baseDf.coalesce(1).queryExecution.sparkPlan, + keys => AllTuples, + keys => SinglePartition, + expectShuffle = false) + + private def testEnsureStatefulOpPartitioning( + testName: String, + inputPlan: SparkPlan, + requiredDistribution: Seq[Attribute] => Distribution, + expectedPartitioning: Seq[Attribute] => Partitioning, + expectShuffle: Boolean): Unit = { + test("EnsureStatefulOpPartitioning - " + testName) { + val operator = TestOperator(inputPlan, requiredDistribution(inputPlan.output.take(1))) + val executed = executePlan(operator, OutputMode.Complete()) + if (expectShuffle) { + val exchange = executed.children.find(_.isInstanceOf[Exchange]) + if (exchange.isEmpty) { + fail(s"Was expecting an exchange but didn't get one in:\n$executed") + } + assert(exchange.get === + ShuffleExchange(expectedPartitioning(inputPlan.output.take(1)), inputPlan), + s"Exchange didn't have expected properties:\n${exchange.get}") + } else { + assert(!executed.children.exists(_.isInstanceOf[Exchange]), + s"Unexpected exchange found in:\n$executed") + } + } + } + + private def executePlan( + p: SparkPlan, + outputMode: OutputMode = OutputMode.Append()): SparkPlan = { + val execution = new IncrementalExecution( + spark, + null, + OutputMode.Complete(), + "chk", + UUID.randomUUID(), + 0L, + OffsetSeqMetadata()) { + override lazy val sparkPlan: SparkPlan = p transform { + case plan: SparkPlan => + val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap + plan transformExpressions { + case UnresolvedAttribute(Seq(u)) => + inputMap.getOrElse(u, + sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) + } + } + } + execution.executedPlan + } +} + +case class TestOperator( + child: SparkPlan, + requiredDist: Distribution) extends UnaryExecNode with StatefulOperator { + override def output: Seq[Attribute] = child.output + override def doExecute(): RDD[InternalRow] = child.execute() + override def requiredChildDistribution: Seq[Distribution] = requiredDist :: Nil + override def stateInfo: Option[StatefulOperatorStateInfo] = None +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 1e1400c5413d..e8ff5592cfef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -168,7 +168,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be trigger: Trigger = Trigger.ProcessingTime(0), triggerClock: Clock = new SystemClock, additionalConfs: Map[String, String] = Map.empty, - queryName: String = null) + checkpointLocation: String = null) extends StreamAction /** Advance the trigger clock's time manually. */ @@ -350,13 +350,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be """.stripMargin) } - val metadataRoot = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath var manualClockExpectedTime = -1L try { startedTest.foreach { action => logInfo(s"Processing test stream action: $action") action match { - case StartStream(trigger, triggerClock, additionalConfs, queryName) => + case StartStream(trigger, triggerClock, additionalConfs, checkpointLocation) => verify(currentStream == null, "stream already running") verify(triggerClock.isInstanceOf[SystemClock] || triggerClock.isInstanceOf[StreamManualClock], @@ -364,6 +363,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be if (triggerClock.isInstanceOf[StreamManualClock]) { manualClockExpectedTime = triggerClock.asInstanceOf[StreamManualClock].getTimeMillis() } + val metadataRoot = Option(checkpointLocation).getOrElse( + Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath) additionalConfs.foreach(pair => { val value = @@ -379,7 +380,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be sparkSession .streams .startQuery( - Option(queryName), + None, Some(metadataRoot), stream, sink, @@ -480,7 +481,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be verify(currentStream != null || lastStream != null, "cannot assert when no stream has been started") val streamToAssert = Option(currentStream).getOrElse(lastStream) - verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}") + try { + verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}") + } catch { + case NonFatal(e) => + failTest(s"Assert on query failed: ${a.message}", e) + } case a: Assert => val streamToAssert = Option(currentStream).getOrElse(lastStream) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 01b0a854c856..8cc5338b078c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -24,23 +24,22 @@ import org.scalatest.Assertions import org.scalatest.BeforeAndAfterAll import org.apache.spark.{SparkEnv, SparkException} -import org.apache.spark.rdd.BlockRDD +import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.Aggregate -import org.apache.spark.sql.catalyst.plans.physical.SinglePartition +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, HashPartitioning} import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode, WholeStageCodegenExec} -import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode} import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchange} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ import org.apache.spark.sql.streaming.util.{MockSourceProvider, StreamManualClock} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} object FailureSingleton { @@ -411,7 +410,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest .head restore.child match { case node: UnaryExecNode => - assert(node.outputPartitioning.numPartitions === expectedPartition) + assert(node.outputPartitioning.numPartitions === expectedPartition, + "Didn't get the expected number of partitions.") if (expectShuffling) { assert(node.isInstanceOf[Exchange], s"Expected a shuffle, got: ${node.child}") } else { @@ -435,7 +435,12 @@ class StreamingAggregationSuite extends StateStoreMetricsTest /** Add blocks of data to the `BlockRDDBackedSource`. */ case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData { override def addData(query: Option[StreamExecution]): (Source, Offset) = { - data.foreach(source.addData) + if (data.nonEmpty) { + data.foreach(source.addData) + } else { + // we would like to create empty blockRDD's so add an empty block here. + source.addData() + } source.releaseLock() (source, LongOffset(source.counter)) } @@ -457,10 +462,14 @@ class StreamingAggregationSuite extends StateStoreMetricsTest testStream(aggregated, Complete())( AddBlockData(inputSource, Seq(1)), CheckLastBatch(1), - AssertOnQuery(se => checkAggregationChain(se, expectShuffling = false, 1)), + AssertOnQuery("Verify no shuffling") { se => + checkAggregationChain(se, expectShuffling = false, 1) + }, AddBlockData(inputSource), // create an empty trigger - AssertOnQuery(se => checkAggregationChain(se, expectShuffling = true, 1)), CheckLastBatch(1), + AssertOnQuery("Verify addition of exchange operator") { se => + checkAggregationChain(se, expectShuffling = true, 1) + }, AddBlockData(inputSource, Seq(2, 3)), CheckLastBatch(3), AddBlockData(inputSource), @@ -486,31 +495,30 @@ class StreamingAggregationSuite extends StateStoreMetricsTest .as[(Long, Long)] } - val confs = Map(SQLConf.CHECKPOINT_LOCATION.key -> tempDir.getAbsolutePath) - testStream(createDf(1), Complete())( - StartStream(additionalConfs = confs, queryName = "agg_test"), + StartStream(checkpointLocation = tempDir.getAbsolutePath), AddBlockData(inputSource, Seq(1)), - AssertOnQuery { se => + CheckLastBatch((0L, 1L)), + AssertOnQuery("Verify addition of exchange operator") { se => checkAggregationChain( se, expectShuffling = true, spark.sessionState.conf.numShufflePartitions) }, - CheckLastBatch((0L, 1L)), StopStream ) testStream(createDf(2), Complete())( - StartStream(additionalConfs = confs, queryName = "agg_test"), + StartStream(checkpointLocation = tempDir.getAbsolutePath), + Execute(se => se.processAllAvailable()), AddBlockData(inputSource, Seq(2), Seq(3), Seq(4)), - AssertOnQuery { se => + CheckLastBatch((0L, 4L)), + AssertOnQuery("Verify no exchange added") { se => checkAggregationChain( se, expectShuffling = false, spark.sessionState.conf.numShufflePartitions) }, - CheckLastBatch((0L, 4L)), AddBlockData(inputSource), CheckLastBatch((0L, 4L)), StopStream From be3912552d200fced24f69c436b59a2c24639380 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Sun, 17 Sep 2017 18:32:40 -0700 Subject: [PATCH 10/14] address --- .../spark/sql/execution/SparkPlanTest.scala | 4 - .../FlatMapGroupsWithStateSuite.scala | 10 +- .../spark/sql/streaming/StreamTest.scala | 5 +- .../streaming/StreamingAggregationSuite.scala | 123 ++++++++---------- 4 files changed, 64 insertions(+), 78 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 8349f25f1b27..b29e822add8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -231,10 +231,6 @@ object SparkPlanTest { } } - /** - * - */ - /** * Runs the plan * @param outputPlan SparkPlan to be executed diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 519b14e33ab6..9d74a5c701ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.streaming -import java.io.File import java.sql.Date import java.util.concurrent.ConcurrentHashMap @@ -25,17 +24,16 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction -import org.apache.spark.sql.{DataFrame, Dataset, Encoder} +import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.execution.{RDDScanExec, WholeStageCodegenExec} -import org.apache.spark.sql.execution.exchange.ShuffleExchange -import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream, StreamingQueryWrapper} +import org.apache.spark.sql.execution.RDDScanExec +import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream} import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair} import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore -import org.apache.spark.sql.streaming.util.{MockSourceProvider, StreamManualClock} +import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.{DataType, IntegerType} /** Class to check custom state types */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index e8ff5592cfef..70b39b934071 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -351,6 +351,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } var manualClockExpectedTime = -1L + val defaultCheckpointLocation = + Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath try { startedTest.foreach { action => logInfo(s"Processing test stream action: $action") @@ -363,8 +365,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be if (triggerClock.isInstanceOf[StreamManualClock]) { manualClockExpectedTime = triggerClock.asInstanceOf[StreamManualClock].getTimeMillis() } - val metadataRoot = Option(checkpointLocation).getOrElse( - Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath) + val metadataRoot = Option(checkpointLocation).getOrElse(defaultCheckpointLocation) additionalConfs.foreach(pair => { val value = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 8cc5338b078c..642149d0ae06 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -24,22 +24,20 @@ import org.scalatest.Assertions import org.scalatest.BeforeAndAfterAll import org.apache.spark.{SparkEnv, SparkException} -import org.apache.spark.rdd.{BlockRDD, RDD} +import org.apache.spark.rdd.BlockRDD import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.Aggregate -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, HashPartitioning} import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode} -import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchange} +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.exchange.Exchange import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.OutputMode._ import org.apache.spark.sql.streaming.util.{MockSourceProvider, StreamManualClock} -import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.types.StructType import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} object FailureSingleton { @@ -435,13 +433,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest /** Add blocks of data to the `BlockRDDBackedSource`. */ case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData { override def addData(query: Option[StreamExecution]): (Source, Offset) = { - if (data.nonEmpty) { - data.foreach(source.addData) - } else { - // we would like to create empty blockRDD's so add an empty block here. - source.addData() - } - source.releaseLock() + source.addBlocks(data: _*) (source, LongOffset(source.counter)) } } @@ -449,50 +441,60 @@ class StreamingAggregationSuite extends StateStoreMetricsTest test("SPARK-21977: coalesce(1) with 0 partition RDD should be repartitioned to 1") { val inputSource = new BlockRDDBackedSource(spark) MockSourceProvider.withMockSources(inputSource) { - withTempDir { tempDir => - val aggregated: Dataset[Long] = - spark.readStream - .format((new MockSourceProvider).getClass.getCanonicalName) - .load() - .coalesce(1) - .groupBy() - .count() - .as[Long] - - testStream(aggregated, Complete())( - AddBlockData(inputSource, Seq(1)), - CheckLastBatch(1), - AssertOnQuery("Verify no shuffling") { se => - checkAggregationChain(se, expectShuffling = false, 1) - }, - AddBlockData(inputSource), // create an empty trigger - CheckLastBatch(1), - AssertOnQuery("Verify addition of exchange operator") { se => - checkAggregationChain(se, expectShuffling = true, 1) - }, - AddBlockData(inputSource, Seq(2, 3)), - CheckLastBatch(3), - AddBlockData(inputSource), - CheckLastBatch(3), - StopStream - ) - } + // `coalesce(1)` changes the partitioning of data to `SinglePartition` which by default + // satisfies the required distributions of all aggregations. Therefore in our SparkPlan, we + // don't have any shuffling. However, `coalesce(1)` only guarantees that the RDD has at most 1 + // partition. Which means that if we have an input RDD with 0 partitions, nothing gets + // executed. Therefore the StateStore's don't save any delta files for a given trigger. This + // then leads to `FileNotFoundException`s in the subsequent batch. + // This isn't the only problem though. Once we introduce a shuffle before + // `StateStoreRestoreExec`, the input to the operator is an empty iterator. When performing + // `groupBy().agg(...)`, `HashAggregateExec` returns a `0` value for all aggregations. If + // we fail to restore the previous state in `StateStoreRestoreExec`, we save the 0 value in + // `StateStoreSaveExec` losing all previous state. + val aggregated: Dataset[Long] = + spark.readStream.format((new MockSourceProvider).getClass.getCanonicalName) + .load().coalesce(1).groupBy().count().as[Long] + + testStream(aggregated, Complete())( + AddBlockData(inputSource, Seq(1)), + CheckLastBatch(1), + AssertOnQuery("Verify no shuffling") { se => + checkAggregationChain(se, expectShuffling = false, 1) + }, + AddBlockData(inputSource), // create an empty trigger + CheckLastBatch(1), + AssertOnQuery("Verify addition of exchange operator") { se => + checkAggregationChain(se, expectShuffling = true, 1) + }, + AddBlockData(inputSource, Seq(2, 3)), + CheckLastBatch(3), + AddBlockData(inputSource), + CheckLastBatch(3), + StopStream + ) } } - test("SPARK-21977: coalesce(1) should still be repartitioned when it has keyExpressions") { + test("SPARK-21977: coalesce(1) with aggregation should still be repartitioned when it " + + "has non-empty grouping keys") { val inputSource = new BlockRDDBackedSource(spark) MockSourceProvider.withMockSources(inputSource) { withTempDir { tempDir => + // `coalesce(1)` changes the partitioning of data to `SinglePartition` which by default + // satisfies the required distributions of all aggregations. However, when we have + // non-empty grouping keys, in streaming, we must repartition to + // `spark.sql.shuffle.partitions`, otherwise only a single StateStore is used to process + // all keys. This may be fine, however, if the user removes the coalesce(1) or changes to + // a `coalesce(2)` for example, then the default behavior is to shuffle to + // `spark.sql.shuffle.partitions` many StateStores. When this happens, all StateStore's + // except 1 will be missing their previous delta files, which causes the stream to fail + // with FileNotFoundException. def createDf(partitions: Int): Dataset[(Long, Long)] = { spark.readStream .format((new MockSourceProvider).getClass.getCanonicalName) - .load() - .coalesce(partitions) - .groupBy('a % 1) // just to give it a fake key - .count() - .as[(Long, Long)] + .load().coalesce(partitions).groupBy('a % 1).count().as[(Long, Long)] } testStream(createDf(1), Complete())( @@ -536,29 +538,18 @@ class BlockRDDBackedSource(spark: SparkSession) extends Source { private val blockMgr = SparkEnv.get.blockManager private var blocks: Seq[BlockId] = Seq.empty - private var streamLock: CountDownLatch = new CountDownLatch(1) - - def addData(data: Int*): Unit = { - if (streamLock.getCount == 0) { - streamLock = new CountDownLatch(1) - } - synchronized { - if (data.nonEmpty) { - val id = TestBlockId(counter.toString) - blockMgr.putIterator(id, data.iterator, StorageLevel.MEMORY_ONLY) - blocks ++= id :: Nil - } + def addBlocks(dataBlocks: Seq[Int]*): Unit = synchronized { + dataBlocks.foreach { data => + val id = TestBlockId(counter.toString) + blockMgr.putIterator(id, data.iterator, StorageLevel.MEMORY_ONLY) + blocks ++= id :: Nil counter += 1 } + counter += 1 } - def releaseLock(): Unit = streamLock.countDown() - - override def getOffset: Option[Offset] = { - streamLock.await() - synchronized { - if (counter == 0) None else Some(LongOffset(counter)) - } + override def getOffset: Option[Offset] = synchronized { + if (counter == 0) None else Some(LongOffset(counter)) } override def getBatch(start: Option[Offset], end: Offset): DataFrame = synchronized { From f34fc8a11c0e2d7bf8b36446d3fbfcdc46e7a0ca Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 19 Sep 2017 09:16:50 -0700 Subject: [PATCH 11/14] address --- ...> EnsureStatefulOpPartitioningSuite.scala} | 47 +++++++++++-------- 1 file changed, 28 insertions(+), 19 deletions(-) rename sql/core/src/test/scala/org/apache/spark/sql/streaming/{IncrementalExecutionRulesSuite.scala => EnsureStatefulOpPartitioningSuite.scala} (71%) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/IncrementalExecutionRulesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala similarity index 71% rename from sql/core/src/test/scala/org/apache/spark/sql/streaming/IncrementalExecutionRulesSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala index 9af44fac91fb..4cc125a72327 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/IncrementalExecutionRulesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchange} import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata, StatefulOperator, StatefulOperatorStateInfo} import org.apache.spark.sql.test.SharedSQLContext -class IncrementalExecutionRulesSuite extends SparkPlanTest with SharedSQLContext { +class EnsureStatefulOpPartitioningSuite extends SparkPlanTest with SharedSQLContext { import testImplicits._ super.beforeAll() @@ -39,39 +39,46 @@ class IncrementalExecutionRulesSuite extends SparkPlanTest with SharedSQLContext testEnsureStatefulOpPartitioning( "ClusteredDistribution generates Exchange with HashPartitioning", baseDf.queryExecution.sparkPlan, - keys => ClusteredDistribution(keys), - keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions), + requiredDistribution = keys => ClusteredDistribution(keys), + expectedPartitioning = + keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions), expectShuffle = true) testEnsureStatefulOpPartitioning( "ClusteredDistribution with coalesce(1) generates Exchange with HashPartitioning", baseDf.coalesce(1).queryExecution.sparkPlan, - keys => ClusteredDistribution(keys), - keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions), + requiredDistribution = keys => ClusteredDistribution(keys), + expectedPartitioning = + keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions), expectShuffle = true) testEnsureStatefulOpPartitioning( "AllTuples generates Exchange with SinglePartition", baseDf.queryExecution.sparkPlan, - keys => AllTuples, - keys => SinglePartition, + requiredDistribution = _ => AllTuples, + expectedPartitioning = _ => SinglePartition, expectShuffle = true) testEnsureStatefulOpPartitioning( "AllTuples with coalesce(1) doesn't need Exchange", baseDf.coalesce(1).queryExecution.sparkPlan, - keys => AllTuples, - keys => SinglePartition, + requiredDistribution = _ => AllTuples, + expectedPartitioning = _ => SinglePartition, expectShuffle = false) + /** + * For `StatefulOperator` with the given `requiredChildDistribution`, and child SparkPlan + * `inputPlan`, ensures that the incremental planner adds exchanges, if required, in order to + * ensure the expected partitioning. + */ private def testEnsureStatefulOpPartitioning( testName: String, inputPlan: SparkPlan, requiredDistribution: Seq[Attribute] => Distribution, expectedPartitioning: Seq[Attribute] => Partitioning, expectShuffle: Boolean): Unit = { - test("EnsureStatefulOpPartitioning - " + testName) { - val operator = TestOperator(inputPlan, requiredDistribution(inputPlan.output.take(1))) + test(testName) { + val operator = TestStatefulOperator(inputPlan, requiredDistribution(inputPlan.output.take(1))) val executed = executePlan(operator, OutputMode.Complete()) if (expectShuffle) { val exchange = executed.children.find(_.isInstanceOf[Exchange]) @@ -88,6 +95,7 @@ class IncrementalExecutionRulesSuite extends SparkPlanTest with SharedSQLContext } } + /** Executes a SparkPlan using the IncrementalPlanner used for Structured Streaming. */ private def executePlan( p: SparkPlan, outputMode: OutputMode = OutputMode.Append()): SparkPlan = { @@ -111,13 +119,14 @@ class IncrementalExecutionRulesSuite extends SparkPlanTest with SharedSQLContext } execution.executedPlan } -} -case class TestOperator( - child: SparkPlan, - requiredDist: Distribution) extends UnaryExecNode with StatefulOperator { - override def output: Seq[Attribute] = child.output - override def doExecute(): RDD[InternalRow] = child.execute() - override def requiredChildDistribution: Seq[Distribution] = requiredDist :: Nil - override def stateInfo: Option[StatefulOperatorStateInfo] = None + /** Used to emulate a [[StatefulOperator]] with the given requiredDistribution. */ + case class TestStatefulOperator( + child: SparkPlan, + requiredDist: Distribution) extends UnaryExecNode with StatefulOperator { + override def output: Seq[Attribute] = child.output + override def doExecute(): RDD[InternalRow] = child.execute() + override def requiredChildDistribution: Seq[Distribution] = requiredDist :: Nil + override def stateInfo: Option[StatefulOperatorStateInfo] = None + } } From d0e909433020e36ff42700e815f718f49282016d Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 19 Sep 2017 14:17:22 -0700 Subject: [PATCH 12/14] move things around --- .../streaming/StreamingAggregationSuite.scala | 73 +++++++++---------- 1 file changed, 36 insertions(+), 37 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 642149d0ae06..995cea3b37d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.streaming import java.util.{Locale, TimeZone} -import java.util.concurrent.CountDownLatch import org.scalatest.Assertions import org.scalatest.BeforeAndAfterAll @@ -430,14 +429,6 @@ class StreamingAggregationSuite extends StateStoreMetricsTest true } - /** Add blocks of data to the `BlockRDDBackedSource`. */ - case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData { - override def addData(query: Option[StreamExecution]): (Source, Offset) = { - source.addBlocks(data: _*) - (source, LongOffset(source.counter)) - } - } - test("SPARK-21977: coalesce(1) with 0 partition RDD should be repartitioned to 1") { val inputSource = new BlockRDDBackedSource(spark) MockSourceProvider.withMockSources(inputSource) { @@ -528,38 +519,46 @@ class StreamingAggregationSuite extends StateStoreMetricsTest } } } -} -/** - * A Streaming Source that is backed by a BlockRDD and that can create RDDs with 0 blocks at will. - */ -class BlockRDDBackedSource(spark: SparkSession) extends Source { - var counter = 0L - private val blockMgr = SparkEnv.get.blockManager - private var blocks: Seq[BlockId] = Seq.empty - - def addBlocks(dataBlocks: Seq[Int]*): Unit = synchronized { - dataBlocks.foreach { data => - val id = TestBlockId(counter.toString) - blockMgr.putIterator(id, data.iterator, StorageLevel.MEMORY_ONLY) - blocks ++= id :: Nil - counter += 1 + /** Add blocks of data to the `BlockRDDBackedSource`. */ + case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData { + override def addData(query: Option[StreamExecution]): (Source, Offset) = { + source.addBlocks(data: _*) + (source, LongOffset(source.counter)) } - counter += 1 } - override def getOffset: Option[Offset] = synchronized { - if (counter == 0) None else Some(LongOffset(counter)) - } + /** + * A Streaming Source that is backed by a BlockRDD and that can create RDDs with 0 blocks at will. + */ + class BlockRDDBackedSource(spark: SparkSession) extends Source { + var counter = 0L + private val blockMgr = SparkEnv.get.blockManager + private var blocks: Seq[BlockId] = Seq.empty + + def addBlocks(dataBlocks: Seq[Int]*): Unit = synchronized { + dataBlocks.foreach { data => + val id = TestBlockId(counter.toString) + blockMgr.putIterator(id, data.iterator, StorageLevel.MEMORY_ONLY) + blocks ++= id :: Nil + counter += 1 + } + counter += 1 + } - override def getBatch(start: Option[Offset], end: Offset): DataFrame = synchronized { - val rdd = new BlockRDD[Int](spark.sparkContext, blocks.toArray) - .map(i => InternalRow(i)) // we don't really care about the values in this test - blocks = Seq.empty - spark.internalCreateDataFrame(rdd, schema, isStreaming = true).toDF() - } - override def schema: StructType = MockSourceProvider.fakeSchema - override def stop(): Unit = { - blockMgr.getMatchingBlockIds(_.isInstanceOf[TestBlockId]).foreach(blockMgr.removeBlock(_)) + override def getOffset: Option[Offset] = synchronized { + if (counter == 0) None else Some(LongOffset(counter)) + } + + override def getBatch(start: Option[Offset], end: Offset): DataFrame = synchronized { + val rdd = new BlockRDD[Int](spark.sparkContext, blocks.toArray) + .map(i => InternalRow(i)) // we don't really care about the values in this test + blocks = Seq.empty + spark.internalCreateDataFrame(rdd, schema, isStreaming = true).toDF() + } + override def schema: StructType = MockSourceProvider.fakeSchema + override def stop(): Unit = { + blockMgr.getMatchingBlockIds(_.isInstanceOf[TestBlockId]).foreach(blockMgr.removeBlock(_)) + } } } From 8a6eafef056b2a64ee0be07ce886ad69dc295537 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 19 Sep 2017 14:19:06 -0700 Subject: [PATCH 13/14] i think I found it --- .../spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala index 4cc125a72327..9b69109cbd4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala @@ -120,7 +120,7 @@ class EnsureStatefulOpPartitioningSuite extends SparkPlanTest with SharedSQLCont execution.executedPlan } - /** Used to emulate a [[StatefulOperator]] with the given requiredDistribution. */ + /** Used to emulate a `StatefulOperator` with the given requiredDistribution. */ case class TestStatefulOperator( child: SparkPlan, requiredDist: Distribution) extends UnaryExecNode with StatefulOperator { From 4eb7f4f6df3f2d5ae831bf15715651598e52c3e6 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 19 Sep 2017 16:10:52 -0700 Subject: [PATCH 14/14] Update EnsureStatefulOpPartitioningSuite.scala --- .../EnsureStatefulOpPartitioningSuite.scala | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala index 9b69109cbd4d..66c0263e872b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala @@ -119,14 +119,14 @@ class EnsureStatefulOpPartitioningSuite extends SparkPlanTest with SharedSQLCont } execution.executedPlan } +} - /** Used to emulate a `StatefulOperator` with the given requiredDistribution. */ - case class TestStatefulOperator( - child: SparkPlan, - requiredDist: Distribution) extends UnaryExecNode with StatefulOperator { - override def output: Seq[Attribute] = child.output - override def doExecute(): RDD[InternalRow] = child.execute() - override def requiredChildDistribution: Seq[Distribution] = requiredDist :: Nil - override def stateInfo: Option[StatefulOperatorStateInfo] = None - } +/** Used to emulate a `StatefulOperator` with the given requiredDistribution. */ +case class TestStatefulOperator( + child: SparkPlan, + requiredDist: Distribution) extends UnaryExecNode with StatefulOperator { + override def output: Seq[Attribute] = child.output + override def doExecute(): RDD[InternalRow] = child.execute() + override def requiredChildDistribution: Seq[Distribution] = requiredDist :: Nil + override def stateInfo: Option[StatefulOperatorStateInfo] = None }