Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,11 @@ class KafkaDataConsumerSuite
val kafkaParams = getKafkaParams()
val key = CacheKey(groupId, topicPartition)

val context1 = new TaskContextImpl(0, 0, 0, 0, 0, null, null, null)
val context1 = new TaskContextImpl(0, 0, 0, 0, 0, 1, null, null, null)
TaskContext.setTaskContext(context1)
val consumer1Underlying = initSingleConsumer(kafkaParams, key)

val context2 = new TaskContextImpl(0, 0, 0, 0, 1, null, null, null)
val context2 = new TaskContextImpl(0, 0, 0, 0, 1, 1, null, null, null)
TaskContext.setTaskContext(context2)
val consumer2Underlying = initSingleConsumer(kafkaParams, key)

Expand All @@ -123,7 +123,7 @@ class KafkaDataConsumerSuite
val kafkaParams = getKafkaParams()
val key = new CacheKey(groupId, topicPartition)

val context = new TaskContextImpl(0, 0, 0, 0, 0, null, null, null)
val context = new TaskContextImpl(0, 0, 0, 0, 0, 1, null, null, null)
TaskContext.setTaskContext(context)
setSparkEnv(
Map(
Expand All @@ -145,7 +145,7 @@ class KafkaDataConsumerSuite
val kafkaParams = getKafkaParams()
val key = new CacheKey(groupId, topicPartition)

val context = new TaskContextImpl(0, 0, 0, 0, 0, null, null, null)
val context = new TaskContextImpl(0, 0, 0, 0, 0, 1, null, null, null)
TaskContext.setTaskContext(context)
setSparkEnv(
Map(
Expand Down Expand Up @@ -198,7 +198,8 @@ class KafkaDataConsumerSuite

def consume(i: Int): Unit = {
val taskContext = if (Random.nextBoolean) {
new TaskContextImpl(0, 0, 0, 0, attemptNumber = Random.nextInt(2), null, null, null)
new TaskContextImpl(0, 0, 0, 0, attemptNumber = Random.nextInt(2), 1,
null, null, null)
} else {
null
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,15 @@ class KafkaDataConsumerSuite extends SparkFunSuite with MockitoSugar with Before
val kafkaParams = getKafkaParams()
val key = new CacheKey(groupId, topicPartition)

val context1 = new TaskContextImpl(0, 0, 0, 0, 0, null, null, null)
val context1 = new TaskContextImpl(0, 0, 0, 0, 0, 1, null, null, null)
val consumer1 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]](
topicPartition, kafkaParams, context1, true)
consumer1.release()

assert(KafkaDataConsumer.cache.size() == 1)
assert(KafkaDataConsumer.cache.get(key).eq(consumer1.internalConsumer))

val context2 = new TaskContextImpl(0, 0, 0, 0, 1, null, null, null)
val context2 = new TaskContextImpl(0, 0, 0, 0, 1, 1, null, null, null)
val consumer2 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]](
topicPartition, kafkaParams, context2, true)
consumer2.release()
Expand All @@ -126,7 +126,7 @@ class KafkaDataConsumerSuite extends SparkFunSuite with MockitoSugar with Before
def consume(i: Int): Unit = {
val useCache = Random.nextBoolean
val taskContext = if (Random.nextBoolean) {
new TaskContextImpl(0, 0, 0, 0, attemptNumber = Random.nextInt(2), null, null, null)
new TaskContextImpl(0, 0, 0, 0, attemptNumber = Random.nextInt(2), 1, null, null, null)
} else {
null
}
Expand Down
8 changes: 3 additions & 5 deletions core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,6 @@ class BarrierTaskContext private[spark] (
// with the driver side epoch.
private var barrierEpoch = 0

// Number of tasks of the current barrier stage, a barrier() call must collect enough requests
// from different tasks within the same barrier stage attempt to succeed.
private lazy val numTasks = getTaskInfos().size

private def runBarrier(message: String, requestMethod: RequestMethod.Value): Array[String] = {
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " +
s"the global sync, current barrier epoch is $barrierEpoch.")
Expand All @@ -78,7 +74,7 @@ class BarrierTaskContext private[spark] (

try {
val abortableRpcFuture = barrierCoordinator.askAbortable[Array[String]](
message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId,
message = RequestToSync(numPartitions, stageId, stageAttemptNumber, taskAttemptId,
barrierEpoch, partitionId, message, requestMethod),
// Set a fixed timeout for RPC here, so users shall get a SparkException thrown by
// BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework.
Expand Down Expand Up @@ -215,6 +211,8 @@ class BarrierTaskContext private[spark] (

override def partitionId(): Int = taskContext.partitionId()

override def numPartitions(): Int = taskContext.numPartitions()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove lazy val numTasks in this file and use numPartitions() directly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in the latest commit.


override def attemptNumber(): Int = taskContext.attemptNumber()

override def taskAttemptId(): Long = taskContext.taskAttemptId()
Expand Down
7 changes: 6 additions & 1 deletion core/src/main/scala/org/apache/spark/TaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ object TaskContext {
* An empty task context that does not represent an actual task. This is only used in tests.
*/
private[spark] def empty(): TaskContextImpl = {
new TaskContextImpl(0, 0, 0, 0, 0,
new TaskContextImpl(0, 0, 0, 0, 0, 1,
null, new Properties, null, TaskMetrics.empty, 1)
}
}
Expand Down Expand Up @@ -165,6 +165,11 @@ abstract class TaskContext extends Serializable {
*/
def partitionId(): Int

/**
* Total number of partitions in the stage that this task belongs to.
*/
def numPartitions(): Int

/**
* How many times this task has been attempted. The first task attempt will be assigned
* attemptNumber = 0, and subsequent attempts will have increasing attempt numbers.
Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala/org/apache/spark/TaskContextImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ private[spark] class TaskContextImpl(
override val partitionId: Int,
override val taskAttemptId: Long,
override val attemptNumber: Int,
override val numPartitions: Int,
override val taskMemoryManager: TaskMemoryManager,
localProperties: Properties,
@transient private val metricsSystem: MetricsSystem,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1536,8 +1536,8 @@ private[spark] class DAGScheduler(
val locs = taskIdToLocations(id)
val part = partitions(id)
stage.pendingPartitions += id
new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber,
taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId),
new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber, taskBinary,
part, stage.numPartitions, locs, properties, serializedTaskMetrics, Option(jobId),
Option(sc.applicationId), sc.applicationAttemptId, stage.rdd.isBarrier())
}

Expand All @@ -1547,7 +1547,7 @@ private[spark] class DAGScheduler(
val part = partitions(p)
val locs = taskIdToLocations(id)
new ResultTask(stage.id, stage.latestInfo.attemptNumber,
taskBinary, part, locs, id, properties, serializedTaskMetrics,
taskBinary, part, stage.numPartitions, locs, id, properties, serializedTaskMetrics,
Option(jobId), Option(sc.applicationId), sc.applicationAttemptId,
stage.rdd.isBarrier())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import org.apache.spark.rdd.RDD
* partition of the given RDD. Once deserialized, the type should be
* (RDD[T], (TaskContext, Iterator[T]) => U).
* @param partition partition of the RDD this task is associated with
* @param numPartitions Total number of partitions in the stage that this task belongs to.
* @param locs preferred task execution locations for locality scheduling
* @param outputId index of the task in this job (a job can launch tasks on only a subset of the
* input RDD's partitions).
Expand All @@ -56,6 +57,7 @@ private[spark] class ResultTask[T, U](
stageAttemptId: Int,
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
numPartitions: Int,
locs: Seq[TaskLocation],
val outputId: Int,
localProperties: Properties,
Expand All @@ -64,8 +66,8 @@ private[spark] class ResultTask[T, U](
appId: Option[String] = None,
appAttemptId: Option[String] = None,
isBarrier: Boolean = false)
extends Task[U](stageId, stageAttemptId, partition.index, localProperties, serializedTaskMetrics,
jobId, appId, appAttemptId, isBarrier)
extends Task[U](stageId, stageAttemptId, partition.index, numPartitions, localProperties,
serializedTaskMetrics, jobId, appId, appAttemptId, isBarrier)
with Serializable {

@transient private[this] val preferredLocs: Seq[TaskLocation] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import org.apache.spark.rdd.RDD
* @param taskBinary broadcast version of the RDD and the ShuffleDependency. Once deserialized,
* the type should be (RDD[_], ShuffleDependency[_, _, _]).
* @param partition partition of the RDD this task is associated with
* @param numPartitions Total number of partitions in the stage that this task belongs to.
* @param locs preferred task execution locations for locality scheduling
* @param localProperties copy of thread-local properties set by the user on the driver side.
* @param serializedTaskMetrics a `TaskMetrics` that is created and serialized on the driver side
Expand All @@ -54,20 +55,21 @@ private[spark] class ShuffleMapTask(
stageAttemptId: Int,
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
numPartitions: Int,
@transient private var locs: Seq[TaskLocation],
localProperties: Properties,
serializedTaskMetrics: Array[Byte],
jobId: Option[Int] = None,
appId: Option[String] = None,
appAttemptId: Option[String] = None,
isBarrier: Boolean = false)
extends Task[MapStatus](stageId, stageAttemptId, partition.index, localProperties,
extends Task[MapStatus](stageId, stageAttemptId, partition.index, numPartitions, localProperties,
serializedTaskMetrics, jobId, appId, appAttemptId, isBarrier)
with Logging {

/** A constructor used only in test suites. This does not require passing in an RDD. */
def this(partitionId: Int) = {
this(0, 0, null, new Partition { override def index: Int = 0 }, null, new Properties, null)
this(0, 0, null, new Partition { override def index: Int = 0 }, 1, null, new Properties, null)
}

@transient private val preferredLocs: Seq[TaskLocation] = {
Expand Down
3 changes: 3 additions & 0 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import org.apache.spark.util._
* @param stageId id of the stage this task belongs to
* @param stageAttemptId attempt id of the stage this task belongs to
* @param partitionId index of the number in the RDD
* @param numPartitions Total number of partitions in the stage that this task belongs to.
* @param localProperties copy of thread-local properties set by the user on the driver side.
* @param serializedTaskMetrics a `TaskMetrics` that is created and serialized on the driver side
* and sent to executor side.
Expand All @@ -59,6 +60,7 @@ private[spark] abstract class Task[T](
val stageId: Int,
val stageAttemptId: Int,
val partitionId: Int,
val numPartitions: Int,
@transient var localProperties: Properties = new Properties,
// The default value is only used in tests.
serializedTaskMetrics: Array[Byte] =
Expand Down Expand Up @@ -98,6 +100,7 @@ private[spark] abstract class Task[T](
partitionId,
taskAttemptId,
attemptNumber,
numPartitions,
taskMemoryManager,
localProperties,
metricsSystem,
Expand Down
6 changes: 3 additions & 3 deletions core/src/test/scala/org/apache/spark/ShuffleSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalRootDi

// first attempt -- its successful
val context1 =
new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem)
new TaskContextImpl(0, 0, 0, 0L, 0, 1, taskMemoryManager, new Properties, metricsSystem)
val writer1 = manager.getWriter[Int, Int](
shuffleHandle, 0, context1, context1.taskMetrics.shuffleWriteMetrics)
val data1 = (1 to 10).map { x => x -> x}
Expand All @@ -378,7 +378,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalRootDi
// just to simulate the fact that the records may get written differently
// depending on what gets spilled, what gets combined, etc.
val context2 =
new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem)
new TaskContextImpl(0, 0, 0, 1L, 0, 1, taskMemoryManager, new Properties, metricsSystem)
val writer2 = manager.getWriter[Int, Int](
shuffleHandle, 0, context2, context2.taskMetrics.shuffleWriteMetrics)
val data2 = (11 to 20).map { x => x -> x}
Expand Down Expand Up @@ -413,7 +413,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalRootDi
}

val taskContext = new TaskContextImpl(
1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem)
1, 0, 0, 2L, 0, 1, taskMemoryManager, new Properties, metricsSystem)
val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics()
val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, taskContext, metrics)
TaskContext.unset()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ class ExecutorSuite extends SparkFunSuite
stageAttemptId = 0,
taskBinary = taskBinary,
partition = rdd.partitions(0),
numPartitions = 1,
locs = Seq(),
outputId = 0,
localProperties = new Properties(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ object MemoryTestingUtils {
partitionId = 0,
taskAttemptId = 0,
attemptNumber = 0,
numPartitions = 1,
taskMemoryManager = taskMemoryManager,
localProperties = new Properties,
metricsSystem = env.metricsSystem)
Expand Down
4 changes: 2 additions & 2 deletions core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class FakeTask(
serializedTaskMetrics: Array[Byte] =
SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array(),
isBarrier: Boolean = false)
extends Task[Int](stageId, 0, partitionId, new Properties, serializedTaskMetrics,
extends Task[Int](stageId, 0, partitionId, 1, new Properties, serializedTaskMetrics,
isBarrier = isBarrier) {

override def runTask(context: TaskContext): Int = 0
Expand Down Expand Up @@ -96,7 +96,7 @@ object FakeTask {
val tasks = Array.tabulate[Task[_]](numTasks) { i =>
new ShuffleMapTask(stageId, stageAttemptId, null, new Partition {
override def index: Int = i
}, prefLocs(i), new Properties,
}, 1, prefLocs(i), new Properties,
SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array())
}
new TaskSet(tasks, stageId, stageAttemptId, priority = priority, null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.TaskContext
* A Task implementation that fails to serialize.
*/
private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int)
extends Task[Array[Byte]](stageId, 0, 0) {
extends Task[Array[Byte]](stageId, 0, 0, 1) {

override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte]
override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
val func = (c: TaskContext, i: Iterator[String]) => i.next()
val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func))))
val task = new ResultTask[String, String](
0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties,
0, 0, taskBinary, rdd.partitions(0), 1, Seq.empty, 0, new Properties,
closureSerializer.serialize(TaskMetrics.registered).array())
intercept[RuntimeException] {
task.run(0, 0, null, 1, null, Option.empty)
Expand All @@ -92,7 +92,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
val func = (c: TaskContext, i: Iterator[String]) => i.next()
val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func))))
val task = new ResultTask[String, String](
0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties,
0, 0, taskBinary, rdd.partitions(0), 1, Seq.empty, 0, new Properties,
closureSerializer.serialize(TaskMetrics.registered).array())
intercept[RuntimeException] {
task.run(0, 0, null, 1, null, Option.empty)
Expand Down Expand Up @@ -187,6 +187,28 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
assert(stageAttemptNumbersWithFailedStage.toSet === Set(2))
}

test("TaskContext.get.numPartitions getter") {
sc = new SparkContext("local[1,2]", "test")

for (numPartitions <- 1 to 10) {
val numPartitionsFromContext = sc.parallelize(1 to 1000, numPartitions)
.mapPartitions { _ =>
Seq(TaskContext.get.numPartitions()).iterator
}.collect()
assert(numPartitionsFromContext.toSet === Set(numPartitions),
s"numPartitions = $numPartitions")
}

for (numPartitions <- 1 to 10) {
val numPartitionsFromContext = sc.parallelize(1 to 1000, 2).repartition(numPartitions)
.mapPartitions { _ =>
Seq(TaskContext.get.numPartitions()).iterator
}.collect()
assert(numPartitionsFromContext.toSet === Set(numPartitions),
s"numPartitions = $numPartitions")
}
}

test("accumulators are updated on exception failures") {
// This means use 1 core and 4 max task failures
sc = new SparkContext("local[1,4]", "test")
Expand Down Expand Up @@ -218,8 +240,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
// Create a dummy task. We won't end up running this; we just want to collect
// accumulator updates from it.
val taskMetrics = TaskMetrics.empty
val task = new Task[Int](0, 0, 0) {
context = new TaskContextImpl(0, 0, 0, 0L, 0,
val task = new Task[Int](0, 0, 0, 1) {
context = new TaskContextImpl(0, 0, 0, 0L, 0, 1,
new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),
new Properties,
SparkEnv.get.metricsSystem,
Expand All @@ -241,8 +263,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
// Create a dummy task. We won't end up running this; we just want to collect
// accumulator updates from it.
val taskMetrics = TaskMetrics.registered
val task = new Task[Int](0, 0, 0) {
context = new TaskContextImpl(0, 0, 0, 0L, 0,
val task = new Task[Int](0, 0, 0, 1) {
context = new TaskContextImpl(0, 0, 0, 0L, 0, 1,
new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),
new Properties,
SparkEnv.get.metricsSystem,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2025,11 +2025,11 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
new WorkerOffer("executor1", "host1", 1))
val task1 = new ShuffleMapTask(1, 0, null, new Partition {
override def index: Int = 0
}, Seq(TaskLocation("host0", "executor0")), new Properties, null)
}, 1, Seq(TaskLocation("host0", "executor0")), new Properties, null)

val task2 = new ShuffleMapTask(1, 0, null, new Partition {
override def index: Int = 1
}, Seq(TaskLocation("host1", "executor1")), new Properties, null)
}, 1, Seq(TaskLocation("host1", "executor1")), new Properties, null)

val taskSet = new TaskSet(Array(task1, task2), 0, 0, 0, null, 0)

Expand Down
Loading