Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ private[spark] class CoarseGrainedExecutorBackend(
if (executor == null) {
exitExecutor(1, "Received LaunchTask command but executor was null")
} else {
val taskDesc = TaskDescription.decode(data.value)
val (taskDesc, serializedTask) = TaskDescription.decode(data.value)
logInfo("Got assigned task " + taskDesc.taskId)
executor.launchTask(this, taskDesc)
executor.launchTask(this, taskDesc, serializedTask)
}

case KillTask(taskId, _, interruptThread) =>
Expand Down
14 changes: 9 additions & 5 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,11 @@ private[spark] class Executor(

private[executor] def numRunningTasks: Int = runningTasks.size()

def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
val tr = new TaskRunner(context, taskDescription)
def launchTask(
context: ExecutorBackend,
taskDescription: TaskDescription,
serializedTask: ByteBuffer): Unit = {
val tr = new TaskRunner(context, taskDescription, serializedTask)
runningTasks.put(taskDescription.taskId, tr)
threadPool.execute(tr)
}
Expand Down Expand Up @@ -210,7 +213,8 @@ private[spark] class Executor(

class TaskRunner(
execBackend: ExecutorBackend,
private val taskDescription: TaskDescription)
private val taskDescription: TaskDescription,
private val serializedTask: ByteBuffer)
extends Runnable {

val taskId = taskDescription.taskId
Expand Down Expand Up @@ -289,8 +293,8 @@ private[spark] class Executor(
Executor.taskDeserializationProps.set(taskDescription.properties)

updateDependencies(taskDescription.addedFiles, taskDescription.addedJars)
task = ser.deserialize[Task[Any]](
taskDescription.serializedTask, Thread.currentThread.getContextClassLoader)
task = Utils.deserialize(serializedTask,
Thread.currentThread.getContextClassLoader).asInstanceOf[Task[Any]]
task.localProperties = taskDescription.properties
task.setTaskMemoryManager(taskMemoryManager)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ import java.util.Properties

import scala.collection.JavaConverters._
import scala.collection.mutable.{HashMap, Map}
import scala.util.control.NonFatal

import org.apache.spark.TaskNotSerializableException
import org.apache.spark.internal.Logging
import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils}

/**
Expand Down Expand Up @@ -53,8 +56,26 @@ private[spark] class TaskDescription(
val addedFiles: Map[String, Long],
val addedJars: Map[String, Long],
val properties: Properties,
val serializedTask: ByteBuffer) {

// Task object corresponding to the TaskDescription. This is only defined on the driver; on
// the executor, the Task object is handled separately from the TaskDescription so that it can
// be deserialized after the TaskDescription is deserialized.
@transient private val task: Task[_] = null) extends Logging {

/**
* Serializes the task for this TaskDescription and returns the serialized task.
*
* This method should only be used on the driver (to serialize a task to send to a executor).
*/
def serializeTask(): ByteBuffer = {
try {
ByteBuffer.wrap(Utils.serialize(task))
} catch {
case NonFatal(e) =>
val msg = s"Failed to serialize task $taskId."
logError(msg, e)
throw new TaskNotSerializableException(e)
}
}
override def toString: String = "TaskDescription(TID=%d, index=%d)".format(taskId, index)
}

Expand All @@ -67,6 +88,7 @@ private[spark] object TaskDescription {
}
}

@throws[TaskNotSerializableException]
def encode(taskDescription: TaskDescription): ByteBuffer = {
val bytesOut = new ByteBufferOutputStream(4096)
val dataOut = new DataOutputStream(bytesOut)
Expand All @@ -93,8 +115,8 @@ private[spark] object TaskDescription {
dataOut.write(bytes)
}

// Write the task. The task is already serialized, so write it directly to the byte buffer.
Utils.writeByteBuffer(taskDescription.serializedTask, bytesOut)
// Serialize and write the task.
Utils.writeByteBuffer(taskDescription.serializeTask(), bytesOut)

dataOut.close()
bytesOut.close()
Expand All @@ -110,7 +132,7 @@ private[spark] object TaskDescription {
map
}

def decode(byteBuffer: ByteBuffer): TaskDescription = {
def decode(byteBuffer: ByteBuffer): (TaskDescription, ByteBuffer) = {
val dataIn = new DataInputStream(new ByteBufferInputStream(byteBuffer))
val taskId = dataIn.readLong()
val attemptNumber = dataIn.readInt()
Expand Down Expand Up @@ -138,7 +160,8 @@ private[spark] object TaskDescription {
// Create a sub-buffer for the serialized task into its own buffer (to be deserialized later).
val serializedTask = byteBuffer.slice()

new TaskDescription(taskId, attemptNumber, executorId, name, index, taskFiles, taskJars,
properties, serializedTask)
(new TaskDescription(taskId, attemptNumber, executorId, name, index, taskFiles, taskJars,
properties),
serializedTask)
}
}
104 changes: 85 additions & 19 deletions core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.util.concurrent.atomic.AtomicLong

import scala.collection.Set
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.util.control.NonFatal
import scala.util.Random

import org.apache.spark._
Expand All @@ -34,7 +35,7 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
import org.apache.spark.scheduler.TaskLocality.TaskLocality
import org.apache.spark.scheduler.local.LocalSchedulerBackend
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.{AccumulatorV2, ThreadUtils, Utils}
import org.apache.spark.util.{AccumulatorV2, RpcUtils, ThreadUtils, Utils}

/**
* Schedules tasks for multiple types of clusters by acting through a SchedulerBackend.
Expand Down Expand Up @@ -93,6 +94,8 @@ private[spark] class TaskSchedulerImpl private[scheduler](
// CPUs to request per task
val CPUS_PER_TASK = conf.getInt("spark.task.cpus", 1)

val MAX_RRC_MESSAGE_SIZE = RpcUtils.maxMessageSizeBytes(conf)

// TaskSetManagers are not thread safe, so any access to one should be synchronized
// on this class.
private val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]]
Expand Down Expand Up @@ -277,23 +280,15 @@ private[spark] class TaskSchedulerImpl private[scheduler](
val execId = shuffledOffers(i).executorId
val host = shuffledOffers(i).host
if (availableCpus(i) >= CPUS_PER_TASK) {
try {
for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
tasks(i) += task
val tid = task.taskId
taskIdToTaskSetManager(tid) = taskSet
taskIdToExecutorId(tid) = execId
executorIdToRunningTaskIds(execId).add(tid)
availableCpus(i) -= CPUS_PER_TASK
assert(availableCpus(i) >= 0)
launchedTask = true
}
} catch {
case e: TaskNotSerializableException =>
logError(s"Resource offer failed, task set ${taskSet.name} was not serializable")
// Do not offer resources for this task, but don't throw an error to allow other
// task sets to be submitted.
return launchedTask
for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
tasks(i) += task
val tid = task.taskId
taskIdToTaskSetManager(tid) = taskSet
taskIdToExecutorId(tid) = execId
executorIdToRunningTaskIds(execId).add(tid)
availableCpus(i) -= CPUS_PER_TASK
assert(availableCpus(i) >= 0)
launchedTask = true
}
}
}
Expand Down Expand Up @@ -374,6 +369,77 @@ private[spark] class TaskSchedulerImpl private[scheduler](
return tasks
}

private[scheduler] def makeOffersAndSerializeTasks(
offers: IndexedSeq[WorkerOffer]): Seq[Seq[(TaskDescription, ByteBuffer)]] = {
resourceOffers(offers).map(_.map { task =>
val serializedTask = prepareSerializedTask(task, MAX_RRC_MESSAGE_SIZE)
if (serializedTask == null) {
statusUpdate(task.taskId, TaskState.KILLED, null.asInstanceOf[ByteBuffer])
}
(task, serializedTask)
}.filter(_._2 != null))
}

private[scheduler] def prepareSerializedTask(
task: TaskDescription,
maxRpcMessageSize: Long): ByteBuffer = {
var serializedTask: ByteBuffer = null
try {
if (!getTaskSetManager(task.taskId).exists(_.isZombie)) {
serializedTask = TaskDescription.encode(task)
}
} catch {
case NonFatal(e) =>
abortTaskSetManager(task.taskId,
s"Failed to serialize task ${task.taskId}, not attempting to retry it.", Some(e))
}

if (serializedTask != null && serializedTask.limit >= maxRpcMessageSize) {
val msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
"spark.rpc.message.maxSize (%d bytes). Consider increasing " +
"spark.rpc.message.maxSize or using broadcast variables for large values."
abortTaskSetManager(task.taskId,
msg.format(task.taskId, task.index, serializedTask.limit, maxRpcMessageSize))
serializedTask = null
} else if (serializedTask != null) {
maybeEmitTaskSizeWarning(serializedTask, task.taskId)
}
serializedTask
}

private def maybeEmitTaskSizeWarning(
serializedTask: ByteBuffer,
taskId: Long): Unit = {
if (serializedTask.limit > TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024) {
getTaskSetManager(taskId).filterNot(_.emittedTaskSizeWarning).
foreach { taskSetMgr =>
taskSetMgr.emittedTaskSizeWarning = true
val stageId = taskSetMgr.taskSet.stageId
logWarning(s"Stage $stageId contains a task of very large size " +
s"(${serializedTask.limit / 1024} KB). The maximum recommended task size is " +
s"${TaskSetManager.TASK_SIZE_TO_WARN_KB} KB.")
}
}
}

// abort TaskSetManager without exception
private def abortTaskSetManager(
taskId: Long,
msg: => String,
exception: Option[Throwable] = None): Unit = {
getTaskSetManager(taskId).foreach { taskSetMgr =>
try {
taskSetMgr.abort(msg, exception)
} catch {
case e: Exception => logError("Exception while aborting taskset", e)
}
}
}

private def getTaskSetManager(taskId: Long): Option[TaskSetManager] = synchronized {
taskIdToTaskSetManager.get(taskId)
}

/**
* Shuffle offers around to avoid always placing tasks on the same workers. Exposed to allow
* overriding in tests, so it can be deterministic.
Expand Down Expand Up @@ -570,7 +636,7 @@ private[spark] class TaskSchedulerImpl private[scheduler](
/**
* Cleans up the TaskScheduler's state for tracking the given task.
*/
private def cleanupTaskState(tid: Long): Unit = {
private[scheduler] def cleanupTaskState(tid: Long): Unit = {
taskIdToTaskSetManager.remove(tid)
taskIdToExecutorId.remove(tid).foreach { executorId =>
executorIdToRunningTaskIds.get(executorId).foreach { _.remove(tid) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ private[spark] class TaskSetManager(

// Serializer for closures and tasks.
val env = SparkEnv.get
val ser = env.closureSerializer.newInstance()

val tasks = taskSet.tasks
val numTasks = tasks.length
Expand Down Expand Up @@ -413,7 +412,6 @@ private[spark] class TaskSetManager(
* @param host the host Id of the offered resource
* @param maxLocality the maximum locality we want to schedule the tasks at
*/
@throws[TaskNotSerializableException]
def resourceOffer(
execId: String,
host: String,
Expand Down Expand Up @@ -454,33 +452,15 @@ private[spark] class TaskSetManager(
currentLocalityIndex = getLocalityIndex(taskLocality)
lastLaunchTime = curTime
}
// Serialize and return the task
val serializedTask: ByteBuffer = try {
ser.serialize(task)
} catch {
// If the task cannot be serialized, then there's no point to re-attempt the task,
// as it will always fail. So just abort the whole task-set.
case NonFatal(e) =>
val msg = s"Failed to serialize task $taskId, not attempting to retry it."
logError(msg, e)
abort(s"$msg Exception during serialization: $e")
throw new TaskNotSerializableException(e)
}
if (serializedTask.limit > TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024 &&
!emittedTaskSizeWarning) {
emittedTaskSizeWarning = true
logWarning(s"Stage ${task.stageId} contains a task of very large size " +
s"(${serializedTask.limit / 1024} KB). The maximum recommended task size is " +
s"${TaskSetManager.TASK_SIZE_TO_WARN_KB} KB.")
}

addRunningTask(taskId)

// We used to log the time it takes to serialize the task, but task size is already
// a good proxy to task serialization time.
// val timeTaken = clock.getTime() - startTime
val taskName = s"task ${info.id} in stage ${taskSet.id}"
logInfo(s"Starting $taskName (TID $taskId, $host, executor ${info.executorId}, " +
s"partition ${task.partitionId}, $taskLocality, ${serializedTask.limit} bytes)")
s"partition ${task.partitionId}, $taskLocality)")

sched.dagScheduler.taskStarted(task, info)
new TaskDescription(
Expand All @@ -492,7 +472,7 @@ private[spark] class TaskSetManager(
sched.sc.addedFiles,
sched.sc.addedJars,
task.localProperties,
serializedTask)
task)
}
} else {
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.nio.ByteBuffer

import org.apache.spark.TaskState.TaskState
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.scheduler.ExecutorLossReason
import org.apache.spark.scheduler.{ExecutorLossReason, TaskDescription}
import org.apache.spark.util.SerializableBuffer

private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable
Expand All @@ -37,6 +37,8 @@ private[spark] object CoarseGrainedClusterMessages {

case object RetrieveLastAllocatedExecutorId extends CoarseGrainedClusterMessage

case class SerializeTask(task: TaskDescription) extends CoarseGrainedClusterMessage

// Driver to executors
case class LaunchTask(data: SerializableBuffer) extends CoarseGrainedClusterMessage

Expand Down
Loading