Skip to content

Commit dda1403

Browse files
committed
Add StreamingTestWaiter class.
1 parent 3c3efc3 commit dda1403

File tree

1 file changed

+73
-0
lines changed

1 file changed

+73
-0
lines changed

streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@ import java.io.{ObjectInputStream, IOException}
2121

2222
import scala.collection.mutable.ArrayBuffer
2323
import scala.collection.mutable.SynchronizedBuffer
24+
import scala.concurrent.TimeoutException
2425
import scala.reflect.ClassTag
2526

2627
import org.scalatest.{BeforeAndAfter, FunSuite}
2728

2829
import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream}
30+
import org.apache.spark.streaming.scheduler.{StreamingListenerBatchStarted, StreamingListenerBatchCompleted, StreamingListener}
2931
import org.apache.spark.streaming.util.ManualClock
3032
import org.apache.spark.{SparkConf, Logging}
3133
import org.apache.spark.rdd.RDD
@@ -103,6 +105,77 @@ class TestOutputStreamWithPartitions[T: ClassTag](parent: DStream[T],
103105
def toTestOutputStream = new TestOutputStream[T](this.parent, this.output.map(_.flatten))
104106
}
105107

108+
/**
109+
* This is an interface that can be used to block until certain events occur, such as
110+
* the start/completion of batches. This is much less brittle than waiting on wall-clock time.
111+
* Internally, this is implemented using a StreamingListener. Constructing a new instance of this
112+
* class automatically registers a StreamingListener on the given StreamingContext.
113+
*/
114+
class StreamingTestWaiter(ssc: StreamingContext) {
115+
116+
// All access to this state should be guarded by `StreamingTestWaiter.this.synchronized`
117+
private var numCompletedBatches = 0
118+
private var numStartedBatches = 0
119+
120+
private val listener = new StreamingListener {
121+
override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit =
122+
StreamingTestWaiter.this.synchronized {
123+
numStartedBatches += 1
124+
StreamingTestWaiter.this.notifyAll()
125+
}
126+
override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit =
127+
StreamingTestWaiter.this.synchronized {
128+
numCompletedBatches += 1
129+
StreamingTestWaiter.this.notifyAll()
130+
}
131+
}
132+
ssc.addStreamingListener(listener)
133+
134+
def getNumCompletedBatches: Int = this.synchronized {
135+
numCompletedBatches
136+
}
137+
138+
def getNumStartedBatches: Int = this.synchronized {
139+
numStartedBatches
140+
}
141+
142+
/**
143+
* Block until the number of completed batches reaches the given threshold.
144+
*/
145+
def waitForTotalBatchesCompleted(
146+
targetNumBatches: Int,
147+
timeout: Duration): Unit = this.synchronized {
148+
val startTime = System.currentTimeMillis()
149+
def successful = getNumCompletedBatches >= targetNumBatches
150+
def timedOut = (System.currentTimeMillis() - startTime) >= timeout.milliseconds
151+
while (!timedOut && !successful) {
152+
this.wait(timeout.milliseconds)
153+
}
154+
if (!successful && timedOut) {
155+
throw new TimeoutException(s"Waited for $targetNumBatches completed batches, but only" +
156+
s" $numCompletedBatches have completed after $timeout")
157+
}
158+
}
159+
160+
/**
161+
* Block until the number of started batches reaches the given threshold.
162+
*/
163+
def waitForTotalBatchesStarted(
164+
targetNumBatches: Int,
165+
timeout: Duration): Unit = this.synchronized {
166+
val startTime = System.currentTimeMillis()
167+
def successful = getNumStartedBatches >= targetNumBatches
168+
def timedOut = (System.currentTimeMillis() - startTime) >= timeout.milliseconds
169+
while (!timedOut && !successful) {
170+
this.wait(timeout.milliseconds)
171+
}
172+
if (!successful && timedOut) {
173+
throw new TimeoutException(s"Waited for $targetNumBatches started batches, but only" +
174+
s" $numStartedBatches have started after $timeout")
175+
}
176+
}
177+
}
178+
106179
/**
107180
* This is the base trait for Spark Streaming testsuites. This provides basic functionality
108181
* to run user-defined set of input on user-defined stream operations, and verify the output.

0 commit comments

Comments
 (0)