From 21bd1c37f4af6480adfc07130a15f70acdeda378 Mon Sep 17 00:00:00 2001 From: liyuanjian Date: Tue, 21 Aug 2018 13:24:07 +0800 Subject: [PATCH 1/6] [SPARK-25017][Core] Add test suite for BarrierCoordinator and ContextBarrierState --- .../org/apache/spark/BarrierCoordinator.scala | 6 +- .../scheduler/BarrierCoordinatorSuite.scala | 153 ++++++++++++++++++ 2 files changed, 156 insertions(+), 3 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/scheduler/BarrierCoordinatorSuite.scala diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index 5e546c694e8d9..30b792f5b4c07 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -65,7 +65,7 @@ private[spark] class BarrierCoordinator( // Record all active stage attempts that make barrier() call(s), and the corresponding internal // state. - private val states = new ConcurrentHashMap[ContextBarrierId, ContextBarrierState] + private[spark] val states = new ConcurrentHashMap[ContextBarrierId, ContextBarrierState] override def onStart(): Unit = { super.onStart() @@ -90,14 +90,14 @@ private[spark] class BarrierCoordinator( * @param numTasks Number of tasks of the barrier stage, all barrier() calls from the stage shall * collect `numTasks` requests to succeed. */ - private class ContextBarrierState( + private[spark] class ContextBarrierState( val barrierId: ContextBarrierId, val numTasks: Int) { // There may be multiple barrier() calls from a barrier stage attempt, `barrierEpoch` is used // to identify each barrier() call. It shall get increased when a barrier() call succeeds, or // reset when a barrier() call fails due to timeout. - private var barrierEpoch: Int = 0 + private[spark] var barrierEpoch: Int = 0 // An array of RPCCallContexts for barrier tasks that are waiting for reply of a barrier() // call. diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierCoordinatorSuite.scala new file mode 100644 index 0000000000000..cd298d4659456 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierCoordinatorSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.util.concurrent.TimeoutException + +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.apache.spark._ +import org.apache.spark.rpc.RpcTimeout + +class BarrierCoordinatorSuite extends SparkFunSuite with LocalSparkContext { + + /** + * Get the current barrierEpoch from barrierCoordinator.states by ContextBarrierId + */ + def getCurrentBarrierEpoch( + stageId: Int, stageAttemptId: Int, barrierCoordinator: BarrierCoordinator): Int = { + val barrierId = ContextBarrierId(stageId, stageAttemptId) + barrierCoordinator.states.get(barrierId).barrierEpoch + } + + test("normal test for single task") { + sc = new SparkContext("local", "test") + val barrierCoordinator = new BarrierCoordinator(5, sc.listenerBus, sc.env.rpcEnv) + val rpcEndpointRef = sc.env.rpcEnv.setupEndpoint("barrierCoordinator", barrierCoordinator) + val stageId = 0 + val stageAttemptNumber = 0 + rpcEndpointRef.askSync[Unit]( + message = RequestToSync(numTasks = 1, stageId, stageAttemptNumber, taskAttemptId = 0, + barrierEpoch = 0), + timeout = new RpcTimeout(5 seconds, "rpcTimeOut")) + // sleep for waiting barrierEpoch value change + Thread.sleep(500) + assert(getCurrentBarrierEpoch(stageId, stageAttemptNumber, barrierCoordinator) == 1) + } + + test("normal test for multi tasks") { + sc = new SparkContext("local", "test") + val barrierCoordinator = new BarrierCoordinator(5, sc.listenerBus, sc.env.rpcEnv) + val rpcEndpointRef = sc.env.rpcEnv.setupEndpoint("barrierCoordinator", barrierCoordinator) + val numTasks = 3 + val stageId = 0 + val stageAttemptNumber = 0 + val rpcTimeOut = new RpcTimeout(5 seconds, "rpcTimeOut") + // sync request from 3 tasks + (0 until numTasks).foreach { taskId => + new Thread(s"task-$taskId-thread") { + setDaemon(true) + override def run(): Unit = { + rpcEndpointRef.askSync[Unit]( + message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId = taskId, + barrierEpoch = 0), + timeout = rpcTimeOut) + } + }.start() + } + // sleep for waiting barrierEpoch value change + Thread.sleep(500) + assert(getCurrentBarrierEpoch(stageId, stageAttemptNumber, barrierCoordinator) == 1) + } + + test("abnormal test for syncing with illegal barrierId") { + sc = new SparkContext("local", "test") + val barrierCoordinator = new BarrierCoordinator(5, sc.listenerBus, sc.env.rpcEnv) + val rpcEndpointRef = sc.env.rpcEnv.setupEndpoint("barrierCoordinator", barrierCoordinator) + val numTasks = 3 + val stageId = 0 + val stageAttemptNumber = 0 + val rpcTimeOut = new RpcTimeout(5 seconds, "rpcTimeOut") + intercept[SparkException]( + rpcEndpointRef.askSync[Unit]( + message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId = 0, + barrierEpoch = -1), // illegal barrierId = -1 + timeout = rpcTimeOut)) + } + + test("abnormal test for syncing with old barrierId") { + sc = new SparkContext("local", "test") + val barrierCoordinator = new BarrierCoordinator(5, sc.listenerBus, sc.env.rpcEnv) + val rpcEndpointRef = sc.env.rpcEnv.setupEndpoint("barrierCoordinator", barrierCoordinator) + val numTasks = 3 + val stageId = 0 + val stageAttemptNumber = 0 + val rpcTimeOut = new RpcTimeout(5 seconds, "rpcTimeOut") + // sync request from 3 tasks + (0 until numTasks).foreach { taskId => + new Thread(s"task-$taskId-thread") { + setDaemon(true) + override def run(): Unit = { + rpcEndpointRef.askSync[Unit]( + message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId = taskId, + barrierEpoch = 0), + timeout = rpcTimeOut) + } + }.start() + } + // sleep for waiting barrierEpoch value change + Thread.sleep(500) + assert(getCurrentBarrierEpoch(stageId, stageAttemptNumber, barrierCoordinator) == 1) + intercept[SparkException]( + rpcEndpointRef.askSync[Unit]( + message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId = 0, + barrierEpoch = 0), + timeout = rpcTimeOut)) + } + + test("abnormal test for timeout when rpcTimeOut < barrierTimeOut") { + sc = new SparkContext("local", "test") + val barrierCoordinator = new BarrierCoordinator(2, sc.listenerBus, sc.env.rpcEnv) + val rpcEndpointRef = sc.env.rpcEnv.setupEndpoint("barrierCoordinator", barrierCoordinator) + val numTasks = 3 + val stageId = 0 + val stageAttemptNumber = 0 + val rpcTimeOut = new RpcTimeout(1 seconds, "rpcTimeOut") + intercept[TimeoutException]( + rpcEndpointRef.askSync[Unit]( + message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId = 0, + barrierEpoch = 0), + timeout = rpcTimeOut)) + } + + test("abnormal test for timeout when rpcTimeOut > barrierTimeOut") { + sc = new SparkContext("local", "test") + val barrierCoordinator = new BarrierCoordinator(2, sc.listenerBus, sc.env.rpcEnv) + val rpcEndpointRef = sc.env.rpcEnv.setupEndpoint("barrierCoordinator", barrierCoordinator) + val numTasks = 3 + val stageId = 0 + val stageAttemptNumber = 0 + val rpcTimeOut = new RpcTimeout(4 seconds, "rpcTimeOut") + intercept[SparkException]( + rpcEndpointRef.askSync[Unit]( + message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId = 0, + barrierEpoch = 0), + timeout = rpcTimeOut)) + } +} From ecf12bdd78b4403806c053c2fc97f05cf37e67f9 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Thu, 6 Sep 2018 21:54:54 +0800 Subject: [PATCH 2/6] Address comments and add clean check for internal data. --- .../org/apache/spark/BarrierCoordinator.scala | 13 ++++-- .../scheduler/BarrierCoordinatorSuite.scala | 41 ++++++++++++------- 2 files changed, 37 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index 30b792f5b4c07..0be67cec00a4d 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -63,8 +63,12 @@ private[spark] class BarrierCoordinator( } } - // Record all active stage attempts that make barrier() call(s), and the corresponding internal - // state. + /** + * Record all active stage attempts that make barrier() call(s), and the corresponding internal + * state. + * + * Visible for testing. + */ private[spark] val states = new ConcurrentHashMap[ContextBarrierId, ContextBarrierState] override def onStart(): Unit = { @@ -84,7 +88,7 @@ private[spark] class BarrierCoordinator( /** * Provide the current state of a barrier() call. A state is created when a new stage attempt - * sends out a barrier() call, and recycled on stage completed. + * sends out a barrier() call, and recycled on stage completed. Visible for testing. * * @param barrierId Identifier of the barrier stage that make a barrier() call. * @param numTasks Number of tasks of the barrier stage, all barrier() calls from the stage shall @@ -187,6 +191,9 @@ private[spark] class BarrierCoordinator( requesters.clear() cancelTimerTask() } + + // Check for clearing internal data, visible for test only. + private[spark] def cleanCheck(): Boolean = requesters.isEmpty && timerTask == null } // Clean up the [[ContextBarrierState]] that correspond to a specific stage attempt. diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierCoordinatorSuite.scala index cd298d4659456..81c384febbe8e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierCoordinatorSuite.scala @@ -22,18 +22,22 @@ import java.util.concurrent.TimeoutException import scala.concurrent.duration._ import scala.language.postfixOps +import org.scalatest.concurrent.Eventually + import org.apache.spark._ import org.apache.spark.rpc.RpcTimeout -class BarrierCoordinatorSuite extends SparkFunSuite with LocalSparkContext { +class BarrierCoordinatorSuite extends SparkFunSuite with LocalSparkContext with Eventually { /** - * Get the current barrierEpoch from barrierCoordinator.states by ContextBarrierId + * Get the current ContextBarrierState from barrierCoordinator.states by ContextBarrierId. */ - def getCurrentBarrierEpoch( - stageId: Int, stageAttemptId: Int, barrierCoordinator: BarrierCoordinator): Int = { + private def getBarrierState( + stageId: Int, + stageAttemptId: Int, + barrierCoordinator: BarrierCoordinator) = { val barrierId = ContextBarrierId(stageId, stageAttemptId) - barrierCoordinator.states.get(barrierId).barrierEpoch + barrierCoordinator.states.get(barrierId) } test("normal test for single task") { @@ -46,9 +50,12 @@ class BarrierCoordinatorSuite extends SparkFunSuite with LocalSparkContext { message = RequestToSync(numTasks = 1, stageId, stageAttemptNumber, taskAttemptId = 0, barrierEpoch = 0), timeout = new RpcTimeout(5 seconds, "rpcTimeOut")) - // sleep for waiting barrierEpoch value change - Thread.sleep(500) - assert(getCurrentBarrierEpoch(stageId, stageAttemptNumber, barrierCoordinator) == 1) + eventually(timeout(10.seconds)) { + // Ensure barrierEpoch value have been changed. + val barrierState = getBarrierState(stageId, stageAttemptNumber, barrierCoordinator) + assert(barrierState.barrierEpoch == 1) + assert(barrierState.cleanCheck()) + } } test("normal test for multi tasks") { @@ -71,9 +78,12 @@ class BarrierCoordinatorSuite extends SparkFunSuite with LocalSparkContext { } }.start() } - // sleep for waiting barrierEpoch value change - Thread.sleep(500) - assert(getCurrentBarrierEpoch(stageId, stageAttemptNumber, barrierCoordinator) == 1) + eventually(timeout(10.seconds)) { + // Ensure barrierEpoch value have been changed. + val barrierState = getBarrierState(stageId, stageAttemptNumber, barrierCoordinator) + assert(barrierState.barrierEpoch == 1) + assert(barrierState.cleanCheck()) + } } test("abnormal test for syncing with illegal barrierId") { @@ -111,9 +121,12 @@ class BarrierCoordinatorSuite extends SparkFunSuite with LocalSparkContext { } }.start() } - // sleep for waiting barrierEpoch value change - Thread.sleep(500) - assert(getCurrentBarrierEpoch(stageId, stageAttemptNumber, barrierCoordinator) == 1) + eventually(timeout(10.seconds)) { + // Ensure barrierEpoch value have been changed. + val barrierState = getBarrierState(stageId, stageAttemptNumber, barrierCoordinator) + assert(barrierState.barrierEpoch == 1) + assert(barrierState.cleanCheck()) + } intercept[SparkException]( rpcEndpointRef.askSync[Unit]( message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId = 0, From ec8466a8272b93e12fa651e65db65f12148576bb Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Tue, 18 Sep 2018 15:46:36 +0800 Subject: [PATCH 3/6] Address comment by kiszk --- .../main/scala/org/apache/spark/BarrierCoordinator.scala | 5 ++++- .../apache/spark/scheduler/BarrierCoordinatorSuite.scala | 6 +++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index 0be67cec00a4d..2eab2eb8ff004 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -101,7 +101,7 @@ private[spark] class BarrierCoordinator( // There may be multiple barrier() calls from a barrier stage attempt, `barrierEpoch` is used // to identify each barrier() call. It shall get increased when a barrier() call succeeds, or // reset when a barrier() call fails due to timeout. - private[spark] var barrierEpoch: Int = 0 + private var barrierEpoch: Int = 0 // An array of RPCCallContexts for barrier tasks that are waiting for reply of a barrier() // call. @@ -194,6 +194,9 @@ private[spark] class BarrierCoordinator( // Check for clearing internal data, visible for test only. private[spark] def cleanCheck(): Boolean = requesters.isEmpty && timerTask == null + + // Get currently barrier epoch, visible for test only. + private[spark] def getBarrierEpoch(): Int = barrierEpoch } // Clean up the [[ContextBarrierState]] that correspond to a specific stage attempt. diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierCoordinatorSuite.scala index 81c384febbe8e..8587a9eb8598e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierCoordinatorSuite.scala @@ -53,7 +53,7 @@ class BarrierCoordinatorSuite extends SparkFunSuite with LocalSparkContext with eventually(timeout(10.seconds)) { // Ensure barrierEpoch value have been changed. val barrierState = getBarrierState(stageId, stageAttemptNumber, barrierCoordinator) - assert(barrierState.barrierEpoch == 1) + assert(barrierState.getBarrierEpoch() == 1) assert(barrierState.cleanCheck()) } } @@ -81,7 +81,7 @@ class BarrierCoordinatorSuite extends SparkFunSuite with LocalSparkContext with eventually(timeout(10.seconds)) { // Ensure barrierEpoch value have been changed. val barrierState = getBarrierState(stageId, stageAttemptNumber, barrierCoordinator) - assert(barrierState.barrierEpoch == 1) + assert(barrierState.getBarrierEpoch() == 1) assert(barrierState.cleanCheck()) } } @@ -124,7 +124,7 @@ class BarrierCoordinatorSuite extends SparkFunSuite with LocalSparkContext with eventually(timeout(10.seconds)) { // Ensure barrierEpoch value have been changed. val barrierState = getBarrierState(stageId, stageAttemptNumber, barrierCoordinator) - assert(barrierState.barrierEpoch == 1) + assert(barrierState.getBarrierEpoch() == 1) assert(barrierState.cleanCheck()) } intercept[SparkException]( From 8cd78a95a0e0649fed81fe6217790943855b7417 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Thu, 27 Sep 2018 23:59:19 +0800 Subject: [PATCH 4/6] Get rid of RPC framework in UT --- .../org/apache/spark/BarrierCoordinator.scala | 2 +- .../scheduler/BarrierCoordinatorSuite.scala | 181 ++++++++++-------- 2 files changed, 98 insertions(+), 85 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index 2eab2eb8ff004..75af0b550a5cb 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -145,7 +145,7 @@ private[spark] class BarrierCoordinator( logInfo(s"Current barrier epoch for $barrierId is $barrierEpoch.") if (epoch != barrierEpoch) { requester.sendFailure(new SparkException(s"The request to sync of $barrierId with " + - s"barrier epoch $barrierEpoch has already finished. Maybe task $taskId is not " + + s"barrier epoch $epoch has already finished. Maybe task $taskId is not " + "properly killed.")) } else { // If this is the first sync message received for a barrier() call, start timer to ensure diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierCoordinatorSuite.scala index 8587a9eb8598e..a4fa5e0339351 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierCoordinatorSuite.scala @@ -17,15 +17,16 @@ package org.apache.spark.scheduler -import java.util.concurrent.TimeoutException - import scala.concurrent.duration._ import scala.language.postfixOps +import org.mockito.ArgumentMatcher +import org.mockito.Matchers._ +import org.mockito.Mockito._ import org.scalatest.concurrent.Eventually import org.apache.spark._ -import org.apache.spark.rpc.RpcTimeout +import org.apache.spark.rpc.{RpcAddress, RpcCallContext} class BarrierCoordinatorSuite extends SparkFunSuite with LocalSparkContext with Eventually { @@ -40,127 +41,139 @@ class BarrierCoordinatorSuite extends SparkFunSuite with LocalSparkContext with barrierCoordinator.states.get(barrierId) } + private def mockRpcCallContext() = { + val rpcAddress = mock(classOf[RpcAddress]) + val rpcCallContext = mock(classOf[RpcCallContext]) + when(rpcCallContext.senderAddress).thenReturn(rpcAddress) + rpcCallContext + } + test("normal test for single task") { sc = new SparkContext("local", "test") val barrierCoordinator = new BarrierCoordinator(5, sc.listenerBus, sc.env.rpcEnv) - val rpcEndpointRef = sc.env.rpcEnv.setupEndpoint("barrierCoordinator", barrierCoordinator) val stageId = 0 val stageAttemptNumber = 0 - rpcEndpointRef.askSync[Unit]( - message = RequestToSync(numTasks = 1, stageId, stageAttemptNumber, taskAttemptId = 0, - barrierEpoch = 0), - timeout = new RpcTimeout(5 seconds, "rpcTimeOut")) - eventually(timeout(10.seconds)) { - // Ensure barrierEpoch value have been changed. - val barrierState = getBarrierState(stageId, stageAttemptNumber, barrierCoordinator) - assert(barrierState.getBarrierEpoch() == 1) - assert(barrierState.cleanCheck()) - } + barrierCoordinator.receiveAndReply(mockRpcCallContext())( + RequestToSync( + numTasks = 1, + stageId, + stageAttemptNumber, + taskAttemptId = 0, + barrierEpoch = 0)) + // Ensure barrierEpoch value have been changed. + val barrierState = getBarrierState(stageId, stageAttemptNumber, barrierCoordinator) + assert(barrierState.getBarrierEpoch() == 1) + assert(barrierState.cleanCheck()) } test("normal test for multi tasks") { sc = new SparkContext("local", "test") val barrierCoordinator = new BarrierCoordinator(5, sc.listenerBus, sc.env.rpcEnv) - val rpcEndpointRef = sc.env.rpcEnv.setupEndpoint("barrierCoordinator", barrierCoordinator) val numTasks = 3 val stageId = 0 val stageAttemptNumber = 0 - val rpcTimeOut = new RpcTimeout(5 seconds, "rpcTimeOut") - // sync request from 3 tasks + // request from 3 tasks (0 until numTasks).foreach { taskId => - new Thread(s"task-$taskId-thread") { - setDaemon(true) - override def run(): Unit = { - rpcEndpointRef.askSync[Unit]( - message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId = taskId, - barrierEpoch = 0), - timeout = rpcTimeOut) - } - }.start() - } - eventually(timeout(10.seconds)) { - // Ensure barrierEpoch value have been changed. - val barrierState = getBarrierState(stageId, stageAttemptNumber, barrierCoordinator) - assert(barrierState.getBarrierEpoch() == 1) - assert(barrierState.cleanCheck()) + barrierCoordinator.receiveAndReply(mockRpcCallContext())( + RequestToSync( + numTasks, + stageId, + stageAttemptNumber, + taskAttemptId = taskId, + barrierEpoch = 0)) } + // Ensure barrierEpoch value have been changed. + val barrierState = getBarrierState(stageId, stageAttemptNumber, barrierCoordinator) + assert(barrierState.getBarrierEpoch() == 1) + assert(barrierState.cleanCheck()) } test("abnormal test for syncing with illegal barrierId") { sc = new SparkContext("local", "test") val barrierCoordinator = new BarrierCoordinator(5, sc.listenerBus, sc.env.rpcEnv) - val rpcEndpointRef = sc.env.rpcEnv.setupEndpoint("barrierCoordinator", barrierCoordinator) val numTasks = 3 val stageId = 0 val stageAttemptNumber = 0 - val rpcTimeOut = new RpcTimeout(5 seconds, "rpcTimeOut") - intercept[SparkException]( - rpcEndpointRef.askSync[Unit]( - message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId = 0, - barrierEpoch = -1), // illegal barrierId = -1 - timeout = rpcTimeOut)) + val rpcCallContext = mockRpcCallContext() + barrierCoordinator.receiveAndReply(rpcCallContext)( + // illegal barrierId = -1 + RequestToSync( + numTasks, + stageId, + stageAttemptNumber, + taskAttemptId = 0, + barrierEpoch = -1)) + verify(rpcCallContext, times(1)) + .sendFailure(argThat(new ArgumentMatcher[SparkException] { + override def matches(e: Any): Boolean = { + e.asInstanceOf[SparkException].getMessage == + "The request to sync of Stage 0 (Attempt 0) with barrier epoch -1 has already" + + " finished. Maybe task 0 is not properly killed." + } + })) } test("abnormal test for syncing with old barrierId") { sc = new SparkContext("local", "test") val barrierCoordinator = new BarrierCoordinator(5, sc.listenerBus, sc.env.rpcEnv) - val rpcEndpointRef = sc.env.rpcEnv.setupEndpoint("barrierCoordinator", barrierCoordinator) val numTasks = 3 val stageId = 0 val stageAttemptNumber = 0 - val rpcTimeOut = new RpcTimeout(5 seconds, "rpcTimeOut") - // sync request from 3 tasks + val rpcCallContext = mockRpcCallContext() + // request from 3 tasks (0 until numTasks).foreach { taskId => - new Thread(s"task-$taskId-thread") { - setDaemon(true) - override def run(): Unit = { - rpcEndpointRef.askSync[Unit]( - message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId = taskId, - barrierEpoch = 0), - timeout = rpcTimeOut) - } - }.start() - } - eventually(timeout(10.seconds)) { - // Ensure barrierEpoch value have been changed. - val barrierState = getBarrierState(stageId, stageAttemptNumber, barrierCoordinator) - assert(barrierState.getBarrierEpoch() == 1) - assert(barrierState.cleanCheck()) + barrierCoordinator.receiveAndReply(mockRpcCallContext())( + RequestToSync( + numTasks, + stageId, + stageAttemptNumber, + taskAttemptId = taskId, + barrierEpoch = 0)) } - intercept[SparkException]( - rpcEndpointRef.askSync[Unit]( - message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId = 0, - barrierEpoch = 0), - timeout = rpcTimeOut)) - } - - test("abnormal test for timeout when rpcTimeOut < barrierTimeOut") { - sc = new SparkContext("local", "test") - val barrierCoordinator = new BarrierCoordinator(2, sc.listenerBus, sc.env.rpcEnv) - val rpcEndpointRef = sc.env.rpcEnv.setupEndpoint("barrierCoordinator", barrierCoordinator) - val numTasks = 3 - val stageId = 0 - val stageAttemptNumber = 0 - val rpcTimeOut = new RpcTimeout(1 seconds, "rpcTimeOut") - intercept[TimeoutException]( - rpcEndpointRef.askSync[Unit]( - message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId = 0, - barrierEpoch = 0), - timeout = rpcTimeOut)) + // Ensure barrierEpoch value have been changed. + val barrierState = getBarrierState(stageId, stageAttemptNumber, barrierCoordinator) + assert(barrierState.getBarrierEpoch() == 1) + assert(barrierState.cleanCheck()) + barrierCoordinator.receiveAndReply(rpcCallContext)( + RequestToSync( + numTasks, + stageId, + stageAttemptNumber, + taskAttemptId = 0, + barrierEpoch = 0)) + verify(rpcCallContext, times(1)) + .sendFailure(argThat(new ArgumentMatcher[SparkException] { + override def matches(e: Any): Boolean = { + e.asInstanceOf[SparkException].getMessage == + "The request to sync of Stage 0 (Attempt 0) with barrier epoch 0 has already" + + " finished. Maybe task 0 is not properly killed." + }})) } test("abnormal test for timeout when rpcTimeOut > barrierTimeOut") { sc = new SparkContext("local", "test") val barrierCoordinator = new BarrierCoordinator(2, sc.listenerBus, sc.env.rpcEnv) - val rpcEndpointRef = sc.env.rpcEnv.setupEndpoint("barrierCoordinator", barrierCoordinator) val numTasks = 3 val stageId = 0 val stageAttemptNumber = 0 - val rpcTimeOut = new RpcTimeout(4 seconds, "rpcTimeOut") - intercept[SparkException]( - rpcEndpointRef.askSync[Unit]( - message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId = 0, - barrierEpoch = 0), - timeout = rpcTimeOut)) + val rpcCallContext = mockRpcCallContext() + barrierCoordinator.receiveAndReply(rpcCallContext)( + // illegal barrierId = -1 + RequestToSync( + numTasks, + stageId, + stageAttemptNumber, + taskAttemptId = 0, + barrierEpoch = 0)) + eventually(timeout(5.seconds)) { + verify(rpcCallContext, times(1)) + .sendFailure(argThat(new ArgumentMatcher[SparkException] { + override def matches(e: Any): Boolean = { + e.asInstanceOf[SparkException].getMessage == + "The coordinator didn't get all barrier sync requests for barrier epoch" + + " 0 from Stage 0 (Attempt 0) within 2 second(s)." + } + })) + } } } From fd4d150bd638d8bfb91a487bd948bfaf3b3f0d56 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Fri, 28 Sep 2018 00:11:23 +0800 Subject: [PATCH 5/6] address comments --- .../main/scala/org/apache/spark/BarrierCoordinator.scala | 4 ++-- .../apache/spark/scheduler/BarrierCoordinatorSuite.scala | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index 75af0b550a5cb..c0d9546710ff3 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -192,8 +192,8 @@ private[spark] class BarrierCoordinator( cancelTimerTask() } - // Check for clearing internal data, visible for test only. - private[spark] def cleanCheck(): Boolean = requesters.isEmpty && timerTask == null + // Check for internal state clear, visible for test only. + private[spark] def isInternalStateClear(): Boolean = requesters.isEmpty && timerTask == null // Get currently barrier epoch, visible for test only. private[spark] def getBarrierEpoch(): Int = barrierEpoch diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierCoordinatorSuite.scala index a4fa5e0339351..0c3119f815672 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierCoordinatorSuite.scala @@ -63,7 +63,7 @@ class BarrierCoordinatorSuite extends SparkFunSuite with LocalSparkContext with // Ensure barrierEpoch value have been changed. val barrierState = getBarrierState(stageId, stageAttemptNumber, barrierCoordinator) assert(barrierState.getBarrierEpoch() == 1) - assert(barrierState.cleanCheck()) + assert(barrierState.isInternalStateClear()) } test("normal test for multi tasks") { @@ -85,7 +85,7 @@ class BarrierCoordinatorSuite extends SparkFunSuite with LocalSparkContext with // Ensure barrierEpoch value have been changed. val barrierState = getBarrierState(stageId, stageAttemptNumber, barrierCoordinator) assert(barrierState.getBarrierEpoch() == 1) - assert(barrierState.cleanCheck()) + assert(barrierState.isInternalStateClear()) } test("abnormal test for syncing with illegal barrierId") { @@ -133,7 +133,7 @@ class BarrierCoordinatorSuite extends SparkFunSuite with LocalSparkContext with // Ensure barrierEpoch value have been changed. val barrierState = getBarrierState(stageId, stageAttemptNumber, barrierCoordinator) assert(barrierState.getBarrierEpoch() == 1) - assert(barrierState.cleanCheck()) + assert(barrierState.isInternalStateClear()) barrierCoordinator.receiveAndReply(rpcCallContext)( RequestToSync( numTasks, From aea2fa0b7c3dbda1ff7b652fcb9e7232013840d7 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Sat, 13 Oct 2018 18:00:42 +0800 Subject: [PATCH 6/6] Don't launch SparkContext --- .../org/apache/spark/BarrierCoordinator.scala | 14 +-- ...e.scala => ContextBarrierStateSuite.scala} | 100 +++++++++--------- 2 files changed, 53 insertions(+), 61 deletions(-) rename core/src/test/scala/org/apache/spark/scheduler/{BarrierCoordinatorSuite.scala => ContextBarrierStateSuite.scala} (64%) diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index c0d9546710ff3..2e984db5e2f07 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -32,7 +32,7 @@ import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListener * we can use (stageId, stageAttemptId) to identify the stage attempt where the barrier() call is * from. */ -private case class ContextBarrierId(stageId: Int, stageAttemptId: Int) { +private[spark] case class ContextBarrierId(stageId: Int, stageAttemptId: Int) { override def toString: String = s"Stage $stageId (Attempt $stageAttemptId)" } @@ -63,13 +63,9 @@ private[spark] class BarrierCoordinator( } } - /** - * Record all active stage attempts that make barrier() call(s), and the corresponding internal - * state. - * - * Visible for testing. - */ - private[spark] val states = new ConcurrentHashMap[ContextBarrierId, ContextBarrierState] + // Record all active stage attempts that make barrier() call(s), and the corresponding internal + // state. + private val states = new ConcurrentHashMap[ContextBarrierId, ContextBarrierState] override def onStart(): Unit = { super.onStart() @@ -225,7 +221,7 @@ private[spark] class BarrierCoordinator( } } -private[spark] sealed trait BarrierCoordinatorMessage extends Serializable +private sealed trait BarrierCoordinatorMessage extends Serializable /** * A global sync request message from BarrierTaskContext, by `barrier()` call. Each request is diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ContextBarrierStateSuite.scala similarity index 64% rename from core/src/test/scala/org/apache/spark/scheduler/BarrierCoordinatorSuite.scala rename to core/src/test/scala/org/apache/spark/scheduler/ContextBarrierStateSuite.scala index 0c3119f815672..af912c84c0dfc 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ContextBarrierStateSuite.scala @@ -26,20 +26,9 @@ import org.mockito.Mockito._ import org.scalatest.concurrent.Eventually import org.apache.spark._ -import org.apache.spark.rpc.{RpcAddress, RpcCallContext} +import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv} -class BarrierCoordinatorSuite extends SparkFunSuite with LocalSparkContext with Eventually { - - /** - * Get the current ContextBarrierState from barrierCoordinator.states by ContextBarrierId. - */ - private def getBarrierState( - stageId: Int, - stageAttemptId: Int, - barrierCoordinator: BarrierCoordinator) = { - val barrierId = ContextBarrierId(stageId, stageAttemptId) - barrierCoordinator.states.get(barrierId) - } +class ContextBarrierStateSuite extends SparkFunSuite with LocalSparkContext with Eventually { private def mockRpcCallContext() = { val rpcAddress = mock(classOf[RpcAddress]) @@ -49,11 +38,14 @@ class BarrierCoordinatorSuite extends SparkFunSuite with LocalSparkContext with } test("normal test for single task") { - sc = new SparkContext("local", "test") - val barrierCoordinator = new BarrierCoordinator(5, sc.listenerBus, sc.env.rpcEnv) + val barrierCoordinator = new BarrierCoordinator( + 5, mock(classOf[LiveListenerBus]), mock(classOf[RpcEnv])) val stageId = 0 val stageAttemptNumber = 0 - barrierCoordinator.receiveAndReply(mockRpcCallContext())( + val state = new barrierCoordinator.ContextBarrierState( + ContextBarrierId(stageId, stageAttemptNumber), numTasks = 1) + state.handleRequest( + mockRpcCallContext(), RequestToSync( numTasks = 1, stageId, @@ -61,42 +53,43 @@ class BarrierCoordinatorSuite extends SparkFunSuite with LocalSparkContext with taskAttemptId = 0, barrierEpoch = 0)) // Ensure barrierEpoch value have been changed. - val barrierState = getBarrierState(stageId, stageAttemptNumber, barrierCoordinator) - assert(barrierState.getBarrierEpoch() == 1) - assert(barrierState.isInternalStateClear()) + assert(state.getBarrierEpoch() == 1) + assert(state.isInternalStateClear()) } test("normal test for multi tasks") { - sc = new SparkContext("local", "test") - val barrierCoordinator = new BarrierCoordinator(5, sc.listenerBus, sc.env.rpcEnv) + val barrierCoordinator = new BarrierCoordinator( + 5, mock(classOf[LiveListenerBus]), mock(classOf[RpcEnv])) val numTasks = 3 val stageId = 0 val stageAttemptNumber = 0 + val state = new barrierCoordinator.ContextBarrierState( + ContextBarrierId(stageId, stageAttemptNumber), numTasks) // request from 3 tasks (0 until numTasks).foreach { taskId => - barrierCoordinator.receiveAndReply(mockRpcCallContext())( - RequestToSync( - numTasks, - stageId, - stageAttemptNumber, - taskAttemptId = taskId, - barrierEpoch = 0)) + state.handleRequest(mockRpcCallContext(), RequestToSync( + numTasks, + stageId, + stageAttemptNumber, + taskAttemptId = taskId, + barrierEpoch = 0)) } // Ensure barrierEpoch value have been changed. - val barrierState = getBarrierState(stageId, stageAttemptNumber, barrierCoordinator) - assert(barrierState.getBarrierEpoch() == 1) - assert(barrierState.isInternalStateClear()) + assert(state.getBarrierEpoch() == 1) + assert(state.isInternalStateClear()) } test("abnormal test for syncing with illegal barrierId") { - sc = new SparkContext("local", "test") - val barrierCoordinator = new BarrierCoordinator(5, sc.listenerBus, sc.env.rpcEnv) + val barrierCoordinator = new BarrierCoordinator( + 5, mock(classOf[LiveListenerBus]), mock(classOf[RpcEnv])) val numTasks = 3 val stageId = 0 val stageAttemptNumber = 0 val rpcCallContext = mockRpcCallContext() - barrierCoordinator.receiveAndReply(rpcCallContext)( - // illegal barrierId = -1 + val state = new barrierCoordinator.ContextBarrierState( + ContextBarrierId(stageId, stageAttemptNumber), numTasks) + state.handleRequest( + rpcCallContext, RequestToSync( numTasks, stageId, @@ -114,15 +107,18 @@ class BarrierCoordinatorSuite extends SparkFunSuite with LocalSparkContext with } test("abnormal test for syncing with old barrierId") { - sc = new SparkContext("local", "test") - val barrierCoordinator = new BarrierCoordinator(5, sc.listenerBus, sc.env.rpcEnv) + val barrierCoordinator = new BarrierCoordinator( + 5, mock(classOf[LiveListenerBus]), mock(classOf[RpcEnv])) val numTasks = 3 val stageId = 0 val stageAttemptNumber = 0 val rpcCallContext = mockRpcCallContext() + val state = new barrierCoordinator.ContextBarrierState( + ContextBarrierId(stageId, stageAttemptNumber), numTasks) // request from 3 tasks (0 until numTasks).foreach { taskId => - barrierCoordinator.receiveAndReply(mockRpcCallContext())( + state.handleRequest( + rpcCallContext, RequestToSync( numTasks, stageId, @@ -131,10 +127,10 @@ class BarrierCoordinatorSuite extends SparkFunSuite with LocalSparkContext with barrierEpoch = 0)) } // Ensure barrierEpoch value have been changed. - val barrierState = getBarrierState(stageId, stageAttemptNumber, barrierCoordinator) - assert(barrierState.getBarrierEpoch() == 1) - assert(barrierState.isInternalStateClear()) - barrierCoordinator.receiveAndReply(rpcCallContext)( + assert(state.getBarrierEpoch() == 1) + assert(state.isInternalStateClear()) + state.handleRequest( + rpcCallContext, RequestToSync( numTasks, stageId, @@ -151,20 +147,20 @@ class BarrierCoordinatorSuite extends SparkFunSuite with LocalSparkContext with } test("abnormal test for timeout when rpcTimeOut > barrierTimeOut") { - sc = new SparkContext("local", "test") - val barrierCoordinator = new BarrierCoordinator(2, sc.listenerBus, sc.env.rpcEnv) + val barrierCoordinator = new BarrierCoordinator( + 2, mock(classOf[LiveListenerBus]), mock(classOf[RpcEnv])) val numTasks = 3 val stageId = 0 val stageAttemptNumber = 0 val rpcCallContext = mockRpcCallContext() - barrierCoordinator.receiveAndReply(rpcCallContext)( - // illegal barrierId = -1 - RequestToSync( - numTasks, - stageId, - stageAttemptNumber, - taskAttemptId = 0, - barrierEpoch = 0)) + val state = new barrierCoordinator.ContextBarrierState( + ContextBarrierId(stageId, stageAttemptNumber), numTasks) + state.handleRequest(rpcCallContext, RequestToSync( + numTasks, + stageId, + stageAttemptNumber, + taskAttemptId = 0, + barrierEpoch = 0)) eventually(timeout(5.seconds)) { verify(rpcCallContext, times(1)) .sendFailure(argThat(new ArgumentMatcher[SparkException] {