@@ -26,10 +26,12 @@ import scala.language.{implicitConversions, postfixOps}
2626import scala .util .Random
2727
2828import org .apache .hadoop .conf .Configuration
29+ import org .mockito .Matchers .any
30+ import org .mockito .Mockito .{doThrow , reset , spy }
2931import org .scalatest .{BeforeAndAfter , Matchers }
3032import org .scalatest .concurrent .Eventually ._
3133
32- import org .apache .spark .{SparkConf , SparkException , SparkFunSuite }
34+ import org .apache .spark .{SparkConf , SparkFunSuite }
3335import org .apache .spark .internal .Logging
3436import org .apache .spark .storage .StreamBlockId
3537import org .apache .spark .streaming .receiver .BlockManagerBasedStoreResult
@@ -115,6 +117,47 @@ class ReceivedBlockTrackerSuite
115117 tracker2.stop()
116118 }
117119
120+ test(" block allocation to batch should not loose blocks from received queue" ) {
121+ val tracker1 = spy(createTracker())
122+ tracker1.isWriteAheadLogEnabled should be (true )
123+ tracker1.getUnallocatedBlocks(streamId) shouldEqual Seq .empty
124+
125+ // Add blocks
126+ val blockInfos = generateBlockInfos()
127+ blockInfos.map(tracker1.addBlock)
128+ tracker1.getUnallocatedBlocks(streamId) shouldEqual blockInfos
129+
130+ // Try to allocate the blocks to a batch and verify that it's failing
131+ // The blocks should stay in the received queue when WAL write failing
132+ doThrow(new RuntimeException (" Not able to write BatchAllocationEvent" ))
133+ .when(tracker1).writeToLog(any(classOf [BatchAllocationEvent ]))
134+ val errMsg = intercept[RuntimeException ] {
135+ tracker1.allocateBlocksToBatch(1 )
136+ }
137+ assert(errMsg.getMessage === " Not able to write BatchAllocationEvent" )
138+ tracker1.getUnallocatedBlocks(streamId) shouldEqual blockInfos
139+ tracker1.getBlocksOfBatch(1 ) shouldEqual Map .empty
140+ tracker1.getBlocksOfBatchAndStream(1 , streamId) shouldEqual Seq .empty
141+
142+ // Allocate the blocks to a batch and verify that all of them have been allocated
143+ reset(tracker1)
144+ tracker1.allocateBlocksToBatch(2 )
145+ tracker1.getUnallocatedBlocks(streamId) shouldEqual Seq .empty
146+ tracker1.hasUnallocatedReceivedBlocks should be (false )
147+ tracker1.getBlocksOfBatch(2 ) shouldEqual Map (streamId -> blockInfos)
148+ tracker1.getBlocksOfBatchAndStream(2 , streamId) shouldEqual blockInfos
149+
150+ tracker1.stop()
151+
152+ // Recover from WAL to see the correctness
153+ val tracker2 = createTracker(recoverFromWriteAheadLog = true )
154+ tracker2.getUnallocatedBlocks(streamId) shouldEqual Seq .empty
155+ tracker2.hasUnallocatedReceivedBlocks should be (false )
156+ tracker2.getBlocksOfBatch(2 ) shouldEqual Map (streamId -> blockInfos)
157+ tracker2.getBlocksOfBatchAndStream(2 , streamId) shouldEqual blockInfos
158+ tracker2.stop()
159+ }
160+
118161 test(" recovery and cleanup with write ahead logs" ) {
119162 val manualClock = new ManualClock
120163 // Set the time increment level to twice the rotation interval so that every increment creates
@@ -312,7 +355,7 @@ class ReceivedBlockTrackerSuite
312355 recoverFromWriteAheadLog : Boolean = false ,
313356 clock : Clock = new SystemClock ): ReceivedBlockTracker = {
314357 val cpDirOption = if (setCheckpointDir) Some (checkpointDirectory.toString) else None
315- val tracker = new ReceivedBlockTracker (
358+ var tracker = new ReceivedBlockTracker (
316359 conf, hadoopConf, Seq (streamId), clock, recoverFromWriteAheadLog, cpDirOption)
317360 allReceivedBlockTrackers += tracker
318361 tracker
0 commit comments