@@ -18,15 +18,14 @@ package org.apache.spark.streaming.util
1818
1919import java .io ._
2020import java .nio .ByteBuffer
21- import java .util .concurrent .{ ExecutionException , ThreadPoolExecutor }
22- import java .util .concurrent .atomic . AtomicInteger
21+ import java .util .{ Iterator => JIterator }
22+ import java .util .concurrent .ThreadPoolExecutor
2323
2424import scala .collection .JavaConverters ._
2525import scala .collection .mutable .ArrayBuffer
2626import scala .concurrent ._
2727import scala .concurrent .duration ._
2828import scala .language .{implicitConversions , postfixOps }
29- import scala .util .{Failure , Success }
3029
3130import org .apache .hadoop .conf .Configuration
3231import org .apache .hadoop .fs .Path
@@ -37,12 +36,12 @@ import org.mockito.invocation.InvocationOnMock
3736import org .mockito .stubbing .Answer
3837import org .scalatest .concurrent .Eventually
3938import org .scalatest .concurrent .Eventually ._
40- import org .scalatest .{BeforeAndAfterEach , BeforeAndAfter }
39+ import org .scalatest .{PrivateMethodTester , BeforeAndAfterEach , BeforeAndAfter }
4140import org .scalatest .mock .MockitoSugar
4241
4342import org .apache .spark .streaming .scheduler ._
4443import org .apache .spark .util .{ThreadUtils , ManualClock , Utils }
45- import org .apache .spark .{SparkException , SparkConf , SparkFunSuite }
44+ import org .apache .spark .{SparkConf , SparkFunSuite }
4645
4746/** Common tests for WriteAheadLogs that we would like to test with different configurations. */
4847abstract class CommonWriteAheadLogTests (
@@ -315,7 +314,11 @@ class FileBasedWriteAheadLogWithFileCloseAfterWriteSuite
315314class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests (
316315 allowBatching = true ,
317316 closeFileAfterWrite = false ,
318- " BatchedWriteAheadLog" ) with MockitoSugar with BeforeAndAfterEach with Eventually {
317+ " BatchedWriteAheadLog" )
318+ with MockitoSugar
319+ with BeforeAndAfterEach
320+ with Eventually
321+ with PrivateMethodTester {
319322
320323 import BatchedWriteAheadLog ._
321324 import WriteAheadLogSuite ._
@@ -326,6 +329,8 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests(
326329 private var walBatchingExecutionContext : ExecutionContextExecutorService = _
327330 private val sparkConf = new SparkConf ()
328331
332+ private val queueLength = PrivateMethod [Int ](' getQueueLength )
333+
329334 override def beforeEach (): Unit = {
330335 wal = mock[WriteAheadLog ]
331336 walHandle = mock[WriteAheadLogRecordHandle ]
@@ -366,7 +371,7 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests(
366371 }
367372
368373 // we make the write requests in separate threads so that we don't block the test thread
369- private def promiseWriteEvent (wal : WriteAheadLog , event : String , time : Long ): Promise [Unit ] = {
374+ private def writeAsync (wal : WriteAheadLog , event : String , time : Long ): Promise [Unit ] = {
370375 val p = Promise [Unit ]()
371376 p.completeWith(Future {
372377 val v = wal.write(event, time)
@@ -375,28 +380,9 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests(
375380 p
376381 }
377382
378- /**
379- * In order to block the writes on the writer thread, we mock the write method, and block it
380- * for some time with a promise.
381- */
382- private def writeBlockingPromise (wal : WriteAheadLog ): Promise [Any ] = {
383- // we would like to block the write so that we can queue requests
384- val promise = Promise [Any ]()
385- when(wal.write(any[ByteBuffer ], any[Long ])).thenAnswer(
386- new Answer [WriteAheadLogRecordHandle ] {
387- override def answer (invocation : InvocationOnMock ): WriteAheadLogRecordHandle = {
388- Await .ready(promise.future, 4 .seconds)
389- walHandle
390- }
391- }
392- )
393- promise
394- }
395-
396383 test(" BatchedWriteAheadLog - name log with aggregated entries with the timestamp of last entry" ) {
397- val batchedWal = new BatchedWriteAheadLog (wal, sparkConf)
398- // block the write so that we can batch some records
399- val promise = writeBlockingPromise(wal)
384+ val blockingWal = new BlockingWriteAheadLog (wal, walHandle)
385+ val batchedWal = new BatchedWriteAheadLog (blockingWal, sparkConf)
400386
401387 val event1 = " hello"
402388 val event2 = " world"
@@ -406,21 +392,27 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests(
406392
407393 // The queue.take() immediately takes the 3, and there is nothing left in the queue at that
408394 // moment. Then the promise blocks the writing of 3. The rest get queued.
409- promiseWriteEvent(batchedWal, event1, 3L )
410- // rest of the records will be batched while it takes 3 to get written
411- promiseWriteEvent(batchedWal, event2, 5L )
412- promiseWriteEvent(batchedWal, event3, 8L )
413- promiseWriteEvent(batchedWal, event4, 12L )
414- promiseWriteEvent(batchedWal, event5, 10L )
395+ writeAsync(batchedWal, event1, 3L )
396+ eventually(timeout(1 second)) {
397+ assert(blockingWal.isBlocked)
398+ assert(batchedWal.invokePrivate(queueLength()) === 0 )
399+ }
400+ // rest of the records will be batched while it takes time for 3 to get written
401+ writeAsync(batchedWal, event2, 5L )
402+ writeAsync(batchedWal, event3, 8L )
403+ writeAsync(batchedWal, event4, 12L )
404+ writeAsync(batchedWal, event5, 10L )
415405 eventually(timeout(1 second)) {
416406 assert(walBatchingThreadPool.getActiveCount === 5 )
407+ assert(batchedWal.invokePrivate(queueLength()) === 4 )
417408 }
418- promise.success( true )
409+ blockingWal.allowWrite( )
419410
420411 val buffer1 = wrapArrayArrayByte(Array (event1))
421412 val buffer2 = wrapArrayArrayByte(Array (event2, event3, event4, event5))
422413
423414 eventually(timeout(1 second)) {
415+ assert(batchedWal.invokePrivate(queueLength()) === 0 )
424416 verify(wal, times(1 )).write(meq(buffer1), meq(3L ))
425417 // the file name should be the timestamp of the last record, as events should be naturally
426418 // in order of timestamp, and we need the last element.
@@ -437,27 +429,32 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests(
437429 }
438430
439431 test(" BatchedWriteAheadLog - fail everything in queue during shutdown" ) {
440- val batchedWal = new BatchedWriteAheadLog (wal, sparkConf)
432+ val blockingWal = new BlockingWriteAheadLog (wal, walHandle)
433+ val batchedWal = new BatchedWriteAheadLog (blockingWal, sparkConf)
441434
442- // block the write so that we can batch some records
443- writeBlockingPromise(wal)
444-
445- val event1 = (" hello" , 3L )
446- val event2 = (" world" , 5L )
447- val event3 = (" this" , 8L )
448- val event4 = (" is" , 9L )
449- val event5 = (" doge" , 10L )
435+ val event1 = " hello"
436+ val event2 = " world"
437+ val event3 = " this"
450438
451439 // The queue.take() immediately takes the 3, and there is nothing left in the queue at that
452440 // moment. Then the promise blocks the writing of 3. The rest get queued.
453- val writePromises = Seq (event1, event2, event3, event4, event5).map { event =>
454- promiseWriteEvent(batchedWal, event._1, event._2)
441+ val promise1 = writeAsync(batchedWal, event1, 3L )
442+ eventually(timeout(1 second)) {
443+ assert(blockingWal.isBlocked)
444+ assert(batchedWal.invokePrivate(queueLength()) === 0 )
455445 }
446+ // rest of the records will be batched while it takes time for 3 to get written
447+ val promise2 = writeAsync(batchedWal, event2, 5L )
448+ val promise3 = writeAsync(batchedWal, event3, 8L )
456449
457450 eventually(timeout(1 second)) {
458- assert(walBatchingThreadPool.getActiveCount === 5 )
451+ assert(walBatchingThreadPool.getActiveCount === 3 )
452+ assert(blockingWal.isBlocked)
453+ assert(batchedWal.invokePrivate(queueLength()) === 2 ) // event1 is being written
459454 }
460455
456+ val writePromises = Seq (promise1, promise2, promise3)
457+
461458 batchedWal.close()
462459 eventually(timeout(1 second)) {
463460 assert(writePromises.forall(_.isCompleted))
@@ -641,4 +638,37 @@ object WriteAheadLogSuite {
641638 def wrapArrayArrayByte [T ](records : Array [T ]): ByteBuffer = {
642639 ByteBuffer .wrap(Utils .serialize[Array [Array [Byte ]]](records.map(Utils .serialize[T ])))
643640 }
641+
642+ /**
643+ * A wrapper WriteAheadLog that blocks the write function to allow batching with the
644+ * BatchedWriteAheadLog.
645+ */
646+ class BlockingWriteAheadLog (
647+ wal : WriteAheadLog ,
648+ handle : WriteAheadLogRecordHandle ) extends WriteAheadLog {
649+ @ volatile private var isWriteCalled : Boolean = false
650+ @ volatile private var blockWrite : Boolean = true
651+
652+ override def write (record : ByteBuffer , time : Long ): WriteAheadLogRecordHandle = {
653+ isWriteCalled = true
654+ eventually(Eventually .timeout(2 second)) {
655+ assert(! blockWrite)
656+ }
657+ wal.write(record, time)
658+ isWriteCalled = false
659+ handle
660+ }
661+ override def read (segment : WriteAheadLogRecordHandle ): ByteBuffer = wal.read(segment)
662+ override def readAll (): JIterator [ByteBuffer ] = wal.readAll()
663+ override def clean (threshTime : Long , waitForCompletion : Boolean ): Unit = {
664+ wal.clean(threshTime, waitForCompletion)
665+ }
666+ override def close (): Unit = wal.close()
667+
668+ def allowWrite (): Unit = {
669+ blockWrite = false
670+ }
671+
672+ def isBlocked : Boolean = isWriteCalled
673+ }
644674}
0 commit comments