From ed100381fc7bb7e0fffcc9d48a8b1e5d5b9915d4 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 10 Nov 2015 13:24:09 -0800 Subject: [PATCH 1/5] Add waiter to make sure a single element is added --- .../spark/streaming/util/WriteAheadLogSuite.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index e96f4c2a2934..d3facc27ef0f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -37,6 +37,7 @@ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.concurrent.Eventually import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.AsyncAssertions.Waiter import org.scalatest.{BeforeAndAfterEach, BeforeAndAfter} import org.scalatest.mock.MockitoSugar @@ -379,24 +380,26 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( * In order to block the writes on the writer thread, we mock the write method, and block it * for some time with a promise. */ - private def writeBlockingPromise(wal: WriteAheadLog): Promise[Any] = { + private def writeBlockingPromise(wal: WriteAheadLog): (Promise[Any], Waiter) = { // we would like to block the write so that we can queue requests val promise = Promise[Any]() + val waiter = new Waiter when(wal.write(any[ByteBuffer], any[Long])).thenAnswer( new Answer[WriteAheadLogRecordHandle] { override def answer(invocation: InvocationOnMock): WriteAheadLogRecordHandle = { + waiter.dismiss() Await.ready(promise.future, 4.seconds) walHandle } } ) - promise + (promise, waiter) } test("BatchedWriteAheadLog - name log with aggregated entries with the timestamp of last entry") { val batchedWal = new BatchedWriteAheadLog(wal, sparkConf) // block the write so that we can batch some records - val promise = writeBlockingPromise(wal) + val (promise, waiter) = writeBlockingPromise(wal) val event1 = "hello" val event2 = "world" @@ -407,6 +410,7 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( // The queue.take() immediately takes the 3, and there is nothing left in the queue at that // moment. Then the promise blocks the writing of 3. The rest get queued. promiseWriteEvent(batchedWal, event1, 3L) + waiter.await() // rest of the records will be batched while it takes 3 to get written promiseWriteEvent(batchedWal, event2, 5L) promiseWriteEvent(batchedWal, event3, 8L) From 6d5e50e1c972ea660f13a3e44e8524eeb7ef5a5d Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 10 Nov 2015 14:04:41 -0800 Subject: [PATCH 2/5] address 1 --- .../streaming/util/WriteAheadLogSuite.scala | 90 ++++++++++--------- 1 file changed, 49 insertions(+), 41 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index d3facc27ef0f..f580a7b09726 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -18,15 +18,14 @@ package org.apache.spark.streaming.util import java.io._ import java.nio.ByteBuffer -import java.util.concurrent.{ExecutionException, ThreadPoolExecutor} -import java.util.concurrent.atomic.AtomicInteger +import java.util.{Iterator => JIterator} +import java.util.concurrent.ThreadPoolExecutor import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.concurrent._ import scala.concurrent.duration._ import scala.language.{implicitConversions, postfixOps} -import scala.util.{Failure, Success} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -37,13 +36,12 @@ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.concurrent.Eventually import org.scalatest.concurrent.Eventually._ -import org.scalatest.concurrent.AsyncAssertions.Waiter import org.scalatest.{BeforeAndAfterEach, BeforeAndAfter} import org.scalatest.mock.MockitoSugar import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.{ThreadUtils, ManualClock, Utils} -import org.apache.spark.{SparkException, SparkConf, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} /** Common tests for WriteAheadLogs that we would like to test with different configurations. */ abstract class CommonWriteAheadLogTests( @@ -367,7 +365,7 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( } // we make the write requests in separate threads so that we don't block the test thread - private def promiseWriteEvent(wal: WriteAheadLog, event: String, time: Long): Promise[Unit] = { + private def writeAsync(wal: WriteAheadLog, event: String, time: Long): Promise[Unit] = { val p = Promise[Unit]() p.completeWith(Future { val v = wal.write(event, time) @@ -376,30 +374,9 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( p } - /** - * In order to block the writes on the writer thread, we mock the write method, and block it - * for some time with a promise. - */ - private def writeBlockingPromise(wal: WriteAheadLog): (Promise[Any], Waiter) = { - // we would like to block the write so that we can queue requests - val promise = Promise[Any]() - val waiter = new Waiter - when(wal.write(any[ByteBuffer], any[Long])).thenAnswer( - new Answer[WriteAheadLogRecordHandle] { - override def answer(invocation: InvocationOnMock): WriteAheadLogRecordHandle = { - waiter.dismiss() - Await.ready(promise.future, 4.seconds) - walHandle - } - } - ) - (promise, waiter) - } - test("BatchedWriteAheadLog - name log with aggregated entries with the timestamp of last entry") { - val batchedWal = new BatchedWriteAheadLog(wal, sparkConf) - // block the write so that we can batch some records - val (promise, waiter) = writeBlockingPromise(wal) + val blockingWal = new BlockingWriteAheadLog(wal, walHandle) + val batchedWal = new BatchedWriteAheadLog(blockingWal, sparkConf) val event1 = "hello" val event2 = "world" @@ -409,17 +386,19 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( // The queue.take() immediately takes the 3, and there is nothing left in the queue at that // moment. Then the promise blocks the writing of 3. The rest get queued. - promiseWriteEvent(batchedWal, event1, 3L) - waiter.await() + writeAsync(batchedWal, event1, 3L) + eventually(timeout(1 second)) { + assert(blockingWal.isBlocked) + } // rest of the records will be batched while it takes 3 to get written - promiseWriteEvent(batchedWal, event2, 5L) - promiseWriteEvent(batchedWal, event3, 8L) - promiseWriteEvent(batchedWal, event4, 12L) - promiseWriteEvent(batchedWal, event5, 10L) + writeAsync(batchedWal, event2, 5L) + writeAsync(batchedWal, event3, 8L) + writeAsync(batchedWal, event4, 12L) + writeAsync(batchedWal, event5, 10L) eventually(timeout(1 second)) { assert(walBatchingThreadPool.getActiveCount === 5) } - promise.success(true) + blockingWal.allowWrite() val buffer1 = wrapArrayArrayByte(Array(event1)) val buffer2 = wrapArrayArrayByte(Array(event2, event3, event4, event5)) @@ -441,10 +420,8 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( } test("BatchedWriteAheadLog - fail everything in queue during shutdown") { - val batchedWal = new BatchedWriteAheadLog(wal, sparkConf) - - // block the write so that we can batch some records - writeBlockingPromise(wal) + val blockingWal = new BlockingWriteAheadLog(wal, walHandle) + val batchedWal = new BatchedWriteAheadLog(blockingWal, sparkConf) val event1 = ("hello", 3L) val event2 = ("world", 5L) @@ -455,7 +432,7 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( // The queue.take() immediately takes the 3, and there is nothing left in the queue at that // moment. Then the promise blocks the writing of 3. The rest get queued. val writePromises = Seq(event1, event2, event3, event4, event5).map { event => - promiseWriteEvent(batchedWal, event._1, event._2) + writeAsync(batchedWal, event._1, event._2) } eventually(timeout(1 second)) { @@ -645,4 +622,35 @@ object WriteAheadLogSuite { def wrapArrayArrayByte[T](records: Array[T]): ByteBuffer = { ByteBuffer.wrap(Utils.serialize[Array[Array[Byte]]](records.map(Utils.serialize[T]))) } + + /** + * A wrapper WriteAheadLog that blocks the write function to allow batching with the + * BatchedWriteAheadLog. + */ + class BlockingWriteAheadLog( + wal: WriteAheadLog, + handle: WriteAheadLogRecordHandle) extends WriteAheadLog { + @volatile private var isWriteCalled: Boolean = false + @volatile private var blockWrite: Boolean = true + + override def write(record: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { + isWriteCalled = true + eventually(Eventually.timeout(2 second)) { + assert(!blockWrite) + } + wal.write(record, time) + isWriteCalled = false + handle + } + override def read(segment: WriteAheadLogRecordHandle): ByteBuffer = ??? + override def readAll(): JIterator[ByteBuffer] = ??? + override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = ??? + override def close(): Unit = wal.close() + + def allowWrite(): Unit = { + blockWrite = false + } + + def isBlocked: Boolean = isWriteCalled + } } From 66f04d8298738fa78958dfc426173bd55010761b Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 10 Nov 2015 14:32:24 -0800 Subject: [PATCH 3/5] fix scalastyle --- .../apache/spark/streaming/util/WriteAheadLogSuite.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index f580a7b09726..722e281e585c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -642,9 +642,11 @@ object WriteAheadLogSuite { isWriteCalled = false handle } - override def read(segment: WriteAheadLogRecordHandle): ByteBuffer = ??? - override def readAll(): JIterator[ByteBuffer] = ??? - override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = ??? + override def read(segment: WriteAheadLogRecordHandle): ByteBuffer = wal.read(segment) + override def readAll(): JIterator[ByteBuffer] = wal.readAll() + override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { + wal.clean(threshTime, waitForCompletion) + } override def close(): Unit = wal.close() def allowWrite(): Unit = { From c3c177d36caa711f0da646c6057d9ccf847165ff Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 10 Nov 2015 17:51:01 -0800 Subject: [PATCH 4/5] Address comment --- .../streaming/util/BatchedWriteAheadLog.scala | 3 +++ .../spark/streaming/util/WriteAheadLogSuite.scala | 14 +++++++++++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala index 9727ed2ba144..6e6ed8d81972 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala @@ -182,6 +182,9 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp buffer.clear() } } + + /** Method for querying the queue length. Should only be used in tests. */ + private def getQueueLength(): Int = walWriteQueue.size() } /** Static methods for aggregating and de-aggregating records. */ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 722e281e585c..646c327ce40d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -36,7 +36,7 @@ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.concurrent.Eventually import org.scalatest.concurrent.Eventually._ -import org.scalatest.{BeforeAndAfterEach, BeforeAndAfter} +import org.scalatest.{PrivateMethodTester, BeforeAndAfterEach, BeforeAndAfter} import org.scalatest.mock.MockitoSugar import org.apache.spark.streaming.scheduler._ @@ -314,7 +314,11 @@ class FileBasedWriteAheadLogWithFileCloseAfterWriteSuite class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( allowBatching = true, closeFileAfterWrite = false, - "BatchedWriteAheadLog") with MockitoSugar with BeforeAndAfterEach with Eventually { + "BatchedWriteAheadLog") + with MockitoSugar + with BeforeAndAfterEach + with Eventually + with PrivateMethodTester { import BatchedWriteAheadLog._ import WriteAheadLogSuite._ @@ -325,6 +329,8 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( private var walBatchingExecutionContext: ExecutionContextExecutorService = _ private val sparkConf = new SparkConf() + private val queueLength = PrivateMethod[Int]('getQueueLength) + override def beforeEach(): Unit = { wal = mock[WriteAheadLog] walHandle = mock[WriteAheadLogRecordHandle] @@ -390,13 +396,14 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( eventually(timeout(1 second)) { assert(blockingWal.isBlocked) } - // rest of the records will be batched while it takes 3 to get written + // rest of the records will be batched while it takes time for 3 to get written writeAsync(batchedWal, event2, 5L) writeAsync(batchedWal, event3, 8L) writeAsync(batchedWal, event4, 12L) writeAsync(batchedWal, event5, 10L) eventually(timeout(1 second)) { assert(walBatchingThreadPool.getActiveCount === 5) + assert(batchedWal.invokePrivate(queueLength()) === 4) } blockingWal.allowWrite() @@ -404,6 +411,7 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( val buffer2 = wrapArrayArrayByte(Array(event2, event3, event4, event5)) eventually(timeout(1 second)) { + assert(batchedWal.invokePrivate(queueLength()) === 0) verify(wal, times(1)).write(meq(buffer1), meq(3L)) // the file name should be the timestamp of the last record, as events should be naturally // in order of timestamp, and we need the last element. From 8f8d4103388464e075941bea9997b93946361040 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 10 Nov 2015 18:42:58 -0800 Subject: [PATCH 5/5] update --- .../streaming/util/WriteAheadLogSuite.scala | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 646c327ce40d..9e13f25c2efe 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -395,6 +395,7 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( writeAsync(batchedWal, event1, 3L) eventually(timeout(1 second)) { assert(blockingWal.isBlocked) + assert(batchedWal.invokePrivate(queueLength()) === 0) } // rest of the records will be batched while it takes time for 3 to get written writeAsync(batchedWal, event2, 5L) @@ -431,22 +432,29 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( val blockingWal = new BlockingWriteAheadLog(wal, walHandle) val batchedWal = new BatchedWriteAheadLog(blockingWal, sparkConf) - val event1 = ("hello", 3L) - val event2 = ("world", 5L) - val event3 = ("this", 8L) - val event4 = ("is", 9L) - val event5 = ("doge", 10L) + val event1 = "hello" + val event2 = "world" + val event3 = "this" // The queue.take() immediately takes the 3, and there is nothing left in the queue at that // moment. Then the promise blocks the writing of 3. The rest get queued. - val writePromises = Seq(event1, event2, event3, event4, event5).map { event => - writeAsync(batchedWal, event._1, event._2) + val promise1 = writeAsync(batchedWal, event1, 3L) + eventually(timeout(1 second)) { + assert(blockingWal.isBlocked) + assert(batchedWal.invokePrivate(queueLength()) === 0) } + // rest of the records will be batched while it takes time for 3 to get written + val promise2 = writeAsync(batchedWal, event2, 5L) + val promise3 = writeAsync(batchedWal, event3, 8L) eventually(timeout(1 second)) { - assert(walBatchingThreadPool.getActiveCount === 5) + assert(walBatchingThreadPool.getActiveCount === 3) + assert(blockingWal.isBlocked) + assert(batchedWal.invokePrivate(queueLength()) === 2) // event1 is being written } + val writePromises = Seq(promise1, promise2, promise3) + batchedWal.close() eventually(timeout(1 second)) { assert(writePromises.forall(_.isCompleted))