@@ -21,11 +21,13 @@ import java.io.{ObjectInputStream, IOException}
2121
2222import scala .collection .mutable .ArrayBuffer
2323import scala .collection .mutable .SynchronizedBuffer
24+ import scala .concurrent .TimeoutException
2425import scala .reflect .ClassTag
2526
2627import org .scalatest .{BeforeAndAfter , FunSuite }
2728
2829import org .apache .spark .streaming .dstream .{DStream , InputDStream , ForEachDStream }
30+ import org .apache .spark .streaming .scheduler .{StreamingListenerBatchStarted , StreamingListenerBatchCompleted , StreamingListener }
2931import org .apache .spark .streaming .util .ManualClock
3032import org .apache .spark .{SparkConf , Logging }
3133import org .apache .spark .rdd .RDD
@@ -103,6 +105,77 @@ class TestOutputStreamWithPartitions[T: ClassTag](parent: DStream[T],
103105 def toTestOutputStream = new TestOutputStream [T ](this .parent, this .output.map(_.flatten))
104106}
105107
108+ /**
109+ * This is an interface that can be used to block until certain events occur, such as
110+ * the start/completion of batches. This is much less brittle than waiting on wall-clock time.
111+ * Internally, this is implemented using a StreamingListener. Constructing a new instance of this
112+ * class automatically registers a StreamingListener on the given StreamingContext.
113+ */
114+ class StreamingTestWaiter (ssc : StreamingContext ) {
115+
116+ // All access to this state should be guarded by `StreamingTestWaiter.this.synchronized`
117+ private var numCompletedBatches = 0
118+ private var numStartedBatches = 0
119+
120+ private val listener = new StreamingListener {
121+ override def onBatchStarted (batchStarted : StreamingListenerBatchStarted ): Unit =
122+ StreamingTestWaiter .this .synchronized {
123+ numStartedBatches += 1
124+ StreamingTestWaiter .this .notifyAll()
125+ }
126+ override def onBatchCompleted (batchCompleted : StreamingListenerBatchCompleted ): Unit =
127+ StreamingTestWaiter .this .synchronized {
128+ numCompletedBatches += 1
129+ StreamingTestWaiter .this .notifyAll()
130+ }
131+ }
132+ ssc.addStreamingListener(listener)
133+
134+ def getNumCompletedBatches : Int = this .synchronized {
135+ numCompletedBatches
136+ }
137+
138+ def getNumStartedBatches : Int = this .synchronized {
139+ numStartedBatches
140+ }
141+
142+ /**
143+ * Block until the number of completed batches reaches the given threshold.
144+ */
145+ def waitForTotalBatchesCompleted (
146+ targetNumBatches : Int ,
147+ timeout : Duration ): Unit = this .synchronized {
148+ val startTime = System .currentTimeMillis()
149+ def successful = getNumCompletedBatches >= targetNumBatches
150+ def timedOut = (System .currentTimeMillis() - startTime) >= timeout.milliseconds
151+ while (! timedOut && ! successful) {
152+ this .wait(timeout.milliseconds)
153+ }
154+ if (! successful && timedOut) {
155+ throw new TimeoutException (s " Waited for $targetNumBatches completed batches, but only " +
156+ s " $numCompletedBatches have completed after $timeout" )
157+ }
158+ }
159+
160+ /**
161+ * Block until the number of started batches reaches the given threshold.
162+ */
163+ def waitForTotalBatchesStarted (
164+ targetNumBatches : Int ,
165+ timeout : Duration ): Unit = this .synchronized {
166+ val startTime = System .currentTimeMillis()
167+ def successful = getNumStartedBatches >= targetNumBatches
168+ def timedOut = (System .currentTimeMillis() - startTime) >= timeout.milliseconds
169+ while (! timedOut && ! successful) {
170+ this .wait(timeout.milliseconds)
171+ }
172+ if (! successful && timedOut) {
173+ throw new TimeoutException (s " Waited for $targetNumBatches started batches, but only " +
174+ s " $numStartedBatches have started after $timeout" )
175+ }
176+ }
177+ }
178+
106179/**
107180 * This is the base trait for Spark Streaming testsuites. This provides basic functionality
108181 * to run user-defined set of input on user-defined stream operations, and verify the output.
0 commit comments