Skip to content

Commit 15b48ee

Browse files
committed
Replace several TestWaiter methods w/ ScalaTest eventually.
1 parent fffc51c commit 15b48ee

File tree

3 files changed

+27
-42
lines changed

3 files changed

+27
-42
lines changed

streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.apache.hadoop.fs.{FileSystem, Path}
2929
import org.apache.hadoop.io.{IntWritable, Text}
3030
import org.apache.hadoop.mapred.TextOutputFormat
3131
import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat}
32+
import org.scalatest.concurrent.Eventually._
3233

3334
import org.apache.spark.streaming.dstream.{DStream, FileInputDStream}
3435
import org.apache.spark.streaming.util.ManualClock
@@ -360,19 +361,25 @@ class CheckpointSuite extends TestSuiteBase {
360361
outputStream.register()
361362
ssc.start()
362363

363-
clock.addToTime(batchDuration.milliseconds)
364+
// Advance half a batch so that the first file is created after the StreamingContext starts
365+
clock.addToTime(batchDuration.milliseconds / 2)
364366
// Create files and advance manual clock to process them
365367
for (i <- Seq(1, 2, 3)) {
366368
writeFile(i, clock)
367369
clock.addToTime(batchDuration.milliseconds)
368370
if (i != 3) {
369371
// Since we want to shut down while the 3rd batch is processing
370-
waiter.waitForTotalBatchesCompleted(i, batchDuration * 5)
372+
eventually(timeout(batchDuration * 5)) {
373+
assert(waiter.getNumCompletedBatches === i)
374+
}
371375
}
372376
}
373377
clock.addToTime(batchDuration.milliseconds)
374-
waiter.waitForTotalBatchesStarted(3, batchDuration * 5)
375378
Thread.sleep(1000) // To wait for execution to actually begin
379+
eventually(timeout(batchDuration * 5)) {
380+
assert(waiter.getNumStartedBatches === 3)
381+
}
382+
assert(waiter.getNumCompletedBatches === 2)
376383
logInfo("Output after first start = " + outputStream.output.mkString("[", ", ", "]"))
377384
assert(outputStream.output.size > 0, "No files processed before restart")
378385
ssc.stop()
@@ -410,7 +417,9 @@ class CheckpointSuite extends TestSuiteBase {
410417
for ((i, index) <- Seq(7, 8, 9).zipWithIndex) {
411418
writeFile(i, clock)
412419
clock.addToTime(batchDuration.milliseconds)
413-
waiter.waitForTotalBatchesCompleted(index + 1, batchDuration * 5)
420+
eventually(timeout(batchDuration * 5)) {
421+
assert(waiter.getNumCompletedBatches === index + 1)
422+
}
414423
}
415424
clock.addToTime(batchDuration.milliseconds)
416425
logInfo("Output after restart = " + outputStream.output.mkString("[", ", ", "]"))

streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import scala.language.postfixOps
3232

3333
import com.google.common.io.Files
3434
import org.scalatest.BeforeAndAfter
35+
import org.scalatest.concurrent.Eventually._
3536

3637
import org.apache.spark.Logging
3738
import org.apache.spark.storage.StorageLevel
@@ -262,7 +263,9 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
262263
assert(file.setLastModified(clock.currentTime()))
263264
assert(file.lastModified === clock.currentTime)
264265
logInfo("Created file " + file)
265-
waiter.waitForTotalBatchesCompleted(i, timeout = batchDuration * 5)
266+
eventually(timeout(batchDuration * 5)) {
267+
assert(waiter.getNumCompletedBatches === i)
268+
}
266269
}
267270

268271
// Verify that all the files have been read

streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala

Lines changed: 10 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@ import java.io.{ObjectInputStream, IOException}
2121

2222
import scala.collection.mutable.ArrayBuffer
2323
import scala.collection.mutable.SynchronizedBuffer
24-
import scala.concurrent.TimeoutException
24+
import scala.language.implicitConversions
2525
import scala.reflect.ClassTag
2626

2727
import org.scalatest.{BeforeAndAfter, FunSuite}
28+
import org.scalatest.time.{Span, Milliseconds => ScalaTestMilliseconds}
2829

2930
import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream}
3031
import org.apache.spark.streaming.scheduler.{StreamingListenerBatchStarted, StreamingListenerBatchCompleted, StreamingListener}
@@ -138,42 +139,6 @@ class StreamingTestWaiter(ssc: StreamingContext) {
138139
def getNumStartedBatches: Int = this.synchronized {
139140
numStartedBatches
140141
}
141-
142-
/**
143-
* Block until the number of completed batches reaches the given threshold.
144-
*/
145-
def waitForTotalBatchesCompleted(
146-
targetNumBatches: Int,
147-
timeout: Duration): Unit = this.synchronized {
148-
val startTime = System.currentTimeMillis()
149-
def successful = getNumCompletedBatches >= targetNumBatches
150-
def timedOut = (System.currentTimeMillis() - startTime) >= timeout.milliseconds
151-
while (!timedOut && !successful) {
152-
this.wait(timeout.milliseconds)
153-
}
154-
if (!successful && timedOut) {
155-
throw new TimeoutException(s"Waited for $targetNumBatches completed batches, but only" +
156-
s" $numCompletedBatches have completed after $timeout")
157-
}
158-
}
159-
160-
/**
161-
* Block until the number of started batches reaches the given threshold.
162-
*/
163-
def waitForTotalBatchesStarted(
164-
targetNumBatches: Int,
165-
timeout: Duration): Unit = this.synchronized {
166-
val startTime = System.currentTimeMillis()
167-
def successful = getNumStartedBatches >= targetNumBatches
168-
def timedOut = (System.currentTimeMillis() - startTime) >= timeout.milliseconds
169-
while (!timedOut && !successful) {
170-
this.wait(timeout.milliseconds)
171-
}
172-
if (!successful && timedOut) {
173-
throw new TimeoutException(s"Waited for $targetNumBatches started batches, but only" +
174-
s" $numStartedBatches have started after $timeout")
175-
}
176-
}
177142
}
178143

179144
/**
@@ -236,6 +201,14 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
236201
before(beforeFunction)
237202
after(afterFunction)
238203

204+
/**
205+
* Implicit conversion which allows streaming Durations to be used with ScalaTest methods,
206+
* such as `eventually`.
207+
*/
208+
implicit def streamingDurationToScalatestSpan(duration: Duration): Span = {
209+
Span(duration.milliseconds, ScalaTestMilliseconds)
210+
}
211+
239212
/**
240213
* Run a block of code with the given StreamingContext and automatically
241214
* stop the context when the block completes or when an exception is thrown.

0 commit comments

Comments
 (0)