Skip to content

Commit 27029bc

Browse files
brkyvztdas
authored andcommitted
[SPARK-11639][STREAMING][FLAKY-TEST] Implement BlockingWriteAheadLog for testing the BatchedWriteAheadLog
Several elements could be drained if the main thread is not fast enough. zsxwing warned me about a similar problem, but missed it here :( Submitting the fix using a waiter. cc tdas Author: Burak Yavuz <[email protected]> Closes #9605 from brkyvz/fix-flaky-test.
1 parent 529a1d3 commit 27029bc

File tree

2 files changed

+80
-47
lines changed

2 files changed

+80
-47
lines changed

streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,9 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp
182182
buffer.clear()
183183
}
184184
}
185+
186+
/** Method for querying the queue length. Should only be used in tests. */
187+
private def getQueueLength(): Int = walWriteQueue.size()
185188
}
186189

187190
/** Static methods for aggregating and de-aggregating records. */

streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala

Lines changed: 77 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,14 @@ package org.apache.spark.streaming.util
1818

1919
import java.io._
2020
import java.nio.ByteBuffer
21-
import java.util.concurrent.{ExecutionException, ThreadPoolExecutor}
22-
import java.util.concurrent.atomic.AtomicInteger
21+
import java.util.{Iterator => JIterator}
22+
import java.util.concurrent.ThreadPoolExecutor
2323

2424
import scala.collection.JavaConverters._
2525
import scala.collection.mutable.ArrayBuffer
2626
import scala.concurrent._
2727
import scala.concurrent.duration._
2828
import scala.language.{implicitConversions, postfixOps}
29-
import scala.util.{Failure, Success}
3029

3130
import org.apache.hadoop.conf.Configuration
3231
import org.apache.hadoop.fs.Path
@@ -37,12 +36,12 @@ import org.mockito.invocation.InvocationOnMock
3736
import org.mockito.stubbing.Answer
3837
import org.scalatest.concurrent.Eventually
3938
import org.scalatest.concurrent.Eventually._
40-
import org.scalatest.{BeforeAndAfterEach, BeforeAndAfter}
39+
import org.scalatest.{PrivateMethodTester, BeforeAndAfterEach, BeforeAndAfter}
4140
import org.scalatest.mock.MockitoSugar
4241

4342
import org.apache.spark.streaming.scheduler._
4443
import org.apache.spark.util.{ThreadUtils, ManualClock, Utils}
45-
import org.apache.spark.{SparkException, SparkConf, SparkFunSuite}
44+
import org.apache.spark.{SparkConf, SparkFunSuite}
4645

4746
/** Common tests for WriteAheadLogs that we would like to test with different configurations. */
4847
abstract class CommonWriteAheadLogTests(
@@ -315,7 +314,11 @@ class FileBasedWriteAheadLogWithFileCloseAfterWriteSuite
315314
class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests(
316315
allowBatching = true,
317316
closeFileAfterWrite = false,
318-
"BatchedWriteAheadLog") with MockitoSugar with BeforeAndAfterEach with Eventually {
317+
"BatchedWriteAheadLog")
318+
with MockitoSugar
319+
with BeforeAndAfterEach
320+
with Eventually
321+
with PrivateMethodTester {
319322

320323
import BatchedWriteAheadLog._
321324
import WriteAheadLogSuite._
@@ -326,6 +329,8 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests(
326329
private var walBatchingExecutionContext: ExecutionContextExecutorService = _
327330
private val sparkConf = new SparkConf()
328331

332+
private val queueLength = PrivateMethod[Int]('getQueueLength)
333+
329334
override def beforeEach(): Unit = {
330335
wal = mock[WriteAheadLog]
331336
walHandle = mock[WriteAheadLogRecordHandle]
@@ -366,7 +371,7 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests(
366371
}
367372

368373
// we make the write requests in separate threads so that we don't block the test thread
369-
private def promiseWriteEvent(wal: WriteAheadLog, event: String, time: Long): Promise[Unit] = {
374+
private def writeAsync(wal: WriteAheadLog, event: String, time: Long): Promise[Unit] = {
370375
val p = Promise[Unit]()
371376
p.completeWith(Future {
372377
val v = wal.write(event, time)
@@ -375,28 +380,9 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests(
375380
p
376381
}
377382

378-
/**
379-
* In order to block the writes on the writer thread, we mock the write method, and block it
380-
* for some time with a promise.
381-
*/
382-
private def writeBlockingPromise(wal: WriteAheadLog): Promise[Any] = {
383-
// we would like to block the write so that we can queue requests
384-
val promise = Promise[Any]()
385-
when(wal.write(any[ByteBuffer], any[Long])).thenAnswer(
386-
new Answer[WriteAheadLogRecordHandle] {
387-
override def answer(invocation: InvocationOnMock): WriteAheadLogRecordHandle = {
388-
Await.ready(promise.future, 4.seconds)
389-
walHandle
390-
}
391-
}
392-
)
393-
promise
394-
}
395-
396383
test("BatchedWriteAheadLog - name log with aggregated entries with the timestamp of last entry") {
397-
val batchedWal = new BatchedWriteAheadLog(wal, sparkConf)
398-
// block the write so that we can batch some records
399-
val promise = writeBlockingPromise(wal)
384+
val blockingWal = new BlockingWriteAheadLog(wal, walHandle)
385+
val batchedWal = new BatchedWriteAheadLog(blockingWal, sparkConf)
400386

401387
val event1 = "hello"
402388
val event2 = "world"
@@ -406,21 +392,27 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests(
406392

407393
// The queue.take() immediately takes the 3, and there is nothing left in the queue at that
408394
// moment. Then the promise blocks the writing of 3. The rest get queued.
409-
promiseWriteEvent(batchedWal, event1, 3L)
410-
// rest of the records will be batched while it takes 3 to get written
411-
promiseWriteEvent(batchedWal, event2, 5L)
412-
promiseWriteEvent(batchedWal, event3, 8L)
413-
promiseWriteEvent(batchedWal, event4, 12L)
414-
promiseWriteEvent(batchedWal, event5, 10L)
395+
writeAsync(batchedWal, event1, 3L)
396+
eventually(timeout(1 second)) {
397+
assert(blockingWal.isBlocked)
398+
assert(batchedWal.invokePrivate(queueLength()) === 0)
399+
}
400+
// rest of the records will be batched while it takes time for 3 to get written
401+
writeAsync(batchedWal, event2, 5L)
402+
writeAsync(batchedWal, event3, 8L)
403+
writeAsync(batchedWal, event4, 12L)
404+
writeAsync(batchedWal, event5, 10L)
415405
eventually(timeout(1 second)) {
416406
assert(walBatchingThreadPool.getActiveCount === 5)
407+
assert(batchedWal.invokePrivate(queueLength()) === 4)
417408
}
418-
promise.success(true)
409+
blockingWal.allowWrite()
419410

420411
val buffer1 = wrapArrayArrayByte(Array(event1))
421412
val buffer2 = wrapArrayArrayByte(Array(event2, event3, event4, event5))
422413

423414
eventually(timeout(1 second)) {
415+
assert(batchedWal.invokePrivate(queueLength()) === 0)
424416
verify(wal, times(1)).write(meq(buffer1), meq(3L))
425417
// the file name should be the timestamp of the last record, as events should be naturally
426418
// in order of timestamp, and we need the last element.
@@ -437,27 +429,32 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests(
437429
}
438430

439431
test("BatchedWriteAheadLog - fail everything in queue during shutdown") {
440-
val batchedWal = new BatchedWriteAheadLog(wal, sparkConf)
432+
val blockingWal = new BlockingWriteAheadLog(wal, walHandle)
433+
val batchedWal = new BatchedWriteAheadLog(blockingWal, sparkConf)
441434

442-
// block the write so that we can batch some records
443-
writeBlockingPromise(wal)
444-
445-
val event1 = ("hello", 3L)
446-
val event2 = ("world", 5L)
447-
val event3 = ("this", 8L)
448-
val event4 = ("is", 9L)
449-
val event5 = ("doge", 10L)
435+
val event1 = "hello"
436+
val event2 = "world"
437+
val event3 = "this"
450438

451439
// The queue.take() immediately takes the 3, and there is nothing left in the queue at that
452440
// moment. Then the promise blocks the writing of 3. The rest get queued.
453-
val writePromises = Seq(event1, event2, event3, event4, event5).map { event =>
454-
promiseWriteEvent(batchedWal, event._1, event._2)
441+
val promise1 = writeAsync(batchedWal, event1, 3L)
442+
eventually(timeout(1 second)) {
443+
assert(blockingWal.isBlocked)
444+
assert(batchedWal.invokePrivate(queueLength()) === 0)
455445
}
446+
// rest of the records will be batched while it takes time for 3 to get written
447+
val promise2 = writeAsync(batchedWal, event2, 5L)
448+
val promise3 = writeAsync(batchedWal, event3, 8L)
456449

457450
eventually(timeout(1 second)) {
458-
assert(walBatchingThreadPool.getActiveCount === 5)
451+
assert(walBatchingThreadPool.getActiveCount === 3)
452+
assert(blockingWal.isBlocked)
453+
assert(batchedWal.invokePrivate(queueLength()) === 2) // event1 is being written
459454
}
460455

456+
val writePromises = Seq(promise1, promise2, promise3)
457+
461458
batchedWal.close()
462459
eventually(timeout(1 second)) {
463460
assert(writePromises.forall(_.isCompleted))
@@ -641,4 +638,37 @@ object WriteAheadLogSuite {
641638
def wrapArrayArrayByte[T](records: Array[T]): ByteBuffer = {
642639
ByteBuffer.wrap(Utils.serialize[Array[Array[Byte]]](records.map(Utils.serialize[T])))
643640
}
641+
642+
/**
643+
* A wrapper WriteAheadLog that blocks the write function to allow batching with the
644+
* BatchedWriteAheadLog.
645+
*/
646+
class BlockingWriteAheadLog(
647+
wal: WriteAheadLog,
648+
handle: WriteAheadLogRecordHandle) extends WriteAheadLog {
649+
@volatile private var isWriteCalled: Boolean = false
650+
@volatile private var blockWrite: Boolean = true
651+
652+
override def write(record: ByteBuffer, time: Long): WriteAheadLogRecordHandle = {
653+
isWriteCalled = true
654+
eventually(Eventually.timeout(2 second)) {
655+
assert(!blockWrite)
656+
}
657+
wal.write(record, time)
658+
isWriteCalled = false
659+
handle
660+
}
661+
override def read(segment: WriteAheadLogRecordHandle): ByteBuffer = wal.read(segment)
662+
override def readAll(): JIterator[ByteBuffer] = wal.readAll()
663+
override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = {
664+
wal.clean(threshTime, waitForCompletion)
665+
}
666+
override def close(): Unit = wal.close()
667+
668+
def allowWrite(): Unit = {
669+
blockWrite = false
670+
}
671+
672+
def isBlocked: Boolean = isWriteCalled
673+
}
644674
}

0 commit comments

Comments
 (0)