diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index b376ecd301eab..8f4909ea2939e 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -92,9 +92,9 @@ private[spark] class CoarseGrainedExecutorBackend( if (executor == null) { exitExecutor(1, "Received LaunchTask command but executor was null") } else { - val taskDesc = TaskDescription.decode(data.value) + val (taskDesc, serializedTask) = TaskDescription.decode(data.value) logInfo("Got assigned task " + taskDesc.taskId) - executor.launchTask(this, taskDesc) + executor.launchTask(this, taskDesc, serializedTask) } case KillTask(taskId, _, interruptThread) => diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 790c1ae942474..39c592557571c 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -152,8 +152,11 @@ private[spark] class Executor( private[executor] def numRunningTasks: Int = runningTasks.size() - def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = { - val tr = new TaskRunner(context, taskDescription) + def launchTask( + context: ExecutorBackend, + taskDescription: TaskDescription, + serializedTask: ByteBuffer): Unit = { + val tr = new TaskRunner(context, taskDescription, serializedTask) runningTasks.put(taskDescription.taskId, tr) threadPool.execute(tr) } @@ -210,7 +213,8 @@ private[spark] class Executor( class TaskRunner( execBackend: ExecutorBackend, - private val taskDescription: TaskDescription) + private val taskDescription: TaskDescription, + private val serializedTask: ByteBuffer) extends Runnable { val taskId = taskDescription.taskId @@ -289,8 +293,8 @@ private[spark] class Executor( Executor.taskDeserializationProps.set(taskDescription.properties) updateDependencies(taskDescription.addedFiles, taskDescription.addedJars) - task = ser.deserialize[Task[Any]]( - taskDescription.serializedTask, Thread.currentThread.getContextClassLoader) + task = Utils.deserialize(serializedTask, + Thread.currentThread.getContextClassLoader).asInstanceOf[Task[Any]] task.localProperties = taskDescription.properties task.setTaskMemoryManager(taskMemoryManager) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala index c98b87148e404..29bc5bc1af19c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala @@ -24,7 +24,10 @@ import java.util.Properties import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, Map} +import scala.util.control.NonFatal +import org.apache.spark.TaskNotSerializableException +import org.apache.spark.internal.Logging import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} /** @@ -53,8 +56,26 @@ private[spark] class TaskDescription( val addedFiles: Map[String, Long], val addedJars: Map[String, Long], val properties: Properties, - val serializedTask: ByteBuffer) { - + // Task object corresponding to the TaskDescription. This is only defined on the driver; on + // the executor, the Task object is handled separately from the TaskDescription so that it can + // be deserialized after the TaskDescription is deserialized. + @transient private val task: Task[_] = null) extends Logging { + + /** + * Serializes the task for this TaskDescription and returns the serialized task. + * + * This method should only be used on the driver (to serialize a task to send to a executor). + */ + def serializeTask(): ByteBuffer = { + try { + ByteBuffer.wrap(Utils.serialize(task)) + } catch { + case NonFatal(e) => + val msg = s"Failed to serialize task $taskId." + logError(msg, e) + throw new TaskNotSerializableException(e) + } + } override def toString: String = "TaskDescription(TID=%d, index=%d)".format(taskId, index) } @@ -67,6 +88,7 @@ private[spark] object TaskDescription { } } + @throws[TaskNotSerializableException] def encode(taskDescription: TaskDescription): ByteBuffer = { val bytesOut = new ByteBufferOutputStream(4096) val dataOut = new DataOutputStream(bytesOut) @@ -93,8 +115,8 @@ private[spark] object TaskDescription { dataOut.write(bytes) } - // Write the task. The task is already serialized, so write it directly to the byte buffer. - Utils.writeByteBuffer(taskDescription.serializedTask, bytesOut) + // Serialize and write the task. + Utils.writeByteBuffer(taskDescription.serializeTask(), bytesOut) dataOut.close() bytesOut.close() @@ -110,7 +132,7 @@ private[spark] object TaskDescription { map } - def decode(byteBuffer: ByteBuffer): TaskDescription = { + def decode(byteBuffer: ByteBuffer): (TaskDescription, ByteBuffer) = { val dataIn = new DataInputStream(new ByteBufferInputStream(byteBuffer)) val taskId = dataIn.readLong() val attemptNumber = dataIn.readInt() @@ -138,7 +160,8 @@ private[spark] object TaskDescription { // Create a sub-buffer for the serialized task into its own buffer (to be deserialized later). val serializedTask = byteBuffer.slice() - new TaskDescription(taskId, attemptNumber, executorId, name, index, taskFiles, taskJars, - properties, serializedTask) + (new TaskDescription(taskId, attemptNumber, executorId, name, index, taskFiles, taskJars, + properties), + serializedTask) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index bfbcfa1aa386f..4b4c23b745cf1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -24,6 +24,7 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.Set import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.util.control.NonFatal import scala.util.Random import org.apache.spark._ @@ -34,7 +35,7 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.scheduler.TaskLocality.TaskLocality import org.apache.spark.scheduler.local.LocalSchedulerBackend import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.{AccumulatorV2, ThreadUtils, Utils} +import org.apache.spark.util.{AccumulatorV2, RpcUtils, ThreadUtils, Utils} /** * Schedules tasks for multiple types of clusters by acting through a SchedulerBackend. @@ -93,6 +94,8 @@ private[spark] class TaskSchedulerImpl private[scheduler]( // CPUs to request per task val CPUS_PER_TASK = conf.getInt("spark.task.cpus", 1) + val MAX_RRC_MESSAGE_SIZE = RpcUtils.maxMessageSizeBytes(conf) + // TaskSetManagers are not thread safe, so any access to one should be synchronized // on this class. private val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]] @@ -277,23 +280,15 @@ private[spark] class TaskSchedulerImpl private[scheduler]( val execId = shuffledOffers(i).executorId val host = shuffledOffers(i).host if (availableCpus(i) >= CPUS_PER_TASK) { - try { - for (task <- taskSet.resourceOffer(execId, host, maxLocality)) { - tasks(i) += task - val tid = task.taskId - taskIdToTaskSetManager(tid) = taskSet - taskIdToExecutorId(tid) = execId - executorIdToRunningTaskIds(execId).add(tid) - availableCpus(i) -= CPUS_PER_TASK - assert(availableCpus(i) >= 0) - launchedTask = true - } - } catch { - case e: TaskNotSerializableException => - logError(s"Resource offer failed, task set ${taskSet.name} was not serializable") - // Do not offer resources for this task, but don't throw an error to allow other - // task sets to be submitted. - return launchedTask + for (task <- taskSet.resourceOffer(execId, host, maxLocality)) { + tasks(i) += task + val tid = task.taskId + taskIdToTaskSetManager(tid) = taskSet + taskIdToExecutorId(tid) = execId + executorIdToRunningTaskIds(execId).add(tid) + availableCpus(i) -= CPUS_PER_TASK + assert(availableCpus(i) >= 0) + launchedTask = true } } } @@ -374,6 +369,77 @@ private[spark] class TaskSchedulerImpl private[scheduler]( return tasks } + private[scheduler] def makeOffersAndSerializeTasks( + offers: IndexedSeq[WorkerOffer]): Seq[Seq[(TaskDescription, ByteBuffer)]] = { + resourceOffers(offers).map(_.map { task => + val serializedTask = prepareSerializedTask(task, MAX_RRC_MESSAGE_SIZE) + if (serializedTask == null) { + statusUpdate(task.taskId, TaskState.KILLED, null.asInstanceOf[ByteBuffer]) + } + (task, serializedTask) + }.filter(_._2 != null)) + } + + private[scheduler] def prepareSerializedTask( + task: TaskDescription, + maxRpcMessageSize: Long): ByteBuffer = { + var serializedTask: ByteBuffer = null + try { + if (!getTaskSetManager(task.taskId).exists(_.isZombie)) { + serializedTask = TaskDescription.encode(task) + } + } catch { + case NonFatal(e) => + abortTaskSetManager(task.taskId, + s"Failed to serialize task ${task.taskId}, not attempting to retry it.", Some(e)) + } + + if (serializedTask != null && serializedTask.limit >= maxRpcMessageSize) { + val msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " + + "spark.rpc.message.maxSize (%d bytes). Consider increasing " + + "spark.rpc.message.maxSize or using broadcast variables for large values." + abortTaskSetManager(task.taskId, + msg.format(task.taskId, task.index, serializedTask.limit, maxRpcMessageSize)) + serializedTask = null + } else if (serializedTask != null) { + maybeEmitTaskSizeWarning(serializedTask, task.taskId) + } + serializedTask + } + + private def maybeEmitTaskSizeWarning( + serializedTask: ByteBuffer, + taskId: Long): Unit = { + if (serializedTask.limit > TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024) { + getTaskSetManager(taskId).filterNot(_.emittedTaskSizeWarning). + foreach { taskSetMgr => + taskSetMgr.emittedTaskSizeWarning = true + val stageId = taskSetMgr.taskSet.stageId + logWarning(s"Stage $stageId contains a task of very large size " + + s"(${serializedTask.limit / 1024} KB). The maximum recommended task size is " + + s"${TaskSetManager.TASK_SIZE_TO_WARN_KB} KB.") + } + } + } + + // abort TaskSetManager without exception + private def abortTaskSetManager( + taskId: Long, + msg: => String, + exception: Option[Throwable] = None): Unit = { + getTaskSetManager(taskId).foreach { taskSetMgr => + try { + taskSetMgr.abort(msg, exception) + } catch { + case e: Exception => logError("Exception while aborting taskset", e) + } + } + } + + private def getTaskSetManager(taskId: Long): Option[TaskSetManager] = synchronized { + taskIdToTaskSetManager.get(taskId) + } + /** * Shuffle offers around to avoid always placing tasks on the same workers. Exposed to allow * overriding in tests, so it can be deterministic. @@ -570,7 +636,7 @@ private[spark] class TaskSchedulerImpl private[scheduler]( /** * Cleans up the TaskScheduler's state for tracking the given task. */ - private def cleanupTaskState(tid: Long): Unit = { + private[scheduler] def cleanupTaskState(tid: Long): Unit = { taskIdToTaskSetManager.remove(tid) taskIdToExecutorId.remove(tid).foreach { executorId => executorIdToRunningTaskIds.get(executorId).foreach { _.remove(tid) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 19ebaf817e24e..cc36b6f2995b1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -65,7 +65,6 @@ private[spark] class TaskSetManager( // Serializer for closures and tasks. val env = SparkEnv.get - val ser = env.closureSerializer.newInstance() val tasks = taskSet.tasks val numTasks = tasks.length @@ -413,7 +412,6 @@ private[spark] class TaskSetManager( * @param host the host Id of the offered resource * @param maxLocality the maximum locality we want to schedule the tasks at */ - @throws[TaskNotSerializableException] def resourceOffer( execId: String, host: String, @@ -454,25 +452,7 @@ private[spark] class TaskSetManager( currentLocalityIndex = getLocalityIndex(taskLocality) lastLaunchTime = curTime } - // Serialize and return the task - val serializedTask: ByteBuffer = try { - ser.serialize(task) - } catch { - // If the task cannot be serialized, then there's no point to re-attempt the task, - // as it will always fail. So just abort the whole task-set. - case NonFatal(e) => - val msg = s"Failed to serialize task $taskId, not attempting to retry it." - logError(msg, e) - abort(s"$msg Exception during serialization: $e") - throw new TaskNotSerializableException(e) - } - if (serializedTask.limit > TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024 && - !emittedTaskSizeWarning) { - emittedTaskSizeWarning = true - logWarning(s"Stage ${task.stageId} contains a task of very large size " + - s"(${serializedTask.limit / 1024} KB). The maximum recommended task size is " + - s"${TaskSetManager.TASK_SIZE_TO_WARN_KB} KB.") - } + addRunningTask(taskId) // We used to log the time it takes to serialize the task, but task size is already @@ -480,7 +460,7 @@ private[spark] class TaskSetManager( // val timeTaken = clock.getTime() - startTime val taskName = s"task ${info.id} in stage ${taskSet.id}" logInfo(s"Starting $taskName (TID $taskId, $host, executor ${info.executorId}, " + - s"partition ${task.partitionId}, $taskLocality, ${serializedTask.limit} bytes)") + s"partition ${task.partitionId}, $taskLocality)") sched.dagScheduler.taskStarted(task, info) new TaskDescription( @@ -492,7 +472,7 @@ private[spark] class TaskSetManager( sched.sc.addedFiles, sched.sc.addedJars, task.localProperties, - serializedTask) + task) } } else { None diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 2898cd7d17ca0..988d4b53d47cc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -21,7 +21,7 @@ import java.nio.ByteBuffer import org.apache.spark.TaskState.TaskState import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.scheduler.ExecutorLossReason +import org.apache.spark.scheduler.{ExecutorLossReason, TaskDescription} import org.apache.spark.util.SerializableBuffer private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable @@ -37,6 +37,8 @@ private[spark] object CoarseGrainedClusterMessages { case object RetrieveLastAllocatedExecutorId extends CoarseGrainedClusterMessage + case class SerializeTask(task: TaskDescription) extends CoarseGrainedClusterMessage + // Driver to executors case class LaunchTask(data: SerializableBuffer) extends CoarseGrainedClusterMessage diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 94abe30bb12f2..ba2a6fe259f75 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -17,13 +17,14 @@ package org.apache.spark.scheduler.cluster -import java.util.concurrent.TimeUnit +import java.nio.ByteBuffer +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.concurrent.Future -import scala.concurrent.duration.Duration import org.apache.spark.{ExecutorAllocationClient, SparkEnv, SparkException, TaskState} import org.apache.spark.internal.Logging @@ -67,7 +68,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // must be protected by `CoarseGrainedSchedulerBackend.this`. Besides, `executorDataMap` should // only be modified in `DriverEndpoint.receive/receiveAndReply` with protection by // `CoarseGrainedSchedulerBackend.this`. - private val executorDataMap = new HashMap[String, ExecutorData] + private val executorDataMap = new ConcurrentHashMap[String, ExecutorData]().asScala // Number of executors requested from the cluster manager that have not registered yet @GuardedBy("CoarseGrainedSchedulerBackend.this") @@ -92,6 +93,37 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // The num of current max ExecutorId used to re-register appMaster @volatile protected var currentExecutorIdCounter = 0 + class SerializeTaskEndpoint(override val rpcEnv: RpcEnv) extends RpcEndpoint with Logging { + + override def receive: PartialFunction[Any, Unit] = { + case SerializeTask(task: TaskDescription) => + serializeTask(task) + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case StopDriver => + context.reply(true) + stop() + } + + private def serializeTask(task: TaskDescription): Unit = { + val serializedTask = scheduler.prepareSerializedTask(task, maxRpcMessageSize) + val executorData = executorDataMap(task.executorId) + if (executorData == null) { + driverEndpoint.send(StatusUpdate(task.executorId, task.taskId, TaskState.KILLED, + null.asInstanceOf[ByteBuffer])) + return + } + if (serializedTask != null) { + logInfo(s"Launching task ${task.taskId} on executor id: ${task.executorId} hostname: " + + s"${executorData.executorHost}, serializedTask: ${serializedTask.limit} bytes.") + executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask))) + } else { + driverEndpoint.send(StatusUpdate(task.executorId, task.taskId, TaskState.FAILED, + ByteBuffer.allocate(0))) + } + } + } class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends ThreadSafeRpcEndpoint with Logging { @@ -120,7 +152,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp if (TaskState.isFinished(state)) { executorDataMap.get(executorId) match { case Some(executorInfo) => - executorInfo.freeCores += scheduler.CPUS_PER_TASK + executorInfo.incrementFreeCores(scheduler.CPUS_PER_TASK) makeOffers(executorId) case None => // Ignoring the update since we don't know about the executor. @@ -256,31 +288,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Launch tasks returned by a set of resource offers private def launchTasks(tasks: Seq[Seq[TaskDescription]]) { - for (task <- tasks.flatten) { - val serializedTask = TaskDescription.encode(task) - if (serializedTask.limit >= maxRpcMessageSize) { - scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr => - try { - var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " + - "spark.rpc.message.maxSize (%d bytes). Consider increasing " + - "spark.rpc.message.maxSize or using broadcast variables for large values." - msg = msg.format(task.taskId, task.index, serializedTask.limit, maxRpcMessageSize) - taskSetMgr.abort(msg) - } catch { - case e: Exception => logError("Exception in error callback", e) - } - } - } - else { - val executorData = executorDataMap(task.executorId) - executorData.freeCores -= scheduler.CPUS_PER_TASK - - logDebug(s"Launching task ${task.taskId} on executor id: ${task.executorId} hostname: " + - s"${executorData.executorHost}.") - - executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask))) - } - } + tasks.foreach(_.foreach { task => + val executorData = executorDataMap(task.executorId) + executorData.decrementFreeCores(scheduler.CPUS_PER_TASK) + serializeTaskEndpoint.send(SerializeTask(task)) + }) } // Remove a disconnected slave from the cluster @@ -344,6 +356,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } var driverEndpoint: RpcEndpointRef = null + var serializeTaskEndpoint: RpcEndpointRef = null protected def minRegisteredRatio: Double = _minRegisteredRatio @@ -357,6 +370,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // TODO (prashant) send conf instead of properties driverEndpoint = createDriverEndpointRef(properties) + serializeTaskEndpoint = createSerializeTaskEndpointRef() } protected def createDriverEndpointRef( @@ -364,6 +378,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp rpcEnv.setupEndpoint(ENDPOINT_NAME, createDriverEndpoint(properties)) } + protected def createSerializeTaskEndpointRef(): RpcEndpointRef = { + rpcEnv.setupEndpoint(CoarseGrainedSchedulerBackend.SERIALIZE_TASK_ENDPOINT_NAME, + new SerializeTaskEndpoint(rpcEnv)) + } + protected def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = { new DriverEndpoint(rpcEnv, properties) } @@ -374,6 +393,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp logInfo("Shutting down all executors") driverEndpoint.askSync[Boolean](StopExecutors) } + + if (serializeTaskEndpoint != null) { + serializeTaskEndpoint.askSync[Boolean](StopDriver) + } } catch { case e: Exception => throw new SparkException("Error asking standalone scheduler to shut down executors", e) @@ -621,6 +644,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } } -private[spark] object CoarseGrainedSchedulerBackend { +private[spark] object CoarseGrainedSchedulerBackend extends Logging { val ENDPOINT_NAME = "CoarseGrainedScheduler" + val SERIALIZE_TASK_ENDPOINT_NAME = "SerializeTaskCoarseGrainedScheduler" } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala index b25a4bfb501fb..ef343ccf7089f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler.cluster +import io.netty.util.internal.chmv8.LongAdderV8 + import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} /** @@ -25,14 +27,27 @@ import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} * @param executorEndpoint The RpcEndpointRef representing this executor * @param executorAddress The network address of this executor * @param executorHost The hostname that this executor is running on - * @param freeCores The current number of cores available for work on the executor + * @param initFreeCores The current number of cores available for work on the executor * @param totalCores The total number of cores available to the executor */ private[cluster] class ExecutorData( - val executorEndpoint: RpcEndpointRef, - val executorAddress: RpcAddress, - override val executorHost: String, - var freeCores: Int, - override val totalCores: Int, - override val logUrlMap: Map[String, String] -) extends ExecutorInfo(executorHost, totalCores, logUrlMap) + val executorEndpoint: RpcEndpointRef, + val executorAddress: RpcAddress, + override val executorHost: String, + private val initFreeCores: Int, + override val totalCores: Int, + override val logUrlMap: Map[String, String]) + extends ExecutorInfo(executorHost, totalCores, logUrlMap) { + private val freeCoresUpdater = new LongAdderV8() + freeCoresUpdater.add(initFreeCores) + + def incrementFreeCores(increment: Int): Unit = { + freeCoresUpdater.add(increment) + } + + def decrementFreeCores(decrement: Int): Unit = { + freeCoresUpdater.add(-decrement) + } + + def freeCores: Int = freeCoresUpdater.intValue() +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala index 625f998cd4608..43fec3c444d11 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala @@ -21,6 +21,8 @@ import java.io.File import java.net.URL import java.nio.ByteBuffer +import scala.collection.mutable.HashSet + import org.apache.spark.{SparkConf, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} @@ -29,6 +31,7 @@ import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo +import org.apache.spark.util.RpcUtils private case class ReviveOffers() @@ -82,9 +85,10 @@ private[spark] class LocalEndpoint( def reviveOffers() { val offers = IndexedSeq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) - for (task <- scheduler.resourceOffers(offers).flatten) { + for ((_, buffer) <- scheduler.makeOffersAndSerializeTasks(offers).flatten) { freeCores -= scheduler.CPUS_PER_TASK - executor.launchTask(executorBackend, task) + val (taskDesc, serializedTask) = TaskDescription.decode(buffer) + executor.launchTask(executorBackend, taskDesc, serializedTask) } } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 1af34e3da231f..afce553c8a368 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -151,7 +151,12 @@ private[spark] object Utils extends Logging { /** Deserialize an object using Java serialization and the given ClassLoader */ def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { - val bis = new ByteArrayInputStream(bytes) + deserialize(ByteBuffer.wrap(bytes), loader) + } + + /** Deserialize an object using Java serialization and the given ClassLoader */ + def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { + val bis = new ByteBufferInputStream(bytes) val ois = new ObjectInputStream(bis) { override def resolveClass(desc: ObjectStreamClass): Class[_] = { // scalastyle:off classforname diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 8150fff2d018d..79a18f1c6cb1b 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -53,7 +53,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug val serializer = new JavaSerializer(conf) val env = createMockEnv(conf, serializer) val serializedTask = serializer.newInstance().serialize(new FakeTask(0, 0)) - val taskDescription = createFakeTaskDescription(serializedTask) + val taskDescription = createFakeTaskDescription() // we use latches to force the program to run in this order: // +-----------------------------+---------------------------------------+ @@ -103,7 +103,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug try { executor = new Executor("id", "localhost", env, userClassPath = Nil, isLocal = true) // the task will be launched in a dedicated worker thread - executor.launchTask(mockExecutorBackend, taskDescription) + executor.launchTask(mockExecutorBackend, taskDescription, serializedTask) if (!executorSuiteHelper.latch1.await(5, TimeUnit.SECONDS)) { fail("executor did not send first status update in time") @@ -152,9 +152,9 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug ) val serTask = serializer.serialize(task) - val taskDescription = createFakeTaskDescription(serTask) + val taskDescription = createFakeTaskDescription() - val failReason = runTaskAndGetFailReason(taskDescription) + val failReason = runTaskAndGetFailReason(taskDescription, serTask) assert(failReason.isInstanceOf[FetchFailed]) } @@ -184,10 +184,10 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug ) val serTask = serializer.serialize(task) - val taskDescription = createFakeTaskDescription(serTask) + val taskDescription = createFakeTaskDescription() val (failReason, uncaughtExceptionHandler) = - runTaskGetFailReasonAndExceptionHandler(taskDescription) + runTaskGetFailReasonAndExceptionHandler(taskDescription, serTask) // make sure the task failure just looks like a OOM, not a fetch failure assert(failReason.isInstanceOf[ExceptionFailure]) val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable]) @@ -201,9 +201,9 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug val serializer = new JavaSerializer(conf) val env = createMockEnv(conf, serializer) val serializedTask = serializer.newInstance().serialize(new NonDeserializableTask) - val taskDescription = createFakeTaskDescription(serializedTask) + val taskDescription = createFakeTaskDescription() - val failReason = runTaskAndGetFailReason(taskDescription) + val failReason = runTaskAndGetFailReason(taskDescription, serializedTask) failReason match { case ef: ExceptionFailure => assert(ef.exception.isDefined) @@ -228,7 +228,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug mockEnv } - private def createFakeTaskDescription(serializedTask: ByteBuffer): TaskDescription = { + private def createFakeTaskDescription(): TaskDescription = { new TaskDescription( taskId = 0, attemptNumber = 0, @@ -237,16 +237,17 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug index = 0, addedFiles = Map[String, Long](), addedJars = Map[String, Long](), - properties = new Properties, - serializedTask) + properties = new Properties) } - private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = { - runTaskGetFailReasonAndExceptionHandler(taskDescription)._1 + private def runTaskAndGetFailReason(taskDescription: TaskDescription, + serializedTask: ByteBuffer): TaskFailedReason = { + runTaskGetFailReasonAndExceptionHandler(taskDescription, serializedTask)._1 } private def runTaskGetFailReasonAndExceptionHandler( - taskDescription: TaskDescription): (TaskFailedReason, UncaughtExceptionHandler) = { + taskDescription: TaskDescription, + serializedTask: ByteBuffer): (TaskFailedReason, UncaughtExceptionHandler) = { val mockBackend = mock[ExecutorBackend] val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler] var executor: Executor = null @@ -254,7 +255,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true, uncaughtExceptionHandler = mockUncaughtExceptionHandler) // the task will be launched in a dedicated worker thread - executor.launchTask(mockBackend, taskDescription) + executor.launchTask(mockBackend, taskDescription, serializedTask) eventually(timeout(5.seconds), interval(10.milliseconds)) { assert(executor.numRunningTasks === 0) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala index 04cccc67e328e..7c515d6dbe1e1 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala @@ -17,11 +17,37 @@ package org.apache.spark.scheduler -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} +import java.io.{IOException, NotSerializableException, ObjectInputStream, ObjectOutputStream} + +import org.apache.spark._ +import org.apache.spark.rdd.RDD import org.apache.spark.util.{RpcUtils, SerializableBuffer} -class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext { +private[spark] class NotSerializablePartitionRDD( + sc: SparkContext, + numPartitions: Int) extends RDD[(Int, Int)](sc, Nil) with Serializable { + + override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = + throw new RuntimeException("should not be reached") + + override def getPartitions: Array[Partition] = (0 until numPartitions).map(i => new Partition { + override def index: Int = i + + @throws(classOf[IOException]) + private def writeObject(out: ObjectOutputStream): Unit = { + throw new NotSerializableException() + } + + @throws(classOf[IOException]) + private def readObject(in: ObjectInputStream): Unit = {} + }).toArray + override def getPreferredLocations(partition: Partition): Seq[String] = Nil + + override def toString: String = "DAGSchedulerSuiteRDD " + id +} + +class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext { test("serialized task larger than max RPC message size") { val conf = new SparkConf conf.set("spark.rpc.message.maxSize", "1") @@ -38,4 +64,19 @@ class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkCo assert(smaller.size === 4) } + test("Scheduler aborts stages that have unserializable partition") { + val conf = new SparkConf() + .setMaster("local-cluster[2, 1, 1024]") + .setAppName("test") + .set("spark.dynamicAllocation.testing", "true") + sc = new SparkContext(conf) + val myRDD = new NotSerializablePartitionRDD(sc, 2) + val e = intercept[SparkException] { + myRDD.count() + } + assert(e.getMessage.contains("Failed to serialize task")) + assertResult(10) { + sc.parallelize(1 to 10).count() + } + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 8eaf9dfcf49b1..ae98ee7cdbc12 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.scheduler +import java.io.{IOException, NotSerializableException, ObjectInputStream, ObjectOutputStream} import java.util.Properties import java.util.concurrent.atomic.AtomicBoolean @@ -517,6 +518,32 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou assertDataStructuresEmpty() } + test("unserializable partitioner") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new Partitioner { + override def numPartitions = 1 + + override def getPartition(key: Any) = 1 + + @throws(classOf[IOException]) + private def writeObject(out: ObjectOutputStream): Unit = { + throw new NotSerializableException() + } + + @throws(classOf[IOException]) + private def readObject(in: ObjectInputStream): Unit = {} + }) + + // Submit a map stage by itself + submitMapStage(shuffleDep) + assert(failure.getMessage.startsWith( + "Job aborted due to stage failure: Task not serializable")) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sparkListener.failedStages.contains(0)) + assert(sparkListener.failedStages.size === 1) + assertDataStructuresEmpty() + } + test("trivial job failure") { submit(new MyRDD(sc, 1, Nil), Array(0)) failed(taskSets(0), "some failure") diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index fe6de2bd98850..be3ff2ed38cf8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -56,6 +56,6 @@ object FakeTask { val tasks = Array.tabulate[Task[_]](numTasks) { i => new FakeTask(stageId, i, if (prefLocs.size != 0) prefLocs(i) else Nil) } - new TaskSet(tasks, stageId, stageAttemptId, priority = 0, null) + new TaskSet(tasks, stageId, stageAttemptId, priority = 0, new Properties()) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala index 97487ce1d2ca8..44ae24f067a7b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala @@ -65,11 +65,15 @@ class TaskDescriptionSuite extends SparkFunSuite { originalFiles, originalJars, originalProperties, - taskBuffer - ) + // Pass in null for the task, because we override the serialize method below anyway (which + // is the only time task is used). + task = null + ) { + override def serializeTask() = taskBuffer + } val serializedTaskDescription = TaskDescription.encode(originalTaskDescription) - val decodedTaskDescription = TaskDescription.decode(serializedTaskDescription) + val (decodedTaskDescription, serializedTask) = TaskDescription.decode(serializedTaskDescription) // Make sure that all of the fields in the decoded task description match the original. assert(decodedTaskDescription.taskId === originalTaskDescription.taskId) @@ -80,6 +84,6 @@ class TaskDescriptionSuite extends SparkFunSuite { assert(decodedTaskDescription.addedFiles.equals(originalFiles)) assert(decodedTaskDescription.addedJars.equals(originalJars)) assert(decodedTaskDescription.properties.equals(originalTaskDescription.properties)) - assert(decodedTaskDescription.serializedTask.equals(taskBuffer)) + assert(serializedTask.equals(taskBuffer)) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 9ae0bcd9b8860..f7e2268af3917 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -178,29 +178,6 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B assert(!failedTaskSet) } - test("Scheduler does not crash when tasks are not serializable") { - val taskCpus = 2 - val taskScheduler = setupScheduler("spark.task.cpus" -> taskCpus.toString) - val numFreeCores = 1 - val taskSet = new TaskSet( - Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null) - val multiCoreWorkerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", taskCpus), - new WorkerOffer("executor1", "host1", numFreeCores)) - taskScheduler.submitTasks(taskSet) - var taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten - assert(0 === taskDescriptions.length) - assert(failedTaskSet) - assert(failedTaskSetReason.contains("Failed to serialize task")) - - // Now check that we can still submit tasks - // Even if one of the task sets has not-serializable tasks, the other task set should - // still be processed without error - taskScheduler.submitTasks(FakeTask.createTaskSet(1)) - taskScheduler.submitTasks(taskSet) - taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten - assert(taskDescriptions.map(_.executorId) === Seq("executor0")) - } - test("refuse to schedule concurrent attempts for the same stage (SPARK-8103)") { val taskScheduler = setupScheduler() val attempt1 = FakeTask.createTaskSet(1, 0) @@ -904,4 +881,26 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B assert(taskDescs.size === 1) assert(taskDescs.head.executorId === "exec2") } + + test("serialization task errors do not affect each other") { + import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend + val conf = new SparkConf().setMaster("local").setAppName("test") + sc = new SparkContext(conf) + + val taskScheduler = new TaskSchedulerImpl(sc) + taskScheduler.initialize(mock[CoarseGrainedSchedulerBackend]) + taskScheduler.setDAGScheduler(mock[DAGScheduler]) + val taskSet1 = FakeTask.createTaskSet(1) + val taskSet2 = FakeTask.createTaskSet(1) + taskSet1.tasks(0) = new NotSerializableFakeTask(1, 0) + val taskIdToTaskSetManager = taskScheduler.taskIdToTaskSetManager + taskScheduler.submitTasks(taskSet1) + taskScheduler.submitTasks(taskSet2) + val offers = Array(WorkerOffer("1", "localhost", 2), WorkerOffer("2", "localhost", 2)) + taskScheduler.makeOffersAndSerializeTasks(offers) + assert(taskIdToTaskSetManager.values.exists(_.taskSet == taskSet2)) + taskIdToTaskSetManager.values.filter(_.taskSet == taskSet2).foreach { taskSet => + assert(taskSet.isZombie === false) + } + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 2c2cda9f318eb..260fc338c7ba8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -612,34 +612,6 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(!manager.emittedTaskSizeWarning) } - test("emit warning when serialized task is large") { - sc = new SparkContext("local", "test") - sched = new FakeTaskScheduler(sc, ("exec1", "host1")) - - val taskSet = new TaskSet(Array(new LargeTask(0)), 0, 0, 0, null) - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) - - assert(!manager.emittedTaskSizeWarning) - - assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 0) - - assert(manager.emittedTaskSizeWarning) - } - - test("Not serializable exception thrown if the task cannot be serialized") { - sc = new SparkContext("local", "test") - sched = new FakeTaskScheduler(sc, ("exec1", "host1")) - - val taskSet = new TaskSet( - Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null) - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) - - intercept[TaskNotSerializableException] { - manager.resourceOffer("exec1", "host1", ANY) - } - assert(manager.isZombie) - } - test("abort the job if total size of results is too large") { val conf = new SparkConf().set("spark.driver.maxResultSize", "2m") sc = new SparkContext("local", "test", conf) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index b252539782580..29711a17097ca 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -85,12 +85,12 @@ private[spark] class MesosExecutorBackend } override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) { - val taskDescription = TaskDescription.decode(taskInfo.getData.asReadOnlyByteBuffer()) + val (taskDesc, serializedTask) = TaskDescription.decode(taskInfo.getData.asReadOnlyByteBuffer()) if (executor == null) { logError("Received launchTask but executor was null") } else { SparkHadoopUtil.get.runAsSparkUser { () => - executor.launchTask(this, taskDescription) + executor.launchTask(this, taskDesc, serializedTask) } } } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index 215271302ec51..9ca454fe89ea8 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File +import java.nio.ByteBuffer import java.util.{ArrayList => JArrayList, Collections, List => JList} import scala.collection.JavaConverters._ @@ -31,7 +32,7 @@ import org.apache.spark.{SparkContext, SparkException, TaskState} import org.apache.spark.executor.MesosExecutorBackend import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.util.Utils +import org.apache.spark.util.{RpcUtils, Utils} /** * A SchedulerBackend for running fine-grained tasks on Mesos. Each Spark task is mapped to a @@ -300,22 +301,20 @@ private[spark] class MesosFineGrainedSchedulerBackend( val slavesIdsOfAcceptedOffers = HashSet[String]() // Call into the TaskSchedulerImpl - val acceptedOffers = scheduler.resourceOffers(workerOffers).filter(!_.isEmpty) - acceptedOffers - .foreach { offer => - offer.foreach { taskDesc => - val slaveId = taskDesc.executorId - slavesIdsOfAcceptedOffers += slaveId - taskIdToSlaveId(taskDesc.taskId) = slaveId - val (mesosTask, remainingResources) = createMesosTask( - taskDesc, - slaveIdToResources(slaveId), - slaveId) - mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo]) - .add(mesosTask) - slaveIdToResources(slaveId) = remainingResources - } - } + val acceptedOffers = scheduler.makeOffersAndSerializeTasks(workerOffers).flatten + for ((task, serializedTask) <- acceptedOffers) { + val slaveId = task.executorId + slavesIdsOfAcceptedOffers += slaveId + taskIdToSlaveId(task.taskId) = slaveId + val (mesosTask, remainingResources) = createMesosTask( + task, + serializedTask, + slaveIdToResources(slaveId), + slaveId) + mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo]) + .add(mesosTask) + slaveIdToResources(slaveId) = remainingResources + } // Reply to the offers val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout? @@ -341,6 +340,7 @@ private[spark] class MesosFineGrainedSchedulerBackend( /** Turn a Spark TaskDescription into a Mesos task and also resources unused by the task */ def createMesosTask( task: TaskDescription, + serializedTask: ByteBuffer, resources: JList[Resource], slaveId: String): (MesosTaskInfo, JList[Resource]) = { val taskId = TaskID.newBuilder().setValue(task.taskId.toString).build() @@ -358,7 +358,7 @@ private[spark] class MesosFineGrainedSchedulerBackend( .setExecutor(executorInfo) .setName(task.name) .addAllResources(cpuResources.asJava) - .setData(ByteString.copyFrom(TaskDescription.encode(task))) + .setData(ByteString.copyFrom(serializedTask)) .build() (taskInfo, finalResources.asJava) } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index 98033bec6dd68..51b854508e63f 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -554,6 +554,8 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite // override to avoid race condition with the driver thread on `mesosDriver` override def startScheduler(newDriver: SchedulerDriver): Unit = {} + override protected def createSerializeTaskEndpointRef(): RpcEndpointRef = null + override def stopExecutors(): Unit = { stopCalled = true } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala index 4ee85b91830a9..d6ebc3b64da2d 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala @@ -255,9 +255,10 @@ class MesosFineGrainedSchedulerBackendSuite index = 0, addedFiles = mutable.Map.empty[String, Long], addedJars = mutable.Map.empty[String, Long], - properties = new Properties(), - ByteBuffer.wrap(new Array[Byte](0))) - when(taskScheduler.resourceOffers(expectedWorkerOffers)).thenReturn(Seq(Seq(taskDesc))) + properties = new Properties()) + val serializedTask = TaskDescription.encode(taskDesc) + when(taskScheduler.makeOffersAndSerializeTasks(expectedWorkerOffers)). + thenReturn(Seq(Seq((taskDesc, serializedTask)))) when(taskScheduler.CPUS_PER_TASK).thenReturn(2) val capture = ArgumentCaptor.forClass(classOf[Collection[TaskInfo]]) @@ -293,7 +294,8 @@ class MesosFineGrainedSchedulerBackendSuite mesosOffers2.add(createOffer(1, minMem, minCpu)) reset(taskScheduler) reset(driver) - when(taskScheduler.resourceOffers(any(classOf[IndexedSeq[WorkerOffer]]))).thenReturn(Seq(Seq())) + when(taskScheduler.makeOffersAndSerializeTasks(any(classOf[IndexedSeq[WorkerOffer]]))). + thenReturn(Seq(Seq())) when(taskScheduler.CPUS_PER_TASK).thenReturn(2) when(driver.declineOffer(mesosOffers2.get(0).getId)).thenReturn(Status.valueOf(1)) @@ -363,9 +365,10 @@ class MesosFineGrainedSchedulerBackendSuite index = 0, addedFiles = mutable.Map.empty[String, Long], addedJars = mutable.Map.empty[String, Long], - properties = new Properties(), - ByteBuffer.wrap(new Array[Byte](0))) - when(taskScheduler.resourceOffers(expectedWorkerOffers)).thenReturn(Seq(Seq(taskDesc))) + properties = new Properties()) + val serializedTask = TaskDescription.encode(taskDesc) + when(taskScheduler.makeOffersAndSerializeTasks(expectedWorkerOffers)). + thenReturn(Seq(Seq((taskDesc, serializedTask)))) when(taskScheduler.CPUS_PER_TASK).thenReturn(1) val capture = ArgumentCaptor.forClass(classOf[Collection[TaskInfo]])