Skip to content

Commit 4ff9f7e

Browse files
committed
[SPARK-23197][STREAMING] Fix ReceiverSuite."receiver_life_cycle" to not rely on timing
1 parent 2c775f4 commit 4ff9f7e

File tree

1 file changed

+61
-28
lines changed

1 file changed

+61
-28
lines changed

streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala

Lines changed: 61 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class ReceiverSuite extends TestSuiteBase with TimeLimits with Serializable {
7373
executorStarted.acquire()
7474

7575
// Verify that receiver was started
76-
assert(receiver.onStartCalled)
76+
assert(receiver.callsRecorder.calls === Seq("onStart"))
7777
assert(executor.isReceiverStarted)
7878
assert(receiver.isStarted)
7979
assert(!receiver.isStopped())
@@ -106,19 +106,22 @@ class ReceiverSuite extends TestSuiteBase with TimeLimits with Serializable {
106106
assert(executor.errors.head.eq(exception))
107107

108108
// Verify restarting actually stops and starts the receiver
109-
receiver.restart("restarting", null, 600)
110-
eventually(timeout(300.milliseconds), interval(10.milliseconds)) {
111-
// receiver will be stopped async
112-
assert(receiver.isStopped)
113-
assert(receiver.onStopCalled)
114-
}
115-
eventually(timeout(1.second), interval(10.milliseconds)) {
116-
// receiver will be started async
117-
assert(receiver.onStartCalled)
118-
assert(executor.isReceiverStarted)
109+
executor.callsRecorder.reset()
110+
receiver.callsRecorder.reset()
111+
receiver.restart("restarting", null, 100)
112+
eventually(timeout(10.seconds), interval(10.milliseconds)) {
113+
// below verification ensures for now receiver is already restarted
119114
assert(receiver.isStarted)
120115
assert(!receiver.isStopped)
121116
assert(receiver.receiving)
117+
118+
// both receiver supervisor and receiver should be stopped first, and started
119+
assert(executor.callsRecorder.calls === Seq("onReceiverStop", "onReceiverStart"))
120+
assert(receiver.callsRecorder.calls === Seq("onStop", "onStart"))
121+
122+
// check whether the delay between stop and start is respected
123+
assert(executor.callsRecorder.timestamps.reverse.reduceLeft { _ - _ } >= 100)
124+
assert(receiver.callsRecorder.timestamps.reverse.reduceLeft { _ - _ } >= 100)
122125
}
123126

124127
// Verify that stopping actually stops the thread
@@ -290,6 +293,9 @@ class ReceiverSuite extends TestSuiteBase with TimeLimits with Serializable {
290293
val arrayBuffers = new ArrayBuffer[ArrayBuffer[_]]
291294
val errors = new ArrayBuffer[Throwable]
292295

296+
// tracks calls of "onReceiverStart", "onReceiverStop"
297+
val callsRecorder = new MethodsCallRecorder()
298+
293299
/** Check if all data structures are clean */
294300
def isAllEmpty: Boolean = {
295301
singles.isEmpty && byteBuffers.isEmpty && iterators.isEmpty &&
@@ -325,7 +331,15 @@ class ReceiverSuite extends TestSuiteBase with TimeLimits with Serializable {
325331
errors += throwable
326332
}
327333

328-
override protected def onReceiverStart(): Boolean = true
334+
override protected def onReceiverStart(): Boolean = {
335+
callsRecorder.record()
336+
true
337+
}
338+
339+
override protected def onReceiverStop(message: String, error: Option[Throwable]): Unit = {
340+
callsRecorder.record()
341+
super.onReceiverStop(message, error)
342+
}
329343

330344
override def createBlockGenerator(
331345
blockGeneratorListener: BlockGeneratorListener): BlockGenerator = {
@@ -363,36 +377,55 @@ class ReceiverSuite extends TestSuiteBase with TimeLimits with Serializable {
363377
class FakeReceiver(sendData: Boolean = false) extends Receiver[Int](StorageLevel.MEMORY_ONLY) {
364378
@volatile var otherThread: Thread = null
365379
@volatile var receiving = false
366-
@volatile var onStartCalled = false
367-
@volatile var onStopCalled = false
380+
381+
// tracks calls of "onStart", "onStop"
382+
val callsRecorder = new MethodsCallRecorder()
368383

369384
def onStart() {
370385
otherThread = new Thread() {
371386
override def run() {
372387
receiving = true
373-
var count = 0
374-
while(!isStopped()) {
375-
if (sendData) {
376-
store(count)
377-
count += 1
388+
try {
389+
var count = 0
390+
while(!isStopped()) {
391+
if (sendData) {
392+
store(count)
393+
count += 1
394+
}
395+
Thread.sleep(10)
378396
}
379-
Thread.sleep(10)
397+
} finally {
398+
receiving = false
380399
}
381400
}
382401
}
383-
onStartCalled = true
402+
callsRecorder.record()
384403
otherThread.start()
385404
}
386405

387406
def onStop() {
388-
onStopCalled = true
407+
callsRecorder.record()
389408
otherThread.join()
390409
}
391-
392-
def reset() {
393-
receiving = false
394-
onStartCalled = false
395-
onStopCalled = false
396-
}
397410
}
398411

412+
class MethodsCallRecorder {
413+
// tracks calling methods as (timestamp, methodName)
414+
private val records = new ArrayBuffer[(Long, String)]
415+
416+
def record(): Unit = records.append((System.currentTimeMillis(), callerMethodName))
417+
418+
def reset(): Unit = records.clear()
419+
420+
def callsWithTimestamp: scala.collection.immutable.Seq[(Long, String)] = records.toList
421+
422+
def calls: scala.collection.immutable.Seq[String] = records.map(_._2).toList
423+
424+
def timestamps: scala.collection.immutable.Seq[Long] = records.map(_._1).toList
425+
426+
private def callerMethodName: String = {
427+
val stackTrace = new Throwable().getStackTrace
428+
// it should return method name of two levels deeper
429+
stackTrace(2).getMethodName
430+
}
431+
}

0 commit comments

Comments
 (0)