Skip to content

Commit be39125

Browse files
committed
address
1 parent 6cc8c46 commit be39125

File tree

4 files changed

+64
-78
lines changed

4 files changed

+64
-78
lines changed

sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,6 @@ object SparkPlanTest {
231231
}
232232
}
233233

234-
/**
235-
*
236-
*/
237-
238234
/**
239235
* Runs the plan
240236
* @param outputPlan SparkPlan to be executed

sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,23 @@
1717

1818
package org.apache.spark.sql.streaming
1919

20-
import java.io.File
2120
import java.sql.Date
2221
import java.util.concurrent.ConcurrentHashMap
2322

2423
import org.scalatest.BeforeAndAfterAll
2524

2625
import org.apache.spark.SparkException
2726
import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction
28-
import org.apache.spark.sql.{DataFrame, Dataset, Encoder}
27+
import org.apache.spark.sql.Encoder
2928
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
3029
import org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState
3130
import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning
3231
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
33-
import org.apache.spark.sql.execution.{RDDScanExec, WholeStageCodegenExec}
34-
import org.apache.spark.sql.execution.exchange.ShuffleExchange
35-
import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream, StreamingQueryWrapper}
32+
import org.apache.spark.sql.execution.RDDScanExec
33+
import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream}
3634
import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair}
3735
import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore
38-
import org.apache.spark.sql.streaming.util.{MockSourceProvider, StreamManualClock}
36+
import org.apache.spark.sql.streaming.util.StreamManualClock
3937
import org.apache.spark.sql.types.{DataType, IntegerType}
4038

4139
/** Class to check custom state types */

sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
351351
}
352352

353353
var manualClockExpectedTime = -1L
354+
val defaultCheckpointLocation =
355+
Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
354356
try {
355357
startedTest.foreach { action =>
356358
logInfo(s"Processing test stream action: $action")
@@ -363,8 +365,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
363365
if (triggerClock.isInstanceOf[StreamManualClock]) {
364366
manualClockExpectedTime = triggerClock.asInstanceOf[StreamManualClock].getTimeMillis()
365367
}
366-
val metadataRoot = Option(checkpointLocation).getOrElse(
367-
Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath)
368+
val metadataRoot = Option(checkpointLocation).getOrElse(defaultCheckpointLocation)
368369

369370
additionalConfs.foreach(pair => {
370371
val value =

sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala

Lines changed: 57 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,20 @@ import org.scalatest.Assertions
2424
import org.scalatest.BeforeAndAfterAll
2525

2626
import org.apache.spark.{SparkEnv, SparkException}
27-
import org.apache.spark.rdd.{BlockRDD, RDD}
27+
import org.apache.spark.rdd.BlockRDD
2828
import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession}
2929
import org.apache.spark.sql.catalyst.InternalRow
30-
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
3130
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
32-
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, HashPartitioning}
3331
import org.apache.spark.sql.catalyst.util.DateTimeUtils
34-
import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode}
35-
import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchange}
32+
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
33+
import org.apache.spark.sql.execution.exchange.Exchange
3634
import org.apache.spark.sql.execution.streaming._
3735
import org.apache.spark.sql.execution.streaming.state.StateStore
3836
import org.apache.spark.sql.expressions.scalalang.typed
3937
import org.apache.spark.sql.functions._
4038
import org.apache.spark.sql.streaming.OutputMode._
4139
import org.apache.spark.sql.streaming.util.{MockSourceProvider, StreamManualClock}
42-
import org.apache.spark.sql.types.{IntegerType, StructType}
40+
import org.apache.spark.sql.types.StructType
4341
import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId}
4442

4543
object FailureSingleton {
@@ -435,64 +433,68 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
435433
/** Add blocks of data to the `BlockRDDBackedSource`. */
436434
case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData {
437435
override def addData(query: Option[StreamExecution]): (Source, Offset) = {
438-
if (data.nonEmpty) {
439-
data.foreach(source.addData)
440-
} else {
441-
// we would like to create empty blockRDD's so add an empty block here.
442-
source.addData()
443-
}
444-
source.releaseLock()
436+
source.addBlocks(data: _*)
445437
(source, LongOffset(source.counter))
446438
}
447439
}
448440

449441
test("SPARK-21977: coalesce(1) with 0 partition RDD should be repartitioned to 1") {
450442
val inputSource = new BlockRDDBackedSource(spark)
451443
MockSourceProvider.withMockSources(inputSource) {
452-
withTempDir { tempDir =>
453-
val aggregated: Dataset[Long] =
454-
spark.readStream
455-
.format((new MockSourceProvider).getClass.getCanonicalName)
456-
.load()
457-
.coalesce(1)
458-
.groupBy()
459-
.count()
460-
.as[Long]
461-
462-
testStream(aggregated, Complete())(
463-
AddBlockData(inputSource, Seq(1)),
464-
CheckLastBatch(1),
465-
AssertOnQuery("Verify no shuffling") { se =>
466-
checkAggregationChain(se, expectShuffling = false, 1)
467-
},
468-
AddBlockData(inputSource), // create an empty trigger
469-
CheckLastBatch(1),
470-
AssertOnQuery("Verify addition of exchange operator") { se =>
471-
checkAggregationChain(se, expectShuffling = true, 1)
472-
},
473-
AddBlockData(inputSource, Seq(2, 3)),
474-
CheckLastBatch(3),
475-
AddBlockData(inputSource),
476-
CheckLastBatch(3),
477-
StopStream
478-
)
479-
}
444+
// `coalesce(1)` changes the partitioning of data to `SinglePartition` which by default
445+
// satisfies the required distributions of all aggregations. Therefore in our SparkPlan, we
446+
// don't have any shuffling. However, `coalesce(1)` only guarantees that the RDD has at most 1
447+
// partition. Which means that if we have an input RDD with 0 partitions, nothing gets
448+
// executed. Therefore the StateStore's don't save any delta files for a given trigger. This
449+
// then leads to `FileNotFoundException`s in the subsequent batch.
450+
// This isn't the only problem though. Once we introduce a shuffle before
451+
// `StateStoreRestoreExec`, the input to the operator is an empty iterator. When performing
452+
// `groupBy().agg(...)`, `HashAggregateExec` returns a `0` value for all aggregations. If
453+
// we fail to restore the previous state in `StateStoreRestoreExec`, we save the 0 value in
454+
// `StateStoreSaveExec` losing all previous state.
455+
val aggregated: Dataset[Long] =
456+
spark.readStream.format((new MockSourceProvider).getClass.getCanonicalName)
457+
.load().coalesce(1).groupBy().count().as[Long]
458+
459+
testStream(aggregated, Complete())(
460+
AddBlockData(inputSource, Seq(1)),
461+
CheckLastBatch(1),
462+
AssertOnQuery("Verify no shuffling") { se =>
463+
checkAggregationChain(se, expectShuffling = false, 1)
464+
},
465+
AddBlockData(inputSource), // create an empty trigger
466+
CheckLastBatch(1),
467+
AssertOnQuery("Verify addition of exchange operator") { se =>
468+
checkAggregationChain(se, expectShuffling = true, 1)
469+
},
470+
AddBlockData(inputSource, Seq(2, 3)),
471+
CheckLastBatch(3),
472+
AddBlockData(inputSource),
473+
CheckLastBatch(3),
474+
StopStream
475+
)
480476
}
481477
}
482478

483-
test("SPARK-21977: coalesce(1) should still be repartitioned when it has keyExpressions") {
479+
test("SPARK-21977: coalesce(1) with aggregation should still be repartitioned when it " +
480+
"has non-empty grouping keys") {
484481
val inputSource = new BlockRDDBackedSource(spark)
485482
MockSourceProvider.withMockSources(inputSource) {
486483
withTempDir { tempDir =>
487484

485+
// `coalesce(1)` changes the partitioning of data to `SinglePartition` which by default
486+
// satisfies the required distributions of all aggregations. However, when we have
487+
// non-empty grouping keys, in streaming, we must repartition to
488+
// `spark.sql.shuffle.partitions`, otherwise only a single StateStore is used to process
489+
// all keys. This may be fine, however, if the user removes the coalesce(1) or changes to
490+
// a `coalesce(2)` for example, then the default behavior is to shuffle to
491+
// `spark.sql.shuffle.partitions` many StateStores. When this happens, all StateStore's
492+
// except 1 will be missing their previous delta files, which causes the stream to fail
493+
// with FileNotFoundException.
488494
def createDf(partitions: Int): Dataset[(Long, Long)] = {
489495
spark.readStream
490496
.format((new MockSourceProvider).getClass.getCanonicalName)
491-
.load()
492-
.coalesce(partitions)
493-
.groupBy('a % 1) // just to give it a fake key
494-
.count()
495-
.as[(Long, Long)]
497+
.load().coalesce(partitions).groupBy('a % 1).count().as[(Long, Long)]
496498
}
497499

498500
testStream(createDf(1), Complete())(
@@ -536,29 +538,18 @@ class BlockRDDBackedSource(spark: SparkSession) extends Source {
536538
private val blockMgr = SparkEnv.get.blockManager
537539
private var blocks: Seq[BlockId] = Seq.empty
538540

539-
private var streamLock: CountDownLatch = new CountDownLatch(1)
540-
541-
def addData(data: Int*): Unit = {
542-
if (streamLock.getCount == 0) {
543-
streamLock = new CountDownLatch(1)
544-
}
545-
synchronized {
546-
if (data.nonEmpty) {
547-
val id = TestBlockId(counter.toString)
548-
blockMgr.putIterator(id, data.iterator, StorageLevel.MEMORY_ONLY)
549-
blocks ++= id :: Nil
550-
}
541+
def addBlocks(dataBlocks: Seq[Int]*): Unit = synchronized {
542+
dataBlocks.foreach { data =>
543+
val id = TestBlockId(counter.toString)
544+
blockMgr.putIterator(id, data.iterator, StorageLevel.MEMORY_ONLY)
545+
blocks ++= id :: Nil
551546
counter += 1
552547
}
548+
counter += 1
553549
}
554550

555-
def releaseLock(): Unit = streamLock.countDown()
556-
557-
override def getOffset: Option[Offset] = {
558-
streamLock.await()
559-
synchronized {
560-
if (counter == 0) None else Some(LongOffset(counter))
561-
}
551+
override def getOffset: Option[Offset] = synchronized {
552+
if (counter == 0) None else Some(LongOffset(counter))
562553
}
563554

564555
override def getBatch(start: Option[Offset], end: Offset): DataFrame = synchronized {

0 commit comments

Comments
 (0)