diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumerSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumerSuite.scala index c607c4fc81b71..30e8e348f74d2 100644 --- a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumerSuite.scala +++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumerSuite.scala @@ -103,11 +103,11 @@ class KafkaDataConsumerSuite val kafkaParams = getKafkaParams() val key = CacheKey(groupId, topicPartition) - val context1 = new TaskContextImpl(0, 0, 0, 0, 0, null, null, null) + val context1 = new TaskContextImpl(0, 0, 0, 0, 0, 1, null, null, null) TaskContext.setTaskContext(context1) val consumer1Underlying = initSingleConsumer(kafkaParams, key) - val context2 = new TaskContextImpl(0, 0, 0, 0, 1, null, null, null) + val context2 = new TaskContextImpl(0, 0, 0, 0, 1, 1, null, null, null) TaskContext.setTaskContext(context2) val consumer2Underlying = initSingleConsumer(kafkaParams, key) @@ -123,7 +123,7 @@ class KafkaDataConsumerSuite val kafkaParams = getKafkaParams() val key = new CacheKey(groupId, topicPartition) - val context = new TaskContextImpl(0, 0, 0, 0, 0, null, null, null) + val context = new TaskContextImpl(0, 0, 0, 0, 0, 1, null, null, null) TaskContext.setTaskContext(context) setSparkEnv( Map( @@ -145,7 +145,7 @@ class KafkaDataConsumerSuite val kafkaParams = getKafkaParams() val key = new CacheKey(groupId, topicPartition) - val context = new TaskContextImpl(0, 0, 0, 0, 0, null, null, null) + val context = new TaskContextImpl(0, 0, 0, 0, 0, 1, null, null, null) TaskContext.setTaskContext(context) setSparkEnv( Map( @@ -198,7 +198,8 @@ class KafkaDataConsumerSuite def consume(i: Int): Unit = { val taskContext = if (Random.nextBoolean) { - new TaskContextImpl(0, 0, 0, 0, attemptNumber = Random.nextInt(2), null, null, null) + new TaskContextImpl(0, 0, 0, 0, attemptNumber = Random.nextInt(2), 1, + null, null, null) } else { null } diff --git a/connector/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala b/connector/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala index 82913cf416a5f..9c461e73875b8 100644 --- a/connector/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala +++ b/connector/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala @@ -93,7 +93,7 @@ class KafkaDataConsumerSuite extends SparkFunSuite with MockitoSugar with Before val kafkaParams = getKafkaParams() val key = new CacheKey(groupId, topicPartition) - val context1 = new TaskContextImpl(0, 0, 0, 0, 0, null, null, null) + val context1 = new TaskContextImpl(0, 0, 0, 0, 0, 1, null, null, null) val consumer1 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]]( topicPartition, kafkaParams, context1, true) consumer1.release() @@ -101,7 +101,7 @@ class KafkaDataConsumerSuite extends SparkFunSuite with MockitoSugar with Before assert(KafkaDataConsumer.cache.size() == 1) assert(KafkaDataConsumer.cache.get(key).eq(consumer1.internalConsumer)) - val context2 = new TaskContextImpl(0, 0, 0, 0, 1, null, null, null) + val context2 = new TaskContextImpl(0, 0, 0, 0, 1, 1, null, null, null) val consumer2 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]]( topicPartition, kafkaParams, context2, true) consumer2.release() @@ -126,7 +126,7 @@ class KafkaDataConsumerSuite extends SparkFunSuite with MockitoSugar with Before def consume(i: Int): Unit = { val useCache = Random.nextBoolean val taskContext = if (Random.nextBoolean) { - new TaskContextImpl(0, 0, 0, 0, attemptNumber = Random.nextInt(2), null, null, null) + new TaskContextImpl(0, 0, 0, 0, attemptNumber = Random.nextInt(2), 1, null, null, null) } else { null } diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala index aa63e617b723a..ecc0c891ea161 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -55,10 +55,6 @@ class BarrierTaskContext private[spark] ( // with the driver side epoch. private var barrierEpoch = 0 - // Number of tasks of the current barrier stage, a barrier() call must collect enough requests - // from different tasks within the same barrier stage attempt to succeed. - private lazy val numTasks = getTaskInfos().size - private def runBarrier(message: String, requestMethod: RequestMethod.Value): Array[String] = { logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " + s"the global sync, current barrier epoch is $barrierEpoch.") @@ -78,7 +74,7 @@ class BarrierTaskContext private[spark] ( try { val abortableRpcFuture = barrierCoordinator.askAbortable[Array[String]]( - message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, + message = RequestToSync(numPartitions, stageId, stageAttemptNumber, taskAttemptId, barrierEpoch, partitionId, message, requestMethod), // Set a fixed timeout for RPC here, so users shall get a SparkException thrown by // BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework. @@ -215,6 +211,8 @@ class BarrierTaskContext private[spark] ( override def partitionId(): Int = taskContext.partitionId() + override def numPartitions(): Int = taskContext.numPartitions() + override def attemptNumber(): Int = taskContext.attemptNumber() override def taskAttemptId(): Long = taskContext.taskAttemptId() diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index ed781be299b71..3d5be09d0ebe3 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -67,7 +67,7 @@ object TaskContext { * An empty task context that does not represent an actual task. This is only used in tests. */ private[spark] def empty(): TaskContextImpl = { - new TaskContextImpl(0, 0, 0, 0, 0, + new TaskContextImpl(0, 0, 0, 0, 0, 1, null, new Properties, null, TaskMetrics.empty, 1) } } @@ -165,6 +165,11 @@ abstract class TaskContext extends Serializable { */ def partitionId(): Int + /** + * Total number of partitions in the stage that this task belongs to. + */ + def numPartitions(): Int + /** * How many times this task has been attempted. The first task attempt will be assigned * attemptNumber = 0, and subsequent attempts will have increasing attempt numbers. diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index cb7f4304d07cb..075d79b9d55a7 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -49,6 +49,7 @@ private[spark] class TaskContextImpl( override val partitionId: Int, override val taskAttemptId: Long, override val attemptNumber: Int, + override val numPartitions: Int, override val taskMemoryManager: TaskMemoryManager, localProperties: Properties, @transient private val metricsSystem: MetricsSystem, diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index ffaabba71e8cc..0e8c2f699557b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1536,8 +1536,8 @@ private[spark] class DAGScheduler( val locs = taskIdToLocations(id) val part = partitions(id) stage.pendingPartitions += id - new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber, - taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId), + new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, + part, stage.numPartitions, locs, properties, serializedTaskMetrics, Option(jobId), Option(sc.applicationId), sc.applicationAttemptId, stage.rdd.isBarrier()) } @@ -1547,7 +1547,7 @@ private[spark] class DAGScheduler( val part = partitions(p) val locs = taskIdToLocations(id) new ResultTask(stage.id, stage.latestInfo.attemptNumber, - taskBinary, part, locs, id, properties, serializedTaskMetrics, + taskBinary, part, stage.numPartitions, locs, id, properties, serializedTaskMetrics, Option(jobId), Option(sc.applicationId), sc.applicationAttemptId, stage.rdd.isBarrier()) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 15f2161fac39d..cc3677fc4d4ae 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -37,6 +37,7 @@ import org.apache.spark.rdd.RDD * partition of the given RDD. Once deserialized, the type should be * (RDD[T], (TaskContext, Iterator[T]) => U). * @param partition partition of the RDD this task is associated with + * @param numPartitions Total number of partitions in the stage that this task belongs to. * @param locs preferred task execution locations for locality scheduling * @param outputId index of the task in this job (a job can launch tasks on only a subset of the * input RDD's partitions). @@ -56,6 +57,7 @@ private[spark] class ResultTask[T, U]( stageAttemptId: Int, taskBinary: Broadcast[Array[Byte]], partition: Partition, + numPartitions: Int, locs: Seq[TaskLocation], val outputId: Int, localProperties: Properties, @@ -64,8 +66,8 @@ private[spark] class ResultTask[T, U]( appId: Option[String] = None, appAttemptId: Option[String] = None, isBarrier: Boolean = false) - extends Task[U](stageId, stageAttemptId, partition.index, localProperties, serializedTaskMetrics, - jobId, appId, appAttemptId, isBarrier) + extends Task[U](stageId, stageAttemptId, partition.index, numPartitions, localProperties, + serializedTaskMetrics, jobId, appId, appAttemptId, isBarrier) with Serializable { @transient private[this] val preferredLocs: Seq[TaskLocation] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 89db3a86f4ce8..b068709410842 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -37,6 +37,7 @@ import org.apache.spark.rdd.RDD * @param taskBinary broadcast version of the RDD and the ShuffleDependency. Once deserialized, * the type should be (RDD[_], ShuffleDependency[_, _, _]). * @param partition partition of the RDD this task is associated with + * @param numPartitions Total number of partitions in the stage that this task belongs to. * @param locs preferred task execution locations for locality scheduling * @param localProperties copy of thread-local properties set by the user on the driver side. * @param serializedTaskMetrics a `TaskMetrics` that is created and serialized on the driver side @@ -54,6 +55,7 @@ private[spark] class ShuffleMapTask( stageAttemptId: Int, taskBinary: Broadcast[Array[Byte]], partition: Partition, + numPartitions: Int, @transient private var locs: Seq[TaskLocation], localProperties: Properties, serializedTaskMetrics: Array[Byte], @@ -61,13 +63,13 @@ private[spark] class ShuffleMapTask( appId: Option[String] = None, appAttemptId: Option[String] = None, isBarrier: Boolean = false) - extends Task[MapStatus](stageId, stageAttemptId, partition.index, localProperties, + extends Task[MapStatus](stageId, stageAttemptId, partition.index, numPartitions, localProperties, serializedTaskMetrics, jobId, appId, appAttemptId, isBarrier) with Logging { /** A constructor used only in test suites. This does not require passing in an RDD. */ def this(partitionId: Int) = { - this(0, 0, null, new Partition { override def index: Int = 0 }, null, new Properties, null) + this(0, 0, null, new Partition { override def index: Int = 0 }, 1, null, new Properties, null) } @transient private val preferredLocs: Seq[TaskLocation] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 3ef8361efe8e1..8f11520693add 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -44,6 +44,7 @@ import org.apache.spark.util._ * @param stageId id of the stage this task belongs to * @param stageAttemptId attempt id of the stage this task belongs to * @param partitionId index of the number in the RDD + * @param numPartitions Total number of partitions in the stage that this task belongs to. * @param localProperties copy of thread-local properties set by the user on the driver side. * @param serializedTaskMetrics a `TaskMetrics` that is created and serialized on the driver side * and sent to executor side. @@ -59,6 +60,7 @@ private[spark] abstract class Task[T]( val stageId: Int, val stageAttemptId: Int, val partitionId: Int, + val numPartitions: Int, @transient var localProperties: Properties = new Properties, // The default value is only used in tests. serializedTaskMetrics: Array[Byte] = @@ -98,6 +100,7 @@ private[spark] abstract class Task[T]( partitionId, taskAttemptId, attemptNumber, + numPartitions, taskMemoryManager, localProperties, metricsSystem, diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index c1a964c336109..6022d49a0b504 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -369,7 +369,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalRootDi // first attempt -- its successful val context1 = - new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem) + new TaskContextImpl(0, 0, 0, 0L, 0, 1, taskMemoryManager, new Properties, metricsSystem) val writer1 = manager.getWriter[Int, Int]( shuffleHandle, 0, context1, context1.taskMetrics.shuffleWriteMetrics) val data1 = (1 to 10).map { x => x -> x} @@ -378,7 +378,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalRootDi // just to simulate the fact that the records may get written differently // depending on what gets spilled, what gets combined, etc. val context2 = - new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem) + new TaskContextImpl(0, 0, 0, 1L, 0, 1, taskMemoryManager, new Properties, metricsSystem) val writer2 = manager.getWriter[Int, Int]( shuffleHandle, 0, context2, context2.taskMetrics.shuffleWriteMetrics) val data2 = (11 to 20).map { x => x -> x} @@ -413,7 +413,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalRootDi } val taskContext = new TaskContextImpl( - 1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem) + 1, 0, 0, 2L, 0, 1, taskMemoryManager, new Properties, metricsSystem) val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, taskContext, metrics) TaskContext.unset() diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 943f4e115a596..2205fd47af912 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -549,6 +549,7 @@ class ExecutorSuite extends SparkFunSuite stageAttemptId = 0, taskBinary = taskBinary, partition = rdd.partitions(0), + numPartitions = 1, locs = Seq(), outputId = 0, localProperties = new Properties(), diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala index dcf89e4f75acf..67fbf115d80b6 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala @@ -33,6 +33,7 @@ object MemoryTestingUtils { partitionId = 0, taskAttemptId = 0, attemptNumber = 0, + numPartitions = 1, taskMemoryManager = taskMemoryManager, localProperties = new Properties, metricsSystem = env.metricsSystem) diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index 9ec088aaddddd..fdd89378927e4 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -30,7 +30,7 @@ class FakeTask( serializedTaskMetrics: Array[Byte] = SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array(), isBarrier: Boolean = false) - extends Task[Int](stageId, 0, partitionId, new Properties, serializedTaskMetrics, + extends Task[Int](stageId, 0, partitionId, 1, new Properties, serializedTaskMetrics, isBarrier = isBarrier) { override def runTask(context: TaskContext): Int = 0 @@ -96,7 +96,7 @@ object FakeTask { val tasks = Array.tabulate[Task[_]](numTasks) { i => new ShuffleMapTask(stageId, stageAttemptId, null, new Partition { override def index: Int = i - }, prefLocs(i), new Properties, + }, 1, prefLocs(i), new Properties, SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array()) } new TaskSet(tasks, stageId, stageAttemptId, priority = priority, null, diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala index 255be6f46b06b..2631ab2a92a74 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala @@ -25,7 +25,7 @@ import org.apache.spark.TaskContext * A Task implementation that fails to serialize. */ private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int) - extends Task[Array[Byte]](stageId, 0, 0) { + extends Task[Array[Byte]](stageId, 0, 0, 1) { override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte] override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]() diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 693841d843f0b..0f6fe6ca35896 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -70,7 +70,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark val func = (c: TaskContext, i: Iterator[String]) => i.next() val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func)))) val task = new ResultTask[String, String]( - 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties, + 0, 0, taskBinary, rdd.partitions(0), 1, Seq.empty, 0, new Properties, closureSerializer.serialize(TaskMetrics.registered).array()) intercept[RuntimeException] { task.run(0, 0, null, 1, null, Option.empty) @@ -92,7 +92,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark val func = (c: TaskContext, i: Iterator[String]) => i.next() val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func)))) val task = new ResultTask[String, String]( - 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties, + 0, 0, taskBinary, rdd.partitions(0), 1, Seq.empty, 0, new Properties, closureSerializer.serialize(TaskMetrics.registered).array()) intercept[RuntimeException] { task.run(0, 0, null, 1, null, Option.empty) @@ -187,6 +187,28 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark assert(stageAttemptNumbersWithFailedStage.toSet === Set(2)) } + test("TaskContext.get.numPartitions getter") { + sc = new SparkContext("local[1,2]", "test") + + for (numPartitions <- 1 to 10) { + val numPartitionsFromContext = sc.parallelize(1 to 1000, numPartitions) + .mapPartitions { _ => + Seq(TaskContext.get.numPartitions()).iterator + }.collect() + assert(numPartitionsFromContext.toSet === Set(numPartitions), + s"numPartitions = $numPartitions") + } + + for (numPartitions <- 1 to 10) { + val numPartitionsFromContext = sc.parallelize(1 to 1000, 2).repartition(numPartitions) + .mapPartitions { _ => + Seq(TaskContext.get.numPartitions()).iterator + }.collect() + assert(numPartitionsFromContext.toSet === Set(numPartitions), + s"numPartitions = $numPartitions") + } + } + test("accumulators are updated on exception failures") { // This means use 1 core and 4 max task failures sc = new SparkContext("local[1,4]", "test") @@ -218,8 +240,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark // Create a dummy task. We won't end up running this; we just want to collect // accumulator updates from it. val taskMetrics = TaskMetrics.empty - val task = new Task[Int](0, 0, 0) { - context = new TaskContextImpl(0, 0, 0, 0L, 0, + val task = new Task[Int](0, 0, 0, 1) { + context = new TaskContextImpl(0, 0, 0, 0L, 0, 1, new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), new Properties, SparkEnv.get.metricsSystem, @@ -241,8 +263,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark // Create a dummy task. We won't end up running this; we just want to collect // accumulator updates from it. val taskMetrics = TaskMetrics.registered - val task = new Task[Int](0, 0, 0) { - context = new TaskContextImpl(0, 0, 0, 0L, 0, + val task = new Task[Int](0, 0, 0, 1) { + context = new TaskContextImpl(0, 0, 0, 0L, 0, 1, new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), new Properties, SparkEnv.get.metricsSystem, diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 85ea4f582e37c..71eda0063fec9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -2025,11 +2025,11 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B new WorkerOffer("executor1", "host1", 1)) val task1 = new ShuffleMapTask(1, 0, null, new Partition { override def index: Int = 0 - }, Seq(TaskLocation("host0", "executor0")), new Properties, null) + }, 1, Seq(TaskLocation("host0", "executor0")), new Properties, null) val task2 = new ShuffleMapTask(1, 0, null, new Partition { override def index: Int = 1 - }, Seq(TaskLocation("host1", "executor1")), new Properties, null) + }, 1, Seq(TaskLocation("host1", "executor1")), new Properties, null) val taskSet = new TaskSet(Array(task1, task2), 0, 0, 0, null, 0) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 360a14b031139..f21daa1aea6c5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -181,7 +181,7 @@ class FakeTaskScheduler( /** * A Task implementation that results in a large serialized task. */ -class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0) { +class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0, 1) { val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KIB * 1024) val random = new Random(0) @@ -853,7 +853,7 @@ class TaskSetManagerSuite val singleTask = new ShuffleMapTask(0, 0, null, new Partition { override def index: Int = 0 - }, Seq(TaskLocation("host1", "execA")), new Properties, null) + }, 1, Seq(TaskLocation("host1", "execA")), new Properties, null) val taskSet = new TaskSet(Array(singleTask), 0, 0, 0, null, ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala index 8ffc6798526b4..bb4de5e116ad8 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala @@ -64,7 +64,7 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach { try { TaskContext.setTaskContext( new TaskContextImpl(0, 0, 0, taskAttemptId, 0, - null, new Properties, null, TaskMetrics.empty, 1)) + 1, null, new Properties, null, TaskMetrics.empty, 1)) block } finally { TaskContext.unset() diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index cc148d9e247f6..56ddd3223fe47 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -94,6 +94,9 @@ object MimaExcludes { // [SPARK-36173][CORE] Support getting CPU number in TaskContext ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.cpus"), + // [SPARK-38679][CORE] Expose the number of partitions in a stage to TaskContext + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.numPartitions"), + // [SPARK-35896] Include more granular metrics for stateful operators in StreamingQueryProgress ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StateOperatorProgress.this"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregatingAccumulatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregatingAccumulatorSuite.scala index 06fc2022c01ad..fc96a0f4ac986 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregatingAccumulatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregatingAccumulatorSuite.scala @@ -103,7 +103,7 @@ class AggregatingAccumulatorSuite checkResult(acc_driver.value, InternalRow(null, null, 0), acc_driver.schema, false) def inPartition(id: Int)(f: => Unit): Unit = { - val ctx = new TaskContextImpl(0, 0, 1, 0, 0, null, new Properties, null) + val ctx = new TaskContextImpl(0, 0, 1, 0, 0, 1, null, new Properties, null) TaskContext.setTaskContext(ctx) try { f diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 97e5c1148c244..5359793610037 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -78,6 +78,7 @@ class UnsafeFixedWidthAggregationMapSuite stageId = 0, stageAttemptNumber = 0, partitionId = 0, + numPartitions = 1, taskAttemptId = Random.nextInt(10000), attemptNumber = 0, taskMemoryManager = taskMemoryManager, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index f630cd8322c61..b3370b6733d92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -121,6 +121,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSparkSession partitionId = 0, taskAttemptId = 98456, attemptNumber = 0, + numPartitions = 1, taskMemoryManager = taskMemMgr, localProperties = new Properties, metricsSystem = null)) @@ -216,7 +217,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSparkSession // Make sure we can successfully create a UnsafeKVExternalSorter with a `BytesToBytesMap` // which has duplicated keys and the number of entries exceeds its capacity. try { - val context = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties(), null) + val context = new TaskContextImpl(0, 0, 0, 0, 0, 1, taskMemoryManager, new Properties(), null) TaskContext.setTaskContext(context) new UnsafeKVExternalSorter( schema, @@ -239,7 +240,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSparkSession val schema = new StructType().add("i", IntegerType) try { - val context = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties(), null) + val context = new TaskContextImpl(0, 0, 0, 0, 0, 1, taskMemoryManager, new Properties(), null) TaskContext.setTaskContext(context) val expectedSpillSize = map.getTotalMemoryConsumption val sorter = new UnsafeKVExternalSorter( @@ -264,7 +265,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSparkSession val schema = new StructType().add("i", IntegerType) try { - val context = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties(), null) + val context = new TaskContextImpl(0, 0, 0, 0, 0, 1, taskMemoryManager, new Properties(), null) TaskContext.setTaskContext(context) val expectedSpillSize = map1.getTotalMemoryConsumption + map2.getTotalMemoryConsumption val sorter1 = new UnsafeKVExternalSorter( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index 1640a9611ec35..3b9984a312e57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -112,7 +112,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkSession { (i, converter(Row(i))) } val taskMemoryManager = new TaskMemoryManager(spark.sparkContext.env.memoryManager, 0) - val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null) + val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, 1, taskMemoryManager, new Properties, null) val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( taskContext, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala index 3e47fd4289bef..a672a3fb1b344 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala @@ -36,7 +36,7 @@ class SortBasedAggregationStoreSuite extends SparkFunSuite with LocalSparkConte sc = new SparkContext("local[2, 4]", "test", conf) val taskManager = new TaskMemoryManager(new TestMemoryManager(conf), 0) TaskContext.setTaskContext( - new TaskContextImpl(0, 0, 0, 0, 0, taskManager, new Properties, null)) + new TaskContextImpl(0, 0, 0, 0, 0, 1, taskManager, new Properties, null)) } override def afterAll(): Unit = try { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionsIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionsIteratorSuite.scala index 045901bc20ca4..34c4939cbc1a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionsIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionsIteratorSuite.scala @@ -52,7 +52,7 @@ class UpdatingSessionsIteratorSuite extends SharedSparkSession { super.beforeAll() val taskManager = new TaskMemoryManager(new TestMemoryManager(sqlContext.sparkContext.conf), 0) TaskContext.setTaskContext( - new TaskContextImpl(0, 0, 0, 0, 0, taskManager, new Properties, null)) + new TaskContextImpl(0, 0, 0, 0, 0, 1, taskManager, new Properties, null)) } override def afterAll(): Unit = try {