@@ -73,7 +73,7 @@ class ReceiverSuite extends TestSuiteBase with TimeLimits with Serializable {
7373 executorStarted.acquire()
7474
7575 // Verify that receiver was started
76- assert(receiver.onStartCalled )
76+ assert(receiver.callsRecorder.calls === Seq ( " onStart " ) )
7777 assert(executor.isReceiverStarted)
7878 assert(receiver.isStarted)
7979 assert(! receiver.isStopped())
@@ -106,19 +106,22 @@ class ReceiverSuite extends TestSuiteBase with TimeLimits with Serializable {
106106 assert(executor.errors.head.eq(exception))
107107
108108 // Verify restarting actually stops and starts the receiver
109- receiver.restart(" restarting" , null , 600 )
110- eventually(timeout(300 .milliseconds), interval(10 .milliseconds)) {
111- // receiver will be stopped async
112- assert(receiver.isStopped)
113- assert(receiver.onStopCalled)
114- }
115- eventually(timeout(1 .second), interval(10 .milliseconds)) {
116- // receiver will be started async
117- assert(receiver.onStartCalled)
118- assert(executor.isReceiverStarted)
109+ executor.callsRecorder.reset()
110+ receiver.callsRecorder.reset()
111+ receiver.restart(" restarting" , null , 100 )
112+ eventually(timeout(10 .seconds), interval(10 .milliseconds)) {
113+ // below verification ensures for now receiver is already restarted
119114 assert(receiver.isStarted)
120115 assert(! receiver.isStopped)
121116 assert(receiver.receiving)
117+
118+ // both receiver supervisor and receiver should be stopped first, and started
119+ assert(executor.callsRecorder.calls === Seq (" onReceiverStop" , " onReceiverStart" ))
120+ assert(receiver.callsRecorder.calls === Seq (" onStop" , " onStart" ))
121+
122+ // check whether the delay between stop and start is respected
123+ assert(executor.callsRecorder.timestamps.reverse.reduceLeft { _ - _ } >= 100 )
124+ assert(receiver.callsRecorder.timestamps.reverse.reduceLeft { _ - _ } >= 100 )
122125 }
123126
124127 // Verify that stopping actually stops the thread
@@ -290,6 +293,9 @@ class ReceiverSuite extends TestSuiteBase with TimeLimits with Serializable {
290293 val arrayBuffers = new ArrayBuffer [ArrayBuffer [_]]
291294 val errors = new ArrayBuffer [Throwable ]
292295
296+ // tracks calls of "onReceiverStart", "onReceiverStop"
297+ val callsRecorder = new MethodsCallRecorder ()
298+
293299 /** Check if all data structures are clean */
294300 def isAllEmpty : Boolean = {
295301 singles.isEmpty && byteBuffers.isEmpty && iterators.isEmpty &&
@@ -325,7 +331,15 @@ class ReceiverSuite extends TestSuiteBase with TimeLimits with Serializable {
325331 errors += throwable
326332 }
327333
328- override protected def onReceiverStart (): Boolean = true
334+ override protected def onReceiverStart (): Boolean = {
335+ callsRecorder.record()
336+ true
337+ }
338+
339+ override protected def onReceiverStop (message : String , error : Option [Throwable ]): Unit = {
340+ callsRecorder.record()
341+ super .onReceiverStop(message, error)
342+ }
329343
330344 override def createBlockGenerator (
331345 blockGeneratorListener : BlockGeneratorListener ): BlockGenerator = {
@@ -363,36 +377,55 @@ class ReceiverSuite extends TestSuiteBase with TimeLimits with Serializable {
363377class FakeReceiver (sendData : Boolean = false ) extends Receiver [Int ](StorageLevel .MEMORY_ONLY ) {
364378 @ volatile var otherThread : Thread = null
365379 @ volatile var receiving = false
366- @ volatile var onStartCalled = false
367- @ volatile var onStopCalled = false
380+
381+ // tracks calls of "onStart", "onStop"
382+ val callsRecorder = new MethodsCallRecorder ()
368383
369384 def onStart () {
370385 otherThread = new Thread () {
371386 override def run () {
372387 receiving = true
373- var count = 0
374- while (! isStopped()) {
375- if (sendData) {
376- store(count)
377- count += 1
388+ try {
389+ var count = 0
390+ while (! isStopped()) {
391+ if (sendData) {
392+ store(count)
393+ count += 1
394+ }
395+ Thread .sleep(10 )
378396 }
379- Thread .sleep(10 )
397+ } finally {
398+ receiving = false
380399 }
381400 }
382401 }
383- onStartCalled = true
402+ callsRecorder.record()
384403 otherThread.start()
385404 }
386405
387406 def onStop () {
388- onStopCalled = true
407+ callsRecorder.record()
389408 otherThread.join()
390409 }
391-
392- def reset () {
393- receiving = false
394- onStartCalled = false
395- onStopCalled = false
396- }
397410}
398411
412+ class MethodsCallRecorder {
413+ // tracks calling methods as (timestamp, methodName)
414+ private val records = new ArrayBuffer [(Long , String )]
415+
416+ def record (): Unit = records.append((System .currentTimeMillis(), callerMethodName))
417+
418+ def reset (): Unit = records.clear()
419+
420+ def callsWithTimestamp : scala.collection.immutable.Seq [(Long , String )] = records.toList
421+
422+ def calls : scala.collection.immutable.Seq [String ] = records.map(_._2).toList
423+
424+ def timestamps : scala.collection.immutable.Seq [Long ] = records.map(_._1).toList
425+
426+ private def callerMethodName : String = {
427+ val stackTrace = new Throwable ().getStackTrace
428+ // it should return method name of two levels deeper
429+ stackTrace(2 ).getMethodName
430+ }
431+ }
0 commit comments