Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions core/src/main/java/org/apache/spark/TaskContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ static void unset() {
*/
public abstract boolean isInterrupted();

/** @deprecated: use isRunningLocally() */
/** @deprecated use {@link #isRunningLocally()} */
@Deprecated
public abstract boolean runningLocally();

Expand All @@ -87,19 +87,39 @@ static void unset() {
* is for HadoopRDD to register a callback to close the input stream.
* Will be called in any situation - success, failure, or cancellation.
*
* @deprecated: use addTaskCompletionListener
* @deprecated use {@link #addTaskCompletionListener(scala.Function1)}
*
* @param f Callback function.
*/
@Deprecated
public abstract void addOnCompleteCallback(final Function0<Unit> f);

/**
* The ID of the stage that this task belong to.
*/
public abstract int stageId();

/**
* The ID of the RDD partition that is computed by this task.
*/
public abstract int partitionId();

/**
* How many times this task has been attempted. The first task attempt will be assigned
* attemptNumber = 0, and subsequent attempts will have increasing attempt numbers.
*/
public abstract int attemptNumber();

/** @deprecated use {@link #taskAttemptId()}; it was renamed to avoid ambiguity. */
@Deprecated
public abstract long attemptId();

/**
* An ID that is unique to this task attempt (within the same SparkContext, no two task attempts
* will share the same attempt ID). This is roughly equivalent to Hadoop's TaskAttemptID.
*/
public abstract long taskAttemptId();

/** ::DeveloperApi:: */
@DeveloperApi
public abstract TaskMetrics taskMetrics();
Expand Down
9 changes: 7 additions & 2 deletions core/src/main/scala/org/apache/spark/TaskContextImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,19 @@ import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerExce

import scala.collection.mutable.ArrayBuffer

private[spark] class TaskContextImpl(val stageId: Int,
private[spark] class TaskContextImpl(
val stageId: Int,
val partitionId: Int,
val attemptId: Long,
override val taskAttemptId: Long,
override val attemptNumber: Int,
val runningLocally: Boolean = false,
val taskMetrics: TaskMetrics = TaskMetrics.empty)
extends TaskContext
with Logging {

// For backwards-compatibility; this method is now deprecated as of 1.3.0.
override def attemptId: Long = taskAttemptId

// List of callback functions to execute when the task completes.
@transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ private[spark] class CoarseGrainedExecutorBackend(
val ser = env.closureSerializer.newInstance()
val taskDesc = ser.deserialize[TaskDescription](data.value)
logInfo("Got assigned task " + taskDesc.taskId)
executor.launchTask(this, taskDesc.taskId, taskDesc.name, taskDesc.serializedTask)
executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber,
taskDesc.name, taskDesc.serializedTask)
}

case KillTask(taskId, _, interruptThread) =>
Expand Down
17 changes: 13 additions & 4 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,13 @@ private[spark] class Executor(
startDriverHeartbeater()

def launchTask(
context: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer) {
val tr = new TaskRunner(context, taskId, taskName, serializedTask)
context: ExecutorBackend,
taskId: Long,
attemptNumber: Int,
taskName: String,
serializedTask: ByteBuffer) {
val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName,
serializedTask)
runningTasks.put(taskId, tr)
threadPool.execute(tr)
}
Expand All @@ -134,7 +139,11 @@ private[spark] class Executor(
private def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum

class TaskRunner(
execBackend: ExecutorBackend, val taskId: Long, taskName: String, serializedTask: ByteBuffer)
execBackend: ExecutorBackend,
val taskId: Long,
val attemptNumber: Int,
taskName: String,
serializedTask: ByteBuffer)
extends Runnable {

@volatile private var killed = false
Expand Down Expand Up @@ -180,7 +189,7 @@ private[spark] class Executor(

// Run the actual task and measure its runtime.
taskStart = System.currentTimeMillis()
val value = task.run(taskId.toInt)
val value = task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)
val taskFinish = System.currentTimeMillis()

// If the task has been killed, let's fail it.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _}
import org.apache.spark.{Logging, TaskState, SparkConf, SparkEnv}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.scheduler.cluster.mesos.{MesosTaskLaunchData}
import org.apache.spark.util.{SignalLogger, Utils}

private[spark] class MesosExecutorBackend
Expand Down Expand Up @@ -77,11 +78,13 @@ private[spark] class MesosExecutorBackend

override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) {
val taskId = taskInfo.getTaskId.getValue.toLong
val taskData = MesosTaskLaunchData.fromByteString(taskInfo.getData)
if (executor == null) {
logError("Received launchTask but executor was null")
} else {
SparkHadoopUtil.get.runAsSparkUser { () =>
executor.launchTask(this, taskId, taskInfo.getName, taskInfo.getData.asReadOnlyByteBuffer)
executor.launchTask(this, taskId = taskId, attemptNumber = taskData.attemptNumber,
taskInfo.getName, taskData.serializedTask)
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ private[spark] object CheckpointRDD extends Logging {

val finalOutputName = splitIdToFile(ctx.partitionId)
val finalOutputPath = new Path(outputDir, finalOutputName)
val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId)
val tempOutputPath =
new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptNumber)

if (fs.exists(tempOutputPath)) {
throw new IOException("Checkpoint failed: temporary path " +
Expand All @@ -119,7 +120,7 @@ private[spark] object CheckpointRDD extends Logging {
logInfo("Deleting tempOutputPath " + tempOutputPath)
fs.delete(tempOutputPath, false)
throw new IOException("Checkpoint failed: failed to save output of task: "
+ ctx.attemptId + " and final output path does not exist")
+ ctx.attemptNumber + " and final output path does not exist")
} else {
// Some other copy of this task must've finished before us and renamed it
logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it")
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ class HadoopRDD[K, V](
var reader: RecordReader[K, V] = null
val inputFormat = getInputFormat(jobConf)
HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime),
context.stageId, theSplit.index, context.attemptId.toInt, jobConf)
context.stageId, theSplit.index, context.attemptNumber, jobConf)
reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)

// Register an on-task-completion callback to close the input stream.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -978,12 +978,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])

val writeShard = (context: TaskContext, iter: Iterator[(K,V)]) => {
val config = wrappedConf.value
// Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
/* "reduce task" <split #> <attempt # = spark task #> */
val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId,
attemptNumber)
context.attemptNumber)
val hadoopContext = newTaskAttemptContext(config, attemptId)
val format = outfmt.newInstance
format match {
Expand Down Expand Up @@ -1062,11 +1059,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
val config = wrappedConf.value
// Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt

val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config)

writer.setup(context.stageId, context.partitionId, attemptNumber)
writer.setup(context.stageId, context.partitionId, taskAttemptId)
writer.open()
try {
var recordsWritten = 0L
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -635,8 +635,8 @@ class DAGScheduler(
try {
val rdd = job.finalStage.rdd
val split = rdd.partitions(job.partitions(0))
val taskContext =
new TaskContextImpl(job.finalStage.id, job.partitions(0), 0, true)
val taskContext = new TaskContextImpl(job.finalStage.id, job.partitions(0), taskAttemptId = 0,
attemptNumber = 0, runningLocally = true)
TaskContextHelper.setTaskContext(taskContext)
try {
val result = job.func(taskContext, rdd.iterator(split, taskContext))
Expand Down
12 changes: 10 additions & 2 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,16 @@ import org.apache.spark.util.Utils
*/
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {

final def run(attemptId: Long): T = {
context = new TaskContextImpl(stageId, partitionId, attemptId, runningLocally = false)
/**
* Called by Executor to run this task.
*
* @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext.
* @param attemptNumber how many times this task has been attempted (0 for the first attempt)
* @return the result of the task
*/
final def run(taskAttemptId: Long, attemptNumber: Int): T = {
context = new TaskContextImpl(stageId = stageId, partitionId = partitionId,
taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false)
TaskContextHelper.setTaskContext(context)
context.taskMetrics.hostname = Utils.localHostName()
taskThread = Thread.currentThread()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.util.SerializableBuffer
*/
private[spark] class TaskDescription(
val taskId: Long,
val attemptNumber: Int,
val executorId: String,
val name: String,
val index: Int, // Index within this task's TaskSet
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,8 @@ private[spark] class TaskSetManager(
taskName, taskId, host, taskLocality, serializedTask.limit))

sched.dagScheduler.taskStarted(task, info)
return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask))
return Some(new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId,
taskName, index, serializedTask))
}
case _ =>
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.util.{ArrayList => JArrayList, List => JList}
import java.util.Collections

import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.collection.mutable.{HashMap, HashSet}

import org.apache.mesos.protobuf.ByteString
import org.apache.mesos.{Scheduler => MScheduler}
Expand Down Expand Up @@ -296,7 +296,7 @@ private[spark] class MesosSchedulerBackend(
.setExecutor(createExecutorInfo(slaveId))
.setName(task.name)
.addResources(cpuResource)
.setData(ByteString.copyFrom(task.serializedTask))
.setData(MesosTaskLaunchData(task.serializedTask, task.attemptNumber).toByteString)
.build()
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.scheduler.cluster.mesos

import java.nio.ByteBuffer

import org.apache.mesos.protobuf.ByteString

/**
* Wrapper for serializing the data sent when launching Mesos tasks.
*/
private[spark] case class MesosTaskLaunchData(
serializedTask: ByteBuffer,
attemptNumber: Int) {

def toByteString: ByteString = {
val dataBuffer = ByteBuffer.allocate(4 + serializedTask.limit)
dataBuffer.putInt(attemptNumber)
dataBuffer.put(serializedTask)
ByteString.copyFrom(dataBuffer)
}
}

private[spark] object MesosTaskLaunchData {
def fromByteString(byteString: ByteString): MesosTaskLaunchData = {
val byteBuffer = byteString.asReadOnlyByteBuffer()
val attemptNumber = byteBuffer.getInt // updates the position by 4 bytes
val serializedTask = byteBuffer.slice() // subsequence starting at the current position
MesosTaskLaunchData(serializedTask, attemptNumber)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ private[spark] class LocalActor(
val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores))
for (task <- scheduler.resourceOffers(offers).flatten) {
freeCores -= scheduler.CPUS_PER_TASK
executor.launchTask(executorBackend, task.taskId, task.name, task.serializedTask)
executor.launchTask(executorBackend, taskId = task.taskId, attemptNumber = task.attemptNumber,
task.name, task.serializedTask)
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion core/src/test/java/org/apache/spark/JavaAPISuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,7 @@ public void persist() {
@Test
public void iterator() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
TaskContext context = new TaskContextImpl(0, 0, 0L, false, new TaskMetrics());
TaskContext context = new TaskContextImpl(0, 0, 0L, 0, false, new TaskMetrics());
Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
}

Expand Down
8 changes: 4 additions & 4 deletions core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
// in blockManager.put is a losing battle. You have been warned.
blockManager = sc.env.blockManager
cacheManager = sc.env.cacheManager
val context = new TaskContextImpl(0, 0, 0)
val context = new TaskContextImpl(0, 0, 0, 0)
val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
val getValue = blockManager.get(RDDBlockId(rdd.id, split.index))
assert(computeValue.toList === List(1, 2, 3, 4))
Expand All @@ -81,7 +81,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}

whenExecuting(blockManager) {
val context = new TaskContextImpl(0, 0, 0)
val context = new TaskContextImpl(0, 0, 0, 0)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(5, 6, 7))
}
Expand All @@ -94,15 +94,15 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}

whenExecuting(blockManager) {
val context = new TaskContextImpl(0, 0, 0, true)
val context = new TaskContextImpl(0, 0, 0, 0, true)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4))
}
}

test("verify task metrics updated correctly") {
cacheManager = sc.env.cacheManager
val context = new TaskContextImpl(0, 0, 0)
val context = new TaskContextImpl(0, 0, 0, 0)
cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY)
assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext {
}
val hadoopPart1 = generateFakeHadoopPartition()
val pipedRdd = new PipedRDD(nums, "printenv " + varName)
val tContext = new TaskContextImpl(0, 0, 0)
val tContext = new TaskContextImpl(0, 0, 0, 0)
val rddIter = pipedRdd.compute(hadoopPart1, tContext)
val arr = rddIter.toArray
assert(arr(0) == "/some/path")
Expand Down
Loading