Skip to content

Commit ad0056b

Browse files
committed
WIP towards improved CheckpointSuite sleep logic.
1 parent 5c31b8a commit ad0056b

File tree

2 files changed

+32
-14
lines changed

2 files changed

+32
-14
lines changed

streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import java.io.File
2121
import java.nio.charset.Charset
2222

2323
import scala.collection.mutable.ArrayBuffer
24+
import scala.concurrent.duration._
25+
import scala.language.postfixOps
2426
import scala.reflect.ClassTag
2527

2628
import com.google.common.io.Files
@@ -88,7 +90,7 @@ class CheckpointSuite extends TestSuiteBase {
8890
// Run till a time such that at least one RDD in the stream should have been checkpointed,
8991
// then check whether some RDD has been checkpointed or not
9092
ssc.start()
91-
advanceTimeWithRealDelay(ssc, firstNumBatches)
93+
advanceTimeWithRealDelay(ssc, firstNumBatches.toInt)
9294
logInfo("Checkpoint data of state stream = \n" + stateStream.checkpointData)
9395
assert(!stateStream.checkpointData.currentCheckpointFiles.isEmpty,
9496
"No checkpointed RDDs in state stream before first failure")
@@ -102,7 +104,7 @@ class CheckpointSuite extends TestSuiteBase {
102104
// Run till a further time such that previous checkpoint files in the stream would be deleted
103105
// and check whether the earlier checkpoint files are deleted
104106
val checkpointFiles = stateStream.checkpointData.currentCheckpointFiles.map(x => new File(x._2))
105-
advanceTimeWithRealDelay(ssc, secondNumBatches)
107+
advanceTimeWithRealDelay(ssc, secondNumBatches.toInt)
106108
checkpointFiles.foreach(file =>
107109
assert(!file.exists, "Checkpoint file '" + file + "' was not deleted"))
108110
ssc.stop()
@@ -409,8 +411,7 @@ class CheckpointSuite extends TestSuiteBase {
409411
ssc.start()
410412
val output = advanceTimeWithRealDelay[V](ssc, initialNumBatches)
411413
ssc.stop()
412-
verifyOutput[V](output, expectedOutput.take(initialNumBatches), true)
413-
Thread.sleep(1000)
414+
verifyOutput(output, expectedOutput.take(initialNumBatches), useSet = true)
414415

415416
// Restart and complete the computation from checkpoint file
416417
logInfo(
@@ -419,10 +420,13 @@ class CheckpointSuite extends TestSuiteBase {
419420
"\n-------------------------------------------\n"
420421
)
421422
ssc = new StreamingContext(checkpointDir)
423+
val waiter = new StreamingTestWaiter(ssc)
422424
ssc.start()
425+
// Wait for the last batch before restart to be re-processed:
426+
waiter.waitForTotalBatchesCompleted(1, timeout = 10 seconds)
423427
val outputNew = advanceTimeWithRealDelay[V](ssc, nextNumBatches)
424428
// the first element will be re-processed data of the last batch before restart
425-
verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true)
429+
verifyOutput(outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), useSet = true)
426430
ssc.stop()
427431
ssc = null
428432
}
@@ -431,15 +435,15 @@ class CheckpointSuite extends TestSuiteBase {
431435
* Advances the manual clock on the streaming scheduler by given number of batches.
432436
* It also waits for the expected amount of time for each batch.
433437
*/
434-
def advanceTimeWithRealDelay[V: ClassTag](ssc: StreamingContext, numBatches: Long): Seq[Seq[V]] = {
438+
def advanceTimeWithRealDelay[V: ClassTag](ssc: StreamingContext, numBatches: Int): Seq[Seq[V]] = {
435439
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
440+
val waiter = new StreamingTestWaiter(ssc)
436441
logInfo("Manual clock before advancing = " + clock.time)
437-
for (i <- 1 to numBatches.toInt) {
442+
for (i <- 1 to numBatches) {
438443
clock.addToTime(batchDuration.milliseconds)
439-
Thread.sleep(batchDuration.milliseconds)
444+
waiter.waitForTotalBatchesCompleted(i, timeout = 10 seconds)
440445
}
441446
logInfo("Manual clock after advancing = " + clock.time)
442-
Thread.sleep(batchDuration.milliseconds)
443447

444448
val outputStream = ssc.graph.getOutputStreams.filter { dstream =>
445449
dstream.isInstanceOf[TestOutputStreamWithPartitions[V]]

streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
package org.apache.spark.streaming
1919

2020
import java.io.{ObjectInputStream, IOException}
21+
import java.util.concurrent.TimeoutException
2122

2223
import scala.collection.mutable.ArrayBuffer
2324
import scala.collection.mutable.SynchronizedBuffer
25+
import scala.concurrent.duration.{Duration => SDuration}
2426
import scala.reflect.ClassTag
2527

2628
import org.scalatest.{BeforeAndAfter, FunSuite}
@@ -124,13 +126,25 @@ class TestOutputStreamWithPartitions[T: ClassTag](parent: DStream[T],
124126
}
125127
ssc.addStreamingListener(listener)
126128

129+
def getNumCompletedBatches: Int = this.synchronized {
130+
numCompletedBatches
131+
}
132+
127133
/**
128-
* Block until a batch completes.
134+
* Block until the number of completed batches reaches the given threshold.
129135
*/
130-
def waitForBatchToComplete(): Unit = this.synchronized {
131-
val currentBatchesCompleted = numCompletedBatches
132-
while (numCompletedBatches < currentBatchesCompleted + 1) {
133-
this.wait()
136+
def waitForTotalBatchesCompleted(
137+
targetNumBatches: Int,
138+
timeout: SDuration = SDuration.Inf): Unit = this.synchronized {
139+
val startTime = System.nanoTime
140+
def timedOut = timeout < SDuration.Inf && (System.nanoTime - startTime) >= timeout.toNanos
141+
def successful = getNumCompletedBatches >= targetNumBatches
142+
while (!timedOut && !successful) {
143+
this.wait(timeout.toMillis)
144+
}
145+
if (!successful && timedOut) {
146+
throw new TimeoutException(s"Waited for $targetNumBatches completed batches, but only" +
147+
s" $numCompletedBatches have completed after $timeout")
134148
}
135149
}
136150
}

0 commit comments

Comments
 (0)