diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index 5e546c694e8d9..2e984db5e2f07 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -32,7 +32,7 @@ import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListener * we can use (stageId, stageAttemptId) to identify the stage attempt where the barrier() call is * from. */ -private case class ContextBarrierId(stageId: Int, stageAttemptId: Int) { +private[spark] case class ContextBarrierId(stageId: Int, stageAttemptId: Int) { override def toString: String = s"Stage $stageId (Attempt $stageAttemptId)" } @@ -84,13 +84,13 @@ private[spark] class BarrierCoordinator( /** * Provide the current state of a barrier() call. A state is created when a new stage attempt - * sends out a barrier() call, and recycled on stage completed. + * sends out a barrier() call, and recycled on stage completed. Visible for testing. * * @param barrierId Identifier of the barrier stage that make a barrier() call. * @param numTasks Number of tasks of the barrier stage, all barrier() calls from the stage shall * collect `numTasks` requests to succeed. */ - private class ContextBarrierState( + private[spark] class ContextBarrierState( val barrierId: ContextBarrierId, val numTasks: Int) { @@ -141,7 +141,7 @@ private[spark] class BarrierCoordinator( logInfo(s"Current barrier epoch for $barrierId is $barrierEpoch.") if (epoch != barrierEpoch) { requester.sendFailure(new SparkException(s"The request to sync of $barrierId with " + - s"barrier epoch $barrierEpoch has already finished. Maybe task $taskId is not " + + s"barrier epoch $epoch has already finished. Maybe task $taskId is not " + "properly killed.")) } else { // If this is the first sync message received for a barrier() call, start timer to ensure @@ -187,6 +187,12 @@ private[spark] class BarrierCoordinator( requesters.clear() cancelTimerTask() } + + // Check for internal state clear, visible for test only. + private[spark] def isInternalStateClear(): Boolean = requesters.isEmpty && timerTask == null + + // Get currently barrier epoch, visible for test only. + private[spark] def getBarrierEpoch(): Int = barrierEpoch } // Clean up the [[ContextBarrierState]] that correspond to a specific stage attempt. @@ -215,7 +221,7 @@ private[spark] class BarrierCoordinator( } } -private[spark] sealed trait BarrierCoordinatorMessage extends Serializable +private sealed trait BarrierCoordinatorMessage extends Serializable /** * A global sync request message from BarrierTaskContext, by `barrier()` call. Each request is diff --git a/core/src/test/scala/org/apache/spark/scheduler/ContextBarrierStateSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ContextBarrierStateSuite.scala new file mode 100644 index 0000000000000..af912c84c0dfc --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/ContextBarrierStateSuite.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.mockito.ArgumentMatcher +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.scalatest.concurrent.Eventually + +import org.apache.spark._ +import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv} + +class ContextBarrierStateSuite extends SparkFunSuite with LocalSparkContext with Eventually { + + private def mockRpcCallContext() = { + val rpcAddress = mock(classOf[RpcAddress]) + val rpcCallContext = mock(classOf[RpcCallContext]) + when(rpcCallContext.senderAddress).thenReturn(rpcAddress) + rpcCallContext + } + + test("normal test for single task") { + val barrierCoordinator = new BarrierCoordinator( + 5, mock(classOf[LiveListenerBus]), mock(classOf[RpcEnv])) + val stageId = 0 + val stageAttemptNumber = 0 + val state = new barrierCoordinator.ContextBarrierState( + ContextBarrierId(stageId, stageAttemptNumber), numTasks = 1) + state.handleRequest( + mockRpcCallContext(), + RequestToSync( + numTasks = 1, + stageId, + stageAttemptNumber, + taskAttemptId = 0, + barrierEpoch = 0)) + // Ensure barrierEpoch value have been changed. + assert(state.getBarrierEpoch() == 1) + assert(state.isInternalStateClear()) + } + + test("normal test for multi tasks") { + val barrierCoordinator = new BarrierCoordinator( + 5, mock(classOf[LiveListenerBus]), mock(classOf[RpcEnv])) + val numTasks = 3 + val stageId = 0 + val stageAttemptNumber = 0 + val state = new barrierCoordinator.ContextBarrierState( + ContextBarrierId(stageId, stageAttemptNumber), numTasks) + // request from 3 tasks + (0 until numTasks).foreach { taskId => + state.handleRequest(mockRpcCallContext(), RequestToSync( + numTasks, + stageId, + stageAttemptNumber, + taskAttemptId = taskId, + barrierEpoch = 0)) + } + // Ensure barrierEpoch value have been changed. + assert(state.getBarrierEpoch() == 1) + assert(state.isInternalStateClear()) + } + + test("abnormal test for syncing with illegal barrierId") { + val barrierCoordinator = new BarrierCoordinator( + 5, mock(classOf[LiveListenerBus]), mock(classOf[RpcEnv])) + val numTasks = 3 + val stageId = 0 + val stageAttemptNumber = 0 + val rpcCallContext = mockRpcCallContext() + val state = new barrierCoordinator.ContextBarrierState( + ContextBarrierId(stageId, stageAttemptNumber), numTasks) + state.handleRequest( + rpcCallContext, + RequestToSync( + numTasks, + stageId, + stageAttemptNumber, + taskAttemptId = 0, + barrierEpoch = -1)) + verify(rpcCallContext, times(1)) + .sendFailure(argThat(new ArgumentMatcher[SparkException] { + override def matches(e: Any): Boolean = { + e.asInstanceOf[SparkException].getMessage == + "The request to sync of Stage 0 (Attempt 0) with barrier epoch -1 has already" + + " finished. Maybe task 0 is not properly killed." + } + })) + } + + test("abnormal test for syncing with old barrierId") { + val barrierCoordinator = new BarrierCoordinator( + 5, mock(classOf[LiveListenerBus]), mock(classOf[RpcEnv])) + val numTasks = 3 + val stageId = 0 + val stageAttemptNumber = 0 + val rpcCallContext = mockRpcCallContext() + val state = new barrierCoordinator.ContextBarrierState( + ContextBarrierId(stageId, stageAttemptNumber), numTasks) + // request from 3 tasks + (0 until numTasks).foreach { taskId => + state.handleRequest( + rpcCallContext, + RequestToSync( + numTasks, + stageId, + stageAttemptNumber, + taskAttemptId = taskId, + barrierEpoch = 0)) + } + // Ensure barrierEpoch value have been changed. + assert(state.getBarrierEpoch() == 1) + assert(state.isInternalStateClear()) + state.handleRequest( + rpcCallContext, + RequestToSync( + numTasks, + stageId, + stageAttemptNumber, + taskAttemptId = 0, + barrierEpoch = 0)) + verify(rpcCallContext, times(1)) + .sendFailure(argThat(new ArgumentMatcher[SparkException] { + override def matches(e: Any): Boolean = { + e.asInstanceOf[SparkException].getMessage == + "The request to sync of Stage 0 (Attempt 0) with barrier epoch 0 has already" + + " finished. Maybe task 0 is not properly killed." + }})) + } + + test("abnormal test for timeout when rpcTimeOut > barrierTimeOut") { + val barrierCoordinator = new BarrierCoordinator( + 2, mock(classOf[LiveListenerBus]), mock(classOf[RpcEnv])) + val numTasks = 3 + val stageId = 0 + val stageAttemptNumber = 0 + val rpcCallContext = mockRpcCallContext() + val state = new barrierCoordinator.ContextBarrierState( + ContextBarrierId(stageId, stageAttemptNumber), numTasks) + state.handleRequest(rpcCallContext, RequestToSync( + numTasks, + stageId, + stageAttemptNumber, + taskAttemptId = 0, + barrierEpoch = 0)) + eventually(timeout(5.seconds)) { + verify(rpcCallContext, times(1)) + .sendFailure(argThat(new ArgumentMatcher[SparkException] { + override def matches(e: Any): Boolean = { + e.asInstanceOf[SparkException].getMessage == + "The coordinator didn't get all barrier sync requests for barrier epoch" + + " 0 from Stage 0 (Attempt 0) within 2 second(s)." + } + })) + } + } +}