From cf102cbce8049a3eb4963094a21898d166248392 Mon Sep 17 00:00:00 2001 From: Min Shen Date: Thu, 6 Aug 2020 09:42:19 -0700 Subject: [PATCH 01/24] LIHADOOP-48527 Driver side changes supporting Pushbased shuffle: 1. Handling of MergeResults from the executors in MapOutputTracker 2. Shuffle merge finalization in DagScheduler This also includes the following changes: - LIHADOOP-52972 Tests for changes in MapOutputTracker and DagScheduler related to pushbased shuffle. Author: Chandni Singh - LIHADOOP-52202 Utility to create a directory with 770 permission. Author: Chandni Singh - LIHADOOP-52972 Moved isPushBasedShuffleEnabled to Utils and added a unit test for it. Author: Ye Zhou --- .../org/apache/spark/MapOutputTracker.scala | 448 +++++++++++++++--- .../apache/spark/scheduler/MapStatus.scala | 10 +- .../apache/spark/scheduler/MergeStatus.scala | 92 ++++ .../apache/spark/MapOutputTrackerSuite.scala | 62 ++- .../spark/MapStatusesSerDeserBenchmark.scala | 6 +- 5 files changed, 546 insertions(+), 72 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index cdec1982b448..e40051cdbb4b 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -35,7 +35,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.io.CompressionCodec import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} -import org.apache.spark.scheduler.MapStatus +import org.apache.spark.scheduler.{ExecutorCacheTaskLocation, MapStatus, MergeStatus, OutputStatus} import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util._ @@ -49,7 +49,7 @@ import org.apache.spark.util._ * * All public methods of this class are thread-safe. */ -private class ShuffleStatus(numPartitions: Int) extends Logging { +private class ShuffleStatus(numPartitions: Int, numReducers: Int) extends Logging { private val (readLock, writeLock) = { val lock = new ReentrantReadWriteLock() @@ -86,6 +86,8 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { // Exposed for testing val mapStatuses = new Array[MapStatus](numPartitions) + val mergeStatuses = new Array[MergeStatus](numReducers) + /** * The cached result of serializing the map statuses array. This cache is lazily populated when * [[serializedMapStatus]] is called. The cache is invalidated when map outputs are removed. @@ -102,6 +104,10 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { */ private[this] var cachedSerializedBroadcast: Broadcast[Array[Byte]] = _ + private[this] var cachedSerializedMergeStatus: Array[Byte] = _ + + private[this] var cachedSerializedBroadcastMergeStatus: Broadcast[Array[Byte]] = _ + /** * Counter tracking the number of partitions that have output. This is a performance optimization * to avoid having to count the number of non-null entries in the `mapStatuses` array and should @@ -109,6 +115,11 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { */ private[this] var _numAvailableOutputs: Int = 0 + /** + * Counter tracking the number of MergeStatus result received so far. + */ + private[this] var _numAvailableMergeResults: Int = 0 + /** * Register a map output. If there is already a registered location for the map output then it * will be replaced by the new location. @@ -155,6 +166,28 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { } } + /** + * Register a merge result. + */ + def addMergeResult(reduceId: Int, status: MergeStatus): Unit = synchronized { + if (mergeStatuses(reduceId) == null) { + _numAvailableMergeResults += 1 + invalidateSerializedMergeOutputStatusCache() + } + mergeStatuses(reduceId) = status + } + + /** + * Remove the merge result which was served by the specified block manager. + */ + def removeMergeResult(reduceId: Int, bmAddress: BlockManagerId): Unit = synchronized { + if (mergeStatuses(reduceId) != null && mergeStatuses(reduceId).location == bmAddress) { + _numAvailableMergeResults -= 1 + mergeStatuses(reduceId) = null + invalidateSerializedMergeOutputStatusCache() + } + } + /** * Removes all shuffle outputs associated with this host. Note that this will also remove * outputs which are served by an external shuffle server (if one exists). @@ -186,6 +219,13 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { invalidateSerializedMapOutputStatusCache() } } + for (reduceId <- mergeStatuses.indices) { + if (mergeStatuses(reduceId) != null && f(mergeStatuses(reduceId).location)) { + _numAvailableMergeResults -= 1 + mergeStatuses(reduceId) = null + invalidateSerializedMergeOutputStatusCache() + } + } } /** @@ -195,6 +235,13 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { _numAvailableOutputs } + /** + * Number of merged partitions that have already been finalized. + */ + def numAvailableMergeResults: Int = synchronized { + _numAvailableMergeResults + } + /** * Returns the sequence of partition ids that are missing (i.e. needs to be computed). */ @@ -229,7 +276,7 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { if (result == null) withWriteLock { if (cachedSerializedMapStatus == null) { - val serResult = MapOutputTracker.serializeMapStatuses( + val serResult = MapOutputTracker.serializeOutputStatuses( mapStatuses, broadcastManager, isLocal, minBroadcastSize, conf) cachedSerializedMapStatus = serResult._1 cachedSerializedBroadcast = serResult._2 @@ -241,6 +288,23 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { result } + /** + * Serializes the mergeStatus array into an efficient compressed format. + */ + def serializedMergeStatus( + broadcastManager: BroadcastManager, + isLocal: Boolean, + minBroadcastSize: Int, + conf: SparkConf): Array[Byte] = synchronized { + if (cachedSerializedMergeStatus eq null) { + val serResult = MapOutputTracker.serializeOutputStatuses( + mergeStatuses, broadcastManager, isLocal, minBroadcastSize, conf) + cachedSerializedMergeStatus = serResult._1 + cachedSerializedBroadcastMergeStatus = serResult._2 + } + cachedSerializedMergeStatus + } + // Used in testing. def hasCachedSerializedBroadcast: Boolean = withReadLock { cachedSerializedBroadcast != null @@ -254,6 +318,10 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { f(mapStatuses) } + def withMergeStatuses[T](f: Array[MergeStatus] => T): T = synchronized { + f(mergeStatuses) + } + /** * Clears the cached serialized map output statuses. */ @@ -269,14 +337,36 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { } cachedSerializedMapStatus = null } + + /** + * Clears the cached serialized merge result statuses. + */ + def invalidateSerializedMergeOutputStatusCache(): Unit = synchronized { + if (cachedSerializedBroadcastMergeStatus != null) { + Utils.tryLogNonFatalError { + // Use `blocking = false` so that this operation doesn't hang while trying to send cleanup + // RPCs to dead executors. + cachedSerializedBroadcastMergeStatus.destroy(blocking = false) + } + cachedSerializedBroadcastMergeStatus = null + } + cachedSerializedMergeStatus = null + } } private[spark] sealed trait MapOutputTrackerMessage private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage +private[spark] case class GetMergeResultStatuses(shuffleId: Int) + extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage -private[spark] case class GetMapOutputMessage(shuffleId: Int, context: RpcCallContext) +/** + * The boolean flag in the case class indicates whether the request is for map output or not. + * If false, the request is for merge statuses instead. + */ +private[spark] case class GetOutputStatusesMessage(shuffleId: Int, + fetchMapOutput: Boolean, context: RpcCallContext) /** RpcEndpoint class for MapOutputTrackerMaster */ private[spark] class MapOutputTrackerMasterEndpoint( @@ -289,7 +379,12 @@ private[spark] class MapOutputTrackerMasterEndpoint( case GetMapOutputStatuses(shuffleId: Int) => val hostPort = context.senderAddress.hostPort logInfo(s"Asked to send map output locations for shuffle ${shuffleId} to ${hostPort}") - tracker.post(new GetMapOutputMessage(shuffleId, context)) + tracker.post(GetOutputStatusesMessage(shuffleId, true, context)) + + case GetMergeResultStatuses(shuffleId: Int) => + val hostPort = context.senderAddress.hostPort + logInfo(s"Asked to send merge result locations for shuffle $shuffleId to $hostPort") + tracker.post(GetOutputStatusesMessage(shuffleId, false, context)) case StopMapOutputTracker => logInfo("MapOutputTrackerMasterEndpoint stopped!") @@ -367,6 +462,15 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging startPartition: Int, endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] + /** + * Called from executors upon fetch failure on a merged shuffle block. This is to get the server + * URIs and output sizes for each shuffle block that is merged in the specified merged shuffle + * block so fetch failure on a merged shuffle block can fall back to fetching the unmerged blocks. + */ + def getMapSizesForMergeResult( + shuffleId: Int, + partitionId: Int): Seq[(BlockManagerId, Seq[(BlockId, Long)])] + /** * Deletes map output status information for the specified shuffle stage. */ @@ -408,6 +512,11 @@ private[spark] class MapOutputTrackerMaster( // can be read locally, but may lead to more delay in scheduling if those locations are busy. private val REDUCER_PREF_LOCS_FRACTION = 0.2 + // Fraction of map outputs that must be merged at one location for it to be considered as + // a preferred location for a reduce task. To avoid computation overhead, the fraction is + // w.r.t the number of blocks instead of block sizes. + private val MERGE_REF_LOCS_FRACTION = 0.6 + // HashMap for storing shuffleStatuses in the driver. // Statuses are dropped only by explicit de-registering. // Exposed for testing @@ -415,8 +524,10 @@ private[spark] class MapOutputTrackerMaster( private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) - // requests for map output statuses - private val mapOutputRequests = new LinkedBlockingQueue[GetMapOutputMessage] + // requests for map/merge output statuses + private val outputStatusesRequests = new LinkedBlockingQueue[GetOutputStatusesMessage] + + private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(conf) // Thread pool used for handling map output status requests. This is a separate thread pool // to ensure we don't block the normal dispatcher threads. @@ -439,8 +550,8 @@ private[spark] class MapOutputTrackerMaster( throw new IllegalArgumentException(msg) } - def post(message: GetMapOutputMessage): Unit = { - mapOutputRequests.offer(message) + def post(message: GetOutputStatusesMessage): Unit = { + outputStatusesRequests.offer(message) } /** Message loop used for dispatching messages. */ @@ -449,21 +560,29 @@ private[spark] class MapOutputTrackerMaster( try { while (true) { try { - val data = mapOutputRequests.take() - if (data == PoisonPill) { + val data = outputStatusesRequests.take() + if (data == PoisonPill) { // Put PoisonPill back so that other MessageLoops can see it. - mapOutputRequests.offer(PoisonPill) + outputStatusesRequests.offer(PoisonPill) return } val context = data.context val shuffleId = data.shuffleId val hostPort = context.senderAddress.hostPort - logDebug("Handling request to send map output locations for shuffle " + shuffleId + - " to " + hostPort) val shuffleStatus = shuffleStatuses.get(shuffleId).head - context.reply( - shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast, - conf)) + if (data.fetchMapOutput) { + logDebug("Handling request to send map output locations for shuffle " + shuffleId + + " to " + hostPort) + context.reply( + shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast, + conf)) + } else { + logDebug("Handling request to send merge output locations for shuffle " + shuffleId + + " to " + hostPort) + context.reply( + shuffleStatus.serializedMergeStatus(broadcastManager, isLocal, minSizeForBroadcast, + conf)) + } } catch { case NonFatal(e) => logError(e.getMessage, e) } @@ -475,15 +594,15 @@ private[spark] class MapOutputTrackerMaster( } /** A poison endpoint that indicates MessageLoop should exit its message loop. */ - private val PoisonPill = new GetMapOutputMessage(-99, null) + private val PoisonPill = new GetOutputStatusesMessage(-99, true, null) // Used only in unit tests. private[spark] def getNumCachedSerializedBroadcast: Int = { shuffleStatuses.valuesIterator.count(_.hasCachedSerializedBroadcast) } - def registerShuffle(shuffleId: Int, numMaps: Int): Unit = { - if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps)).isDefined) { + def registerShuffle(shuffleId: Int, numMaps: Int, numReduces: Int = 0): Unit = { + if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps, numReduces)).isDefined) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } } @@ -524,10 +643,35 @@ private[spark] class MapOutputTrackerMaster( } } + def registerMergeResult(shuffleId: Int, reduceId: Int, status: MergeStatus) { + shuffleStatuses(shuffleId).addMergeResult(reduceId, status) + } + + def registerMergeResults(shuffleId: Int, statuses: Seq[(Int, MergeStatus)]): Unit = { + statuses.foreach { + case (reduceId, status) => registerMergeResult(shuffleId, reduceId, status) + } + } + + def unregisterMergeResult(shuffleId: Int, reduceId: Int, bmAddress: BlockManagerId) { + shuffleStatuses.get(shuffleId) match { + case Some(shuffleStatus) => + shuffleStatus.removeMergeResult(reduceId, bmAddress) + incrementEpoch() + // TODO how to deal with epoch? Do we need separate epochs for map output and merge result + // TODO right now the reducer uses the epoch to decide whether their local cached + // TODO map output status is outdated or not. Ideally, updating merge result shouldn't + // TODO impact map output status. + case None => + throw new SparkException("unregisterMergeResult called for nonexistent shuffle ID") + } + } + /** Unregister shuffle data */ def unregisterShuffle(shuffleId: Int): Unit = { shuffleStatuses.remove(shuffleId).foreach { shuffleStatus => shuffleStatus.invalidateSerializedMapOutputStatusCache() + shuffleStatus.invalidateSerializedMergeOutputStatusCache() } } @@ -557,6 +701,10 @@ private[spark] class MapOutputTrackerMaster( shuffleStatuses.get(shuffleId).map(_.numAvailableOutputs).getOrElse(0) } + def getNumAvailableMergeResults(shuffleId: Int): Int = { + shuffleStatuses.get(shuffleId).map(_.numAvailableMergeResults).getOrElse(0) + } + /** * Returns the sequence of partition ids that are missing (i.e. needs to be computed), or None * if the MapOutputTrackerMaster doesn't know about this shuffle. @@ -565,6 +713,9 @@ private[spark] class MapOutputTrackerMaster( shuffleStatuses.get(shuffleId).map(_.findMissingPartitions()) } + // TODO do we need to provide a method to calculate missing maps for a MergeStatus of a + // TODO given shuffle? + /** * Grouped function of Range, this is to avoid traverse of all elements of Range using * IterableLike's grouped function. @@ -633,7 +784,9 @@ private[spark] class MapOutputTrackerMaster( /** * Return the preferred hosts on which to run the given map output partition in a given shuffle, - * i.e. the nodes that the most outputs for that partition are on. + * i.e. the nodes that the most outputs for that partition are on. If the map output is + * pre-merged, then return the node where the merged block is located if the merge ratio is + * above the threshold. * * @param dep shuffle dependency object * @param partitionId map output partition that we want to read @@ -641,15 +794,40 @@ private[spark] class MapOutputTrackerMaster( */ def getPreferredLocationsForShuffle(dep: ShuffleDependency[_, _, _], partitionId: Int) : Seq[String] = { - if (shuffleLocalityEnabled && dep.rdd.partitions.length < SHUFFLE_PREF_MAP_THRESHOLD && - dep.partitioner.numPartitions < SHUFFLE_PREF_REDUCE_THRESHOLD) { - val blockManagerIds = getLocationsWithLargestOutputs(dep.shuffleId, partitionId, - dep.partitioner.numPartitions, REDUCER_PREF_LOCS_FRACTION) - if (blockManagerIds.nonEmpty) { - blockManagerIds.get.map(_.host) + val shuffleStatus = shuffleStatuses.get(dep.shuffleId).orNull + if (shuffleStatus != null) { + // Check if the map output is pre-merged and if the merge ratio is above the threshold. + // If so, the location of the merged block is the preferred location. + val preferredLoc = if (pushBasedShuffleEnabled) { + shuffleStatus.withMergeStatuses { statuses => + val status = statuses(partitionId) + val numMaps = dep.rdd.partitions.length + if (status != null && status.getMissingMaps(numMaps).length.toDouble / numMaps + <= (1 - REDUCER_PREF_LOCS_FRACTION)) { + Seq(status.location.host) + } else { + Nil + } + } } else { Nil } + if (!preferredLoc.isEmpty) { + preferredLoc + } else { + if (shuffleLocalityEnabled && dep.rdd.partitions.length < SHUFFLE_PREF_MAP_THRESHOLD && + dep.partitioner.numPartitions < SHUFFLE_PREF_REDUCE_THRESHOLD) { + val blockManagerIds = getLocationsWithLargestOutputs(dep.shuffleId, partitionId, + dep.partitioner.numPartitions, REDUCER_PREF_LOCS_FRACTION) + if (blockManagerIds.nonEmpty) { + blockManagerIds.get.map(_.host) + } else { + Nil + } + } else { + Nil + } + } } else { Nil } @@ -751,6 +929,7 @@ private[spark] class MapOutputTrackerMaster( return epoch } } + // TODO should we have a separate epoch for merge status? // This method is only called in local-mode. def getMapSizesByExecutorId( @@ -774,8 +953,16 @@ private[spark] class MapOutputTrackerMaster( } } + // This method is only called in local-mode. Since push based shuffle won't be + // enabled in local-mode, this method returns empty list. + def getMapSizesForMergeResult( + shuffleId: Int, + partitionId: Int): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + Seq.empty + } + override def stop(): Unit = { - mapOutputRequests.offer(PoisonPill) + outputStatusesRequests.offer(PoisonPill) threadpool.shutdown() try { sendTracker(StopMapOutputTracker) @@ -799,6 +986,11 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr val mapStatuses: Map[Int, Array[MapStatus]] = new ConcurrentHashMap[Int, Array[MapStatus]]().asScala + val mergeStatuses: Map[Int, Array[MergeStatus]] = + new ConcurrentHashMap[Int, Array[MergeStatus]]().asScala + + private val fetchMergeResult = Utils.isPushBasedShuffleEnabled(conf) + /** * A [[KeyLock]] whose key is a shuffle id to ensure there is only one thread fetching * the same shuffle block. @@ -812,17 +1004,44 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr startPartition: Int, endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId") - val statuses = getStatuses(shuffleId, conf) + val (mapOutputStatuses, mergedOutputStatuses) = getStatuses(shuffleId, conf) try { - val actualEndMapIndex = if (endMapIndex == Int.MaxValue) statuses.length else endMapIndex + val actualEndMapIndex = + if (endMapIndex == Int.MaxValue) mapOutputStatuses.length else endMapIndex logDebug(s"Convert map statuses for shuffle $shuffleId, " + s"mappers $startMapIndex-$actualEndMapIndex, partitions $startPartition-$endPartition") MapOutputTracker.convertMapStatuses( - shuffleId, startPartition, endPartition, statuses, startMapIndex, actualEndMapIndex) + shuffleId, startPartition, endPartition, mapOutputStatuses, startMapIndex, + actualEndMapIndex, Option(mergedOutputStatuses)) } catch { case e: MetadataFetchFailedException => // We experienced a fetch failure so our mapStatuses cache is outdated; clear it: mapStatuses.clear() + mergeStatuses.clear() + throw e + } + } + + /** + * Called from executors upon fetch failure on a merged shuffle block. This is to get the server + * URIs and output sizes for each shuffle block that is merged in the specified merged shuffle + * block so fetch failure on a merged shuffle block can fall back to fetching the unmerged blocks. + */ + override def getMapSizesForMergeResult( + shuffleId: Int, + partitionId: Int): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition $partitionId") + // Fetch the map statuses and merge statuses again since they might have already been + // cleared by another task running in the same executor. + val (mapOutputStatuses, mergeResultStatues) = getStatuses(shuffleId, conf) + try { + MapOutputTracker.getMapStatusesForMergeStatus(shuffleId, partitionId, + mapOutputStatuses, mergeResultStatues) + } catch { + // We experienced a fetch failure so our mapStatuses cache is outdated; clear it: + case e: MetadataFetchFailedException => + mapStatuses.clear() + mergeStatuses.clear() throw e } } @@ -833,26 +1052,36 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr * * (It would be nice to remove this restriction in the future.) */ - private def getStatuses(shuffleId: Int, conf: SparkConf): Array[MapStatus] = { - val statuses = mapStatuses.get(shuffleId).orNull - if (statuses == null) { - logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") + private def getStatuses( + shuffleId: Int, conf: SparkConf): (Array[MapStatus], Array[MergeStatus]) = { + val mapOutputStatuses = mapStatuses.get(shuffleId).orNull + val mergeResultStatuses = mergeStatuses.get(shuffleId).orNull + if (mapOutputStatuses == null || (fetchMergeResult && mergeResultStatuses == null)) { + logInfo("Don't have map/merge outputs for shuffle " + shuffleId + ", fetching them") val startTimeNs = System.nanoTime() fetchingLock.withLock(shuffleId) { - var fetchedStatuses = mapStatuses.get(shuffleId).orNull - if (fetchedStatuses == null) { - logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) + var fetchedMapStatuses = mapStatuses.get(shuffleId).orNull + if (fetchedMapStatuses == null) { + logInfo("Doing the map fetch; tracker endpoint = " + trackerEndpoint) val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) - fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes, conf) - logInfo("Got the output locations") - mapStatuses.put(shuffleId, fetchedStatuses) + fetchedMapStatuses = MapOutputTracker.deserializeOutputStatuses(fetchedBytes, conf) + logInfo("Got the map output locations") + mapStatuses.put(shuffleId, fetchedMapStatuses) } - logDebug(s"Fetching map output statuses for shuffle $shuffleId took " + + var fetchedMergeStatues = mergeStatuses.get(shuffleId).orNull + if (fetchMergeResult && fetchedMergeStatues == null) { + logInfo("Doing the merge fetch; tracker endpoint = " + trackerEndpoint) + val fetchedBytes = askTracker[Array[Byte]](GetMergeResultStatuses(shuffleId)) + fetchedMergeStatues = MapOutputTracker.deserializeOutputStatuses(fetchedBytes, conf) + logInfo("Got the merge output locations") + mergeStatuses.put(shuffleId, fetchedMergeStatues) + } + logDebug(s"Fetching map/merge output statuses for shuffle $shuffleId took " + s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms") - fetchedStatuses + (fetchedMapStatuses, fetchedMergeStatues) } } else { - statuses + (mapOutputStatuses, mergeResultStatuses) } } @@ -860,6 +1089,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr /** Unregister shuffle data. */ def unregisterShuffle(shuffleId: Int): Unit = { mapStatuses.remove(shuffleId) + mergeStatuses.remove(shuffleId) } /** @@ -873,6 +1103,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr logInfo("Updating epoch to " + newEpoch + " and clearing cache") epoch = newEpoch mapStatuses.clear() + mergeStatuses.clear() } } } @@ -884,11 +1115,11 @@ private[spark] object MapOutputTracker extends Logging { private val DIRECT = 0 private val BROADCAST = 1 - // Serialize an array of map output locations into an efficient byte format so that we can send - // it to reduce tasks. We do this by compressing the serialized bytes using Zstd. They will - // generally be pretty compressible because many map outputs will be on the same hostname. - def serializeMapStatuses( - statuses: Array[MapStatus], + // Serialize an array of map/merge output locations into an efficient byte format so that we can + // send it to reduce tasks. We do this by compressing the serialized bytes using Zstd. They will + // generally be pretty compressible because many outputs will be on the same hostname. + def serializeOutputStatuses[T <: OutputStatus]( + statuses: Array[T], broadcastManager: BroadcastManager, isLocal: Boolean, minBroadcastSize: Int, @@ -924,7 +1155,7 @@ private[spark] object MapOutputTracker extends Logging { oos.close() } val outArr = out.toByteArray - logInfo("Broadcast mapstatuses size = " + outArr.length + ", actual size = " + arr.length) + logInfo("Broadcast outputstatuses size = " + outArr.length + ", actual size = " + arr.length) (outArr, bcast) } else { (arr, null) @@ -932,7 +1163,8 @@ private[spark] object MapOutputTracker extends Logging { } // Opposite of serializeMapStatuses. - def deserializeMapStatuses(bytes: Array[Byte], conf: SparkConf): Array[MapStatus] = { + def deserializeOutputStatuses[T <: OutputStatus]( + bytes: Array[Byte], conf: SparkConf): Array[T] = { assert (bytes.length > 0) def deserializeObject(arr: Array[Byte], off: Int, len: Int): AnyRef = { @@ -951,15 +1183,15 @@ private[spark] object MapOutputTracker extends Logging { bytes(0) match { case DIRECT => - deserializeObject(bytes, 1, bytes.length - 1).asInstanceOf[Array[MapStatus]] + deserializeObject(bytes, 1, bytes.length - 1).asInstanceOf[Array[T]] case BROADCAST => // deserialize the Broadcast, pull .value array out of it, and then deserialize that val bcast = deserializeObject(bytes, 1, bytes.length - 1). asInstanceOf[Broadcast[Array[Byte]]] - logInfo("Broadcast mapstatuses size = " + bytes.length + + logInfo("Broadcast outputstatuses size = " + bytes.length + ", actual size = " + bcast.value.length) // Important - ignore the DIRECT tag ! Start from offset 1 - deserializeObject(bcast.value, 1, bcast.value.length - 1).asInstanceOf[Array[MapStatus]] + deserializeObject(bcast.value, 1, bcast.value.length - 1).asInstanceOf[Array[T]] case _ => throw new IllegalArgumentException("Unexpected byte tag = " + bytes(0)) } } @@ -976,9 +1208,10 @@ private[spark] object MapOutputTracker extends Logging { * @param shuffleId Identifier for the shuffle * @param startPartition Start of map output partition ID range (included in range) * @param endPartition End of map output partition ID range (excluded from range) - * @param statuses List of map statuses, indexed by map partition index. + * @param mapStatuses List of map statuses, indexed by map ID. * @param startMapIndex Start Map index. * @param endMapIndex End Map index. + * @param mergeStatuses List of merge statuses, index by reduce ID. * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, * and the second item is a sequence of (shuffle block id, shuffle block size, map index) * tuples describing the shuffle blocks that are stored at that block manager. @@ -987,18 +1220,46 @@ private[spark] object MapOutputTracker extends Logging { shuffleId: Int, startPartition: Int, endPartition: Int, - statuses: Array[MapStatus], + mapStatuses: Array[MapStatus], startMapIndex : Int, - endMapIndex: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { - assert (statuses != null) + endMapIndex: Int, + mergeStatuses: Option[Array[MergeStatus]] = None): + Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + assert (mapStatuses != null) val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]] - val iter = statuses.iterator.zipWithIndex - for ((status, mapIndex) <- iter.slice(startMapIndex, endMapIndex)) { - if (status == null) { - val errorMessage = s"Missing an output location for shuffle $shuffleId" - logError(errorMessage) - throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage) - } else { + if (mergeStatuses.isDefined && startMapIndex == 0 && endMapIndex == mapStatuses.length) { + // We have MergeStatus and full range of mapIds are requested so return a merged block. + val numMaps = mapStatuses.length + mergeStatuses.get.zipWithIndex.slice(startPartition, endPartition).foreach { + case (mergeStatus, partId) => + val remainingMapStatuses = if (mergeStatus != null) { + // If MergeStatus is available for the given partition, add location of the + // pre-merged shuffle partition for this partition ID + // TODO check recent upstream patch to see if there's a better way to handle + // TODO the dummy mapper ID here. For now, it is probably fine to use -1. + splitsByAddress.getOrElseUpdate(mergeStatus.location, ListBuffer()) += + ((ShuffleBlockId(shuffleId, -1, partId), mergeStatus.totalSize, -1)) + // For the "holes" in this pre-merged shuffle partition, i.e., unmerged mapper + // shuffle partition blocks, fall back to fetching the original map produced + // shuffle partition blocks + mergeStatus.getMissingMaps(numMaps).map(mapStatuses.zipWithIndex) + } else { + // If MergeStatus is not available for the given partition, fall back to + // fetching all the original mapper shuffle partition blocks + mapStatuses.zipWithIndex.toSeq + } + // Add location for the mapper shuffle partition blocks + for ((mapStatus, mapIndex) <- remainingMapStatuses) { + validateStatus(mapStatus, shuffleId, partId) + splitsByAddress.getOrElseUpdate(mapStatus.location, ListBuffer()) += + ((ShuffleBlockId(shuffleId, mapStatus.mapId, partId), + mapStatus.getSizeForBlock(partId), mapIndex)) + } + } + } else { + val iter = mapStatuses.iterator.zipWithIndex + for ((status, mapIndex) <- iter.slice(startMapIndex, endMapIndex)) { + validateStatus(status, shuffleId, startPartition) for (part <- startPartition until endPartition) { val size = status.getSizeForBlock(part) if (size != 0) { @@ -1008,7 +1269,62 @@ private[spark] object MapOutputTracker extends Logging { } } } - splitsByAddress.mapValues(_.toSeq).iterator } + + private def validateStatus(status: OutputStatus, shuffleId: Int, partition: Int) : Unit = { + if (status == null) { + val errorMessage = s"Missing an output location for shuffle $shuffleId" + logError(errorMessage) + throw new MetadataFetchFailedException(shuffleId, partition, errorMessage) + } + } + + /** + * Given a partition ID, an array of map statuses, and an array of merge statuses, identify + * the metadata about the shuffle partition blocks that are merged into the shuffle block + * represented by the merge status in the array specified by the given partition ID. + * + * @param shuffleId Identifier for the shuffle + * @param partitionId The partition ID of the MergeStatus for which we look for the metadata + * of the merged shuffle partition blocks + * @param mapStatuses List of map statuses, indexed by map ID + * @param mergeStatuses List of merge statuses, index by reduce ID + * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, + * and the second item is a sequence of (shuffle block ID, shuffle block size) tuples + * describing the shuffle blocks that are stored at that block manager. + */ + def getMapStatusesForMergeStatus( + shuffleId: Int, + partitionId: Int, + mapStatuses: Array[MapStatus], + mergeStatuses: Array[MergeStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + assert (mapStatuses != null && mergeStatuses != null) + val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long)]] + val mergeStatus = mergeStatuses(partitionId) + // The original MergeStatus is no longer available, we cannot identify the list of + // unmerged blocks to fetch in this case. Throw MetadataFetchFailedException. + if (mergeStatus == null) { + val errorMessage = s"Missing an output location for shuffle $shuffleId " + + s"for merge block backup" + logError(errorMessage) + throw new MetadataFetchFailedException(shuffleId, partitionId, errorMessage) + } + for ((status, mapId) <- mapStatuses.zipWithIndex) { + // Only add blocks that are merged + if (mergeStatus.tracker.contains(mapId)) { + if (status == null) { + val errorMessage = s"Missing an output location for shuffle $shuffleId " + + s"for merge block backup" + logError(errorMessage) + throw new MetadataFetchFailedException(shuffleId, partitionId, errorMessage) + } else { + splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) += + ((ShuffleBlockId(shuffleId, mapId, partitionId), status.getSizeForBlock(partitionId))) + } + } + } + splitsByAddress.toSeq + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 1239c32cee3a..df40bec6dede 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -28,13 +28,19 @@ import org.apache.spark.internal.config import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils +/** + * A common trait between [[MapStatus]] and [[MergeStatus]]. This allows us to reuse existing + * code to handle MergeStatus. + */ +private[spark] trait OutputStatus + /** * Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the * task has shuffle files stored on as well as the sizes of outputs for each reducer, for passing * on to the reduce tasks. */ -private[spark] sealed trait MapStatus { - /** Location where this task output is. */ +private[spark] sealed trait MapStatus extends OutputStatus { + /** Location where this task was run. */ def location: BlockManagerId def updateLocation(newLoc: BlockManagerId): Unit diff --git a/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala new file mode 100644 index 000000000000..f0b988ee0335 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.io.{Externalizable, ObjectInput, ObjectOutput} + +import org.roaringbitmap.RoaringBitmap + +import org.apache.spark.network.shuffle.protocol.MergeStatuses +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.Utils + +/** + * The status for the result of shuffle partition merge for each individual reducer partition + * maintained by the scheduler. The scheduler would separate the MergeStatuses received from + * ExternalShuffleService into individual MergeStatus which is maintained inside MapOutputTracker + * to be served to the reducers when they start fetching shuffle partition blocks. + */ +private[spark] class MergeStatus( + private[this] var loc: BlockManagerId, + private[this] var mapTracker: RoaringBitmap, + private[this] var size: Long) + extends Externalizable with OutputStatus { + + protected def this() = this(null, null, -1) // For deserialization only + + def location: BlockManagerId = loc + + def totalSize: Long = size + + def tracker: RoaringBitmap = mapTracker + + /** + * Get the list of mapper IDs for missing mapper partition blocks that are not merged. + * The reducer will use this information to decide which shuffle partition blocks to + * fetch in the original way. + */ + def getMissingMaps(numMaps: Int): Seq[Int] = { + (0 until numMaps).filter(i => !mapTracker.contains(i)) + } + + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { + loc.writeExternal(out) + mapTracker.writeExternal(out) + out.writeLong(size) + } + + override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { + loc = BlockManagerId(in) + mapTracker = new RoaringBitmap() + mapTracker.readExternal(in) + size = in.readLong() + } +} + +private[spark] object MergeStatus { + + /** + * Separate a MergeStatuses received from an ExternalShuffleService into individual + * MergeStatus. The scheduler is responsible for providing the location information + * for the given ExternalShuffleService. + */ + def convertMergeStatusesToMergeStatusArr( + mergeStatuses: MergeStatuses, + loc: BlockManagerId): Seq[(Int, MergeStatus)] = { + val mergerLoc = BlockManagerId("", loc.host, loc.port) + mergeStatuses.bitmaps.zipWithIndex.map { + case (bitmap, index) => + val mergeStatus = new MergeStatus(mergerLoc, bitmap, mergeStatuses.sizes(index)) + (mergeStatuses.reduceIds(index), mergeStatus) + } + } + + def apply(loc: BlockManagerId, bitmap: RoaringBitmap, size: Long): MergeStatus = { + new MergeStatus(loc, bitmap, size) + } +} diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 20b040f7c810..7dc707dc6da0 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -21,13 +21,14 @@ import scala.collection.mutable.ArrayBuffer import org.mockito.ArgumentMatchers.any import org.mockito.Mockito._ +import org.roaringbitmap.RoaringBitmap import org.apache.spark.LocalSparkContext._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.internal.config._ import org.apache.spark.internal.config.Network.{RPC_ASK_TIMEOUT, RPC_MESSAGE_MAX_SIZE} import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv} -import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} +import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus, MergeStatus} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} @@ -332,4 +333,63 @@ class MapOutputTrackerSuite extends SparkFunSuite { rpcEnv.shutdown() } + test("master register and unregister merge result") { + val rpcEnv = createRpcEnv("test") + val tracker = newTrackerMaster() + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + tracker.registerShuffle(10, 4, 2) + assert(tracker.containsShuffle(10)) + val bitmap = new RoaringBitmap() + bitmap.add(0) + bitmap.add(1) + + tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000), + bitmap, 1000L)) + tracker.registerMergeResult(10, 1, MergeStatus(BlockManagerId("b", "hostB", 1000), + bitmap, 1000L)) + assert(tracker.getNumAvailableMergeResults(10) == 2) + tracker.unregisterMergeResult(10, 0, BlockManagerId("a", "hostA", 1000)); + assert(tracker.getNumAvailableMergeResults(10) == 1) + tracker.stop() + rpcEnv.shutdown() + } + + test("remote fetch with merged shuffle result") { + conf.set("spark.shuffle.push.based.enabled", "true") + conf.set("spark.shuffle.service.enabled", "true") + val hostname = "localhost" + val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf)) + + val masterTracker = newTrackerMaster() + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) + + val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(conf)) + val slaveTracker = new MapOutputTrackerWorker(conf) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + + masterTracker.registerShuffle(10, 1, 1) + slaveTracker.updateEpoch(masterTracker.getEpoch) + // This is expected to fail because no outputs have been registered for the shuffle. + intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } + val bitmap = new RoaringBitmap() + bitmap.add(0) + masterTracker.registerMapOutput(10, 0, MapStatus( + BlockManagerId("a", "hostA", 1000), Array(1000L), 0)) + masterTracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000), + bitmap, 1000L)) + slaveTracker.updateEpoch(masterTracker.getEpoch) + assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === + Seq((BlockManagerId("a", "hostA", 1000), Seq((ShuffleBlockId(10, -1, 0), 1000L, -1))))) + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + assert(slaveTracker.getMapSizesForMergeResult(10, 0) === + Seq((BlockManagerId("a", "hostA", 1000), Seq((ShuffleBlockId(10, 0, 0), size1000)))) + ) + masterTracker.stop() + slaveTracker.stop() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() + } } diff --git a/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala b/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala index 78f1246295bf..1af1e0f778cf 100644 --- a/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala @@ -68,7 +68,7 @@ object MapStatusesSerDeserBenchmark extends BenchmarkBase { var serializedMapStatusSizes = 0 var serializedBroadcastSizes = 0 - val (serializedMapStatus, serializedBroadcast) = MapOutputTracker.serializeMapStatuses( + val (serializedMapStatus, serializedBroadcast) = MapOutputTracker.serializeOutputStatuses( shuffleStatus.mapStatuses, tracker.broadcastManager, tracker.isLocal, minBroadcastSize, sc.getConf) serializedMapStatusSizes = serializedMapStatus.length @@ -77,12 +77,12 @@ object MapStatusesSerDeserBenchmark extends BenchmarkBase { } benchmark.addCase("Serialization") { _ => - MapOutputTracker.serializeMapStatuses(shuffleStatus.mapStatuses, tracker.broadcastManager, + MapOutputTracker.serializeOutputStatuses(shuffleStatus.mapStatuses, tracker.broadcastManager, tracker.isLocal, minBroadcastSize, sc.getConf) } benchmark.addCase("Deserialization") { _ => - val result = MapOutputTracker.deserializeMapStatuses(serializedMapStatus, sc.getConf) + val result = MapOutputTracker.deserializeOutputStatuses(serializedMapStatus, sc.getConf) assert(result.length == numMaps) } From 81e6a5b64a7426372a2b866c2962eab9b3a03ff1 Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Tue, 17 Nov 2020 21:33:25 -0800 Subject: [PATCH 02/24] LIHADOOP-53321 Magnet: Merge client shuffle block fetcher related changes --- .../org/apache/spark/MapOutputTracker.scala | 28 ++++++++----------- .../apache/spark/MapOutputTrackerSuite.scala | 2 +- 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index e40051cdbb4b..fac28aecef35 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -469,7 +469,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging */ def getMapSizesForMergeResult( shuffleId: Int, - partitionId: Int): Seq[(BlockManagerId, Seq[(BlockId, Long)])] + partitionId: Int): Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])] /** * Deletes map output status information for the specified shuffle stage. @@ -957,7 +957,7 @@ private[spark] class MapOutputTrackerMaster( // enabled in local-mode, this method returns empty list. def getMapSizesForMergeResult( shuffleId: Int, - partitionId: Int): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + partitionId: Int): Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { Seq.empty } @@ -1029,7 +1029,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr */ override def getMapSizesForMergeResult( shuffleId: Int, - partitionId: Int): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + partitionId: Int): Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition $partitionId") // Fetch the map statuses and merge statuses again since they might have already been // cleared by another task running in the same executor. @@ -1298,9 +1298,9 @@ private[spark] object MapOutputTracker extends Logging { shuffleId: Int, partitionId: Int, mapStatuses: Array[MapStatus], - mergeStatuses: Array[MergeStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + mergeStatuses: Array[MergeStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { assert (mapStatuses != null && mergeStatuses != null) - val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long)]] + val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]] val mergeStatus = mergeStatuses(partitionId) // The original MergeStatus is no longer available, we cannot identify the list of // unmerged blocks to fetch in this case. Throw MetadataFetchFailedException. @@ -1310,18 +1310,12 @@ private[spark] object MapOutputTracker extends Logging { logError(errorMessage) throw new MetadataFetchFailedException(shuffleId, partitionId, errorMessage) } - for ((status, mapId) <- mapStatuses.zipWithIndex) { - // Only add blocks that are merged - if (mergeStatus.tracker.contains(mapId)) { - if (status == null) { - val errorMessage = s"Missing an output location for shuffle $shuffleId " + - s"for merge block backup" - logError(errorMessage) - throw new MetadataFetchFailedException(shuffleId, partitionId, errorMessage) - } else { - splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) += - ((ShuffleBlockId(shuffleId, mapId, partitionId), status.getSizeForBlock(partitionId))) - } + for ((status, mapIndex) <- mapStatuses.zipWithIndex) { + if (mergeStatus.tracker.contains(mapIndex)) { + validateStatus(status, shuffleId, partitionId) + splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) += + ((ShuffleBlockId(shuffleId, status.mapId, partitionId), + status.getSizeForBlock(partitionId), mapIndex)) } } splitsByAddress.toSeq diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 7dc707dc6da0..7d87a7a151da 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -385,7 +385,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { Seq((BlockManagerId("a", "hostA", 1000), Seq((ShuffleBlockId(10, -1, 0), 1000L, -1))))) val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) assert(slaveTracker.getMapSizesForMergeResult(10, 0) === - Seq((BlockManagerId("a", "hostA", 1000), Seq((ShuffleBlockId(10, 0, 0), size1000)))) + Seq((BlockManagerId("a", "hostA", 1000), Seq((ShuffleBlockId(10, 0, 0), size1000, 0)))) ) masterTracker.stop() slaveTracker.stop() From e1cc409f5caa6b0571d6cab03300e9bda88da460 Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Fri, 12 Jun 2020 22:27:58 -0700 Subject: [PATCH 03/24] LIHADOOP-54115 Unregister map and merge outputs on the host when DAG scheduler encounters a shuffle chunk failure RB=2151376 BUG=LIHADOOP-54115 G=spark-reviewers R=yezhou,mshen A=mshen --- .../apache/spark/MapOutputTrackerSuite.scala | 31 +++++++++++++------ 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 7d87a7a151da..9cb5b4897707 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -355,7 +355,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { rpcEnv.shutdown() } - test("remote fetch with merged shuffle result") { + test("get map statuses from merged shuffle") { conf.set("spark.shuffle.push.based.enabled", "true") conf.set("spark.shuffle.service.enabled", "true") val hostname = "localhost" @@ -370,23 +370,34 @@ class MapOutputTrackerSuite extends SparkFunSuite { slaveTracker.trackerEndpoint = slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - masterTracker.registerShuffle(10, 1, 1) + masterTracker.registerShuffle(10, 4, 1) slaveTracker.updateEpoch(masterTracker.getEpoch) // This is expected to fail because no outputs have been registered for the shuffle. intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } val bitmap = new RoaringBitmap() bitmap.add(0) - masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("a", "hostA", 1000), Array(1000L), 0)) - masterTracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000), - bitmap, 1000L)) + bitmap.add(1) + bitmap.add(2) + bitmap.add(3) + + val blockMgrId = BlockManagerId("a", "hostA", 1000) + masterTracker.registerMapOutput( + 10, 0, MapStatus(blockMgrId, Array(1000L), 0)) + masterTracker.registerMapOutput( + 10, 1, MapStatus(blockMgrId, Array(1000L), 1)) + masterTracker.registerMapOutput( + 10, 2, MapStatus(blockMgrId, Array(1000L), 2)) + masterTracker.registerMapOutput( + 10, 3, MapStatus(blockMgrId, Array(1000L), 3)) + + masterTracker.registerMergeResult(10, 0, MergeStatus(blockMgrId, + bitmap, 4000L)) slaveTracker.updateEpoch(masterTracker.getEpoch) - assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), Seq((ShuffleBlockId(10, -1, 0), 1000L, -1))))) val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) assert(slaveTracker.getMapSizesForMergeResult(10, 0) === - Seq((BlockManagerId("a", "hostA", 1000), Seq((ShuffleBlockId(10, 0, 0), size1000, 0)))) - ) + Seq((blockMgrId, ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000, 0), + (ShuffleBlockId(10, 1, 0), size1000, 1), (ShuffleBlockId(10, 2, 0), size1000, 2), + (ShuffleBlockId(10, 3, 0), size1000, 3))))) masterTracker.stop() slaveTracker.stop() rpcEnv.shutdown() From bfd3a739540e4c508b8454038b5b283072810a7a Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Tue, 17 Nov 2020 22:28:40 -0800 Subject: [PATCH 04/24] LIHADOOP-52494 Magnet fallback to origin shuffle blocks when fetch of a shuffle chunk fails --- .../org/apache/spark/MapOutputTracker.scala | 126 ++++++++++++------ .../apache/spark/MapOutputTrackerSuite.scala | 40 +++++- 2 files changed, 122 insertions(+), 44 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index fac28aecef35..16a05203ca2d 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -29,13 +29,14 @@ import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.commons.io.output.{ByteArrayOutputStream => ApacheByteArrayOutputStream} +import org.roaringbitmap.RoaringBitmap import org.apache.spark.broadcast.{Broadcast, BroadcastManager} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.io.CompressionCodec import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} -import org.apache.spark.scheduler.{ExecutorCacheTaskLocation, MapStatus, MergeStatus, OutputStatus} +import org.apache.spark.scheduler.{MapStatus, MergeStatus, OutputStatus} import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util._ @@ -136,19 +137,14 @@ private class ShuffleStatus(numPartitions: Int, numReducers: Int) extends Loggin * Update the map output location (e.g. during migration). */ def updateMapOutput(mapId: Long, bmAddress: BlockManagerId): Unit = withWriteLock { - try { - val mapStatusOpt = mapStatuses.find(_.mapId == mapId) - mapStatusOpt match { - case Some(mapStatus) => - logInfo(s"Updating map output for ${mapId} to ${bmAddress}") - mapStatus.updateLocation(bmAddress) - invalidateSerializedMapOutputStatusCache() - case None => - logWarning(s"Asked to update map output ${mapId} for untracked map status.") - } - } catch { - case e: java.lang.NullPointerException => - logWarning(s"Unable to update map output for ${mapId}, status removed in-flight") + val mapStatusOpt = mapStatuses.find(_.mapId == mapId) + mapStatusOpt match { + case Some(mapStatus) => + logInfo(s"Updating map output for ${mapId} to ${bmAddress}") + mapStatus.updateLocation(bmAddress) + invalidateSerializedMapOutputStatusCache() + case None => + logError(s"Asked to update map output ${mapId} for untracked map status.") } } @@ -471,6 +467,17 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging shuffleId: Int, partitionId: Int): Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])] + /** + * Called from executors upon fetch failure on a merged shuffle block chunk. This is to get the + * server URIs and output sizes for each shuffle block that is merged in the specified merged + * shuffle block chunk so fetch failure on a merged shuffle block chunk can fall back to fetching + * the unmerged blocks. + */ + def getMapSizesForMergeResult( + shuffleId: Int, + partitionId: Int, + chunkBitmap: RoaringBitmap): Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])] + /** * Deletes map output status information for the specified shuffle stage. */ @@ -527,7 +534,7 @@ private[spark] class MapOutputTrackerMaster( // requests for map/merge output statuses private val outputStatusesRequests = new LinkedBlockingQueue[GetOutputStatusesMessage] - private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(conf) + private val pushBasedShuffleEnabled = Utils.isPushShuffleEnabled(conf) // Thread pool used for handling map output status requests. This is a separate thread pool // to ensure we don't block the normal dispatcher threads. @@ -961,6 +968,15 @@ private[spark] class MapOutputTrackerMaster( Seq.empty } + // This method is only called in local-mode. Since push based shuffle won't be + // enabled in local-mode, this method returns empty list. + def getMapSizesForMergeResult( + shuffleId: Int, + partitionId: Int, + chunkTracker: RoaringBitmap): Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + Seq.empty + } + override def stop(): Unit = { outputStatusesRequests.offer(PoisonPill) threadpool.shutdown() @@ -989,7 +1005,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr val mergeStatuses: Map[Int, Array[MergeStatus]] = new ConcurrentHashMap[Int, Array[MergeStatus]]().asScala - private val fetchMergeResult = Utils.isPushBasedShuffleEnabled(conf) + private val fetchMergeResult = Utils.isPushShuffleEnabled(conf) /** * A [[KeyLock]] whose key is a shuffle id to ensure there is only one thread fetching @@ -1033,10 +1049,40 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition $partitionId") // Fetch the map statuses and merge statuses again since they might have already been // cleared by another task running in the same executor. - val (mapOutputStatuses, mergeResultStatues) = getStatuses(shuffleId, conf) + val (mapOutputStatuses, mergeResultStatuses) = getStatuses(shuffleId, conf) try { + val mergeStatus = mergeResultStatuses(partitionId) + // The original MergeStatus is no longer available, we cannot identify the list of + // unmerged blocks to fetch in this case. Throw MetadataFetchFailedException. + MapOutputTracker.validateStatus(mergeStatus, shuffleId, partitionId) MapOutputTracker.getMapStatusesForMergeStatus(shuffleId, partitionId, - mapOutputStatuses, mergeResultStatues) + mapOutputStatuses, mergeStatus.tracker) + } catch { + // We experienced a fetch failure so our mapStatuses cache is outdated; clear it: + case e: MetadataFetchFailedException => + mapStatuses.clear() + mergeStatuses.clear() + throw e + } + } + + /** + * Called from executors upon fetch failure on a merged shuffle block chunk. This is to get the + * server URIs and output sizes for each shuffle block that is merged in the specified merged + * shuffle block chunk so fetch failure on a merged shuffle block chunk can fall back to fetching + * the unmerged blocks. + */ + override def getMapSizesForMergeResult( + shuffleId: Int, + partitionId: Int, + chunkTracker: RoaringBitmap): Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition $partitionId") + // Fetch the map statuses and merge statuses again since they might have already been + // cleared by another task running in the same executor. + val (mapOutputStatuses, _) = getStatuses(shuffleId, conf) + try { + MapOutputTracker.getMapStatusesForMergeStatus(shuffleId, partitionId, mapOutputStatuses, + chunkTracker) } catch { // We experienced a fetch failure so our mapStatuses cache is outdated; clear it: case e: MetadataFetchFailedException => @@ -1251,9 +1297,11 @@ private[spark] object MapOutputTracker extends Logging { // Add location for the mapper shuffle partition blocks for ((mapStatus, mapIndex) <- remainingMapStatuses) { validateStatus(mapStatus, shuffleId, partId) - splitsByAddress.getOrElseUpdate(mapStatus.location, ListBuffer()) += - ((ShuffleBlockId(shuffleId, mapStatus.mapId, partId), - mapStatus.getSizeForBlock(partId), mapIndex)) + val size = mapStatus.getSizeForBlock(partId) + if (size != 0) { + splitsByAddress.getOrElseUpdate(mapStatus.location, ListBuffer()) += + ((ShuffleBlockId(shuffleId, mapStatus.mapId, partId), size, mapIndex)) + } } } } else { @@ -1272,14 +1320,6 @@ private[spark] object MapOutputTracker extends Logging { splitsByAddress.mapValues(_.toSeq).iterator } - private def validateStatus(status: OutputStatus, shuffleId: Int, partition: Int) : Unit = { - if (status == null) { - val errorMessage = s"Missing an output location for shuffle $shuffleId" - logError(errorMessage) - throw new MetadataFetchFailedException(shuffleId, partition, errorMessage) - } - } - /** * Given a partition ID, an array of map statuses, and an array of merge statuses, identify * the metadata about the shuffle partition blocks that are merged into the shuffle block @@ -1289,7 +1329,8 @@ private[spark] object MapOutputTracker extends Logging { * @param partitionId The partition ID of the MergeStatus for which we look for the metadata * of the merged shuffle partition blocks * @param mapStatuses List of map statuses, indexed by map ID - * @param mergeStatuses List of merge statuses, index by reduce ID + * @param tracker bitmap containing mapIds that belong to the merged block or merged block + * chunk. * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, * and the second item is a sequence of (shuffle block ID, shuffle block size) tuples * describing the shuffle blocks that are stored at that block manager. @@ -1298,21 +1339,13 @@ private[spark] object MapOutputTracker extends Logging { shuffleId: Int, partitionId: Int, mapStatuses: Array[MapStatus], - mergeStatuses: Array[MergeStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { - assert (mapStatuses != null && mergeStatuses != null) + tracker: RoaringBitmap): Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + assert (mapStatuses != null && tracker != null) val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]] - val mergeStatus = mergeStatuses(partitionId) - // The original MergeStatus is no longer available, we cannot identify the list of - // unmerged blocks to fetch in this case. Throw MetadataFetchFailedException. - if (mergeStatus == null) { - val errorMessage = s"Missing an output location for shuffle $shuffleId " + - s"for merge block backup" - logError(errorMessage) - throw new MetadataFetchFailedException(shuffleId, partitionId, errorMessage) - } for ((status, mapIndex) <- mapStatuses.zipWithIndex) { - if (mergeStatus.tracker.contains(mapIndex)) { - validateStatus(status, shuffleId, partitionId) + // Only add blocks that are merged + if (tracker.contains(mapIndex)) { + MapOutputTracker.validateStatus(status, shuffleId, partitionId) splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) += ((ShuffleBlockId(shuffleId, status.mapId, partitionId), status.getSizeForBlock(partitionId), mapIndex)) @@ -1321,4 +1354,11 @@ private[spark] object MapOutputTracker extends Logging { splitsByAddress.toSeq } + def validateStatus(status: OutputStatus, shuffleId: Int, partition: Int) : Unit = { + if (status == null) { + val errorMessage = s"Missing an output location for shuffle $shuffleId partition $partition" + logError(errorMessage) + throw new MetadataFetchFailedException(shuffleId, partition, errorMessage) + } + } } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 9cb5b4897707..f9bd75d5a75d 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -368,7 +368,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(conf)) val slaveTracker = new MapOutputTrackerWorker(conf) slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) masterTracker.registerShuffle(10, 4, 1) slaveTracker.updateEpoch(masterTracker.getEpoch) @@ -403,4 +403,42 @@ class MapOutputTrackerSuite extends SparkFunSuite { rpcEnv.shutdown() slaveRpcEnv.shutdown() } + + test("get map statuses for merged shuffle block chunks") { + conf.set("spark.shuffle.push.based.enabled", "true") + conf.set("spark.shuffle.service.enabled", "true") + val hostname = "localhost" + val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf)) + + val masterTracker = newTrackerMaster() + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) + + val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(conf)) + val slaveTracker = new MapOutputTrackerWorker(conf) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + + masterTracker.registerShuffle(10, 4, 1) + slaveTracker.updateEpoch(masterTracker.getEpoch) + + val blockMgrId = BlockManagerId("a", "hostA", 1000) + masterTracker.registerMapOutput(10, 0, MapStatus(blockMgrId, Array(1000L), 0)) + masterTracker.registerMapOutput(10, 1, MapStatus(blockMgrId, Array(1000L), 1)) + masterTracker.registerMapOutput(10, 2, MapStatus(blockMgrId, Array(1000L), 2)) + masterTracker.registerMapOutput(10, 3, MapStatus(blockMgrId, Array(1000L), 3)) + + val chunkBitmap = new RoaringBitmap() + chunkBitmap.add(0) + chunkBitmap.add(2) + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + assert(slaveTracker.getMapSizesForMergeResult(10, 0, chunkBitmap) === + Seq((blockMgrId, ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000, 0), + (ShuffleBlockId(10, 2, 0), size1000, 2)))) + ) + masterTracker.stop() + slaveTracker.stop() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() + } } From 2bf9502e9a79ff20d049bc3adcd0e1e5dedfdc73 Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Thu, 20 Aug 2020 00:14:03 -0700 Subject: [PATCH 05/24] Magnet: Serialization of merge status shoud use the reentrant readwrite lock --- .../org/apache/spark/MapOutputTracker.scala | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 16a05203ca2d..ca90abe17b1a 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -291,12 +291,25 @@ private class ShuffleStatus(numPartitions: Int, numReducers: Int) extends Loggin broadcastManager: BroadcastManager, isLocal: Boolean, minBroadcastSize: Int, - conf: SparkConf): Array[Byte] = synchronized { - if (cachedSerializedMergeStatus eq null) { - val serResult = MapOutputTracker.serializeOutputStatuses( + conf: SparkConf): Array[Byte] = { + var result: Array[Byte] = null + + withReadLock { + if (cachedSerializedMergeStatus != null) { + result = cachedSerializedMergeStatus + } + } + + if (result == null) withWriteLock { + if (cachedSerializedMergeStatus == null ) { + val serResult = MapOutputTracker.serializeOutputStatuses( mergeStatuses, broadcastManager, isLocal, minBroadcastSize, conf) - cachedSerializedMergeStatus = serResult._1 - cachedSerializedBroadcastMergeStatus = serResult._2 + cachedSerializedMergeStatus = serResult._1 + cachedSerializedBroadcastMergeStatus = serResult._2 + } + // The following line has to be outside if statement since it's possible that another thread + // initializes cachedSerializedMergeStatus in-between `withReadLock` and `withWriteLock`. + result = cachedSerializedMergeStatus } cachedSerializedMergeStatus } From 23975c41efe6aff374f935ac6e0c35a72f754c91 Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Fri, 20 Nov 2020 11:45:52 -0800 Subject: [PATCH 06/24] Fixed the compilation error in MOT --- core/src/main/scala/org/apache/spark/MapOutputTracker.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index ca90abe17b1a..6d588f76e654 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -547,7 +547,7 @@ private[spark] class MapOutputTrackerMaster( // requests for map/merge output statuses private val outputStatusesRequests = new LinkedBlockingQueue[GetOutputStatusesMessage] - private val pushBasedShuffleEnabled = Utils.isPushShuffleEnabled(conf) + private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(conf) // Thread pool used for handling map output status requests. This is a separate thread pool // to ensure we don't block the normal dispatcher threads. @@ -1018,7 +1018,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr val mergeStatuses: Map[Int, Array[MergeStatus]] = new ConcurrentHashMap[Int, Array[MergeStatus]]().asScala - private val fetchMergeResult = Utils.isPushShuffleEnabled(conf) + private val fetchMergeResult = Utils.isPushBasedShuffleEnabled(conf) /** * A [[KeyLock]] whose key is a shuffle id to ensure there is only one thread fetching From f51a80616d78d7cf0fe108b5dc8940608e499133 Mon Sep 17 00:00:00 2001 From: Min Shen Date: Sun, 22 Nov 2020 22:51:57 -0800 Subject: [PATCH 07/24] Prepare for PR --- .../org/apache/spark/MapOutputTracker.scala | 266 +++++++++--------- .../apache/spark/scheduler/MapStatus.scala | 2 +- .../apache/spark/scheduler/MergeStatus.scala | 21 +- .../apache/spark/MapOutputTrackerSuite.scala | 134 +++++++-- 4 files changed, 269 insertions(+), 154 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 6d588f76e654..d22ede4e28bf 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -87,24 +87,35 @@ private class ShuffleStatus(numPartitions: Int, numReducers: Int) extends Loggin // Exposed for testing val mapStatuses = new Array[MapStatus](numPartitions) + /** + * MergeStatus for each shuffle partition when push-based shuffle is enabled. The index of the + * array is the shuffle partition id (reduce id). Each value in the array is the MergeStatus for + * a shuffle partition, or null if not available. When push-based shuffle is enabled, this array + * provides a reducer oriented view of the shuffle status specifically for the results of + * merging shuffle partition blocks into per-partition merged shuffle files. + */ val mergeStatuses = new Array[MergeStatus](numReducers) /** * The cached result of serializing the map statuses array. This cache is lazily populated when - * [[serializedMapStatus]] is called. The cache is invalidated when map outputs are removed. + * [[serializedOutputStatus]] is called. The cache is invalidated when map outputs are removed. */ private[this] var cachedSerializedMapStatus: Array[Byte] = _ /** - * Broadcast variable holding serialized map output statuses array. When [[serializedMapStatus]] - * serializes the map statuses array it may detect that the result is too large to send in a - * single RPC, in which case it places the serialized array into a broadcast variable and then - * sends a serialized broadcast variable instead. This variable holds a reference to that - * broadcast variable in order to keep it from being garbage collected and to allow for it to be - * explicitly destroyed later on when the ShuffleMapStage is garbage-collected. + * Broadcast variable holding serialized map output statuses array. When + * [[serializedOutputStatus]] serializes the map statuses array it may detect that the result is + * too large to send in a single RPC, in which case it places the serialized array into a + * broadcast variable and then sends a serialized broadcast variable instead. This variable holds + * a reference to that broadcast variable in order to keep it from being garbage collected and + * to allow for it to be explicitly destroyed later on when the ShuffleMapStage is + * garbage-collected. */ private[this] var cachedSerializedBroadcast: Broadcast[Array[Byte]] = _ + /** + * Similar to cachedSerializedMapStatus and cachedSerializedBroadcast, but for MergeStatus. + */ private[this] var cachedSerializedMergeStatus: Array[Byte] = _ private[this] var cachedSerializedBroadcastMergeStatus: Broadcast[Array[Byte]] = _ @@ -114,10 +125,10 @@ private class ShuffleStatus(numPartitions: Int, numReducers: Int) extends Loggin * to avoid having to count the number of non-null entries in the `mapStatuses` array and should * be equivalent to`mapStatuses.count(_ ne null)`. */ - private[this] var _numAvailableOutputs: Int = 0 + private[this] var _numAvailableMapOutputs: Int = 0 /** - * Counter tracking the number of MergeStatus result received so far. + * Counter tracking the number of MergeStatus results received so far from the shuffle services. */ private[this] var _numAvailableMergeResults: Int = 0 @@ -127,7 +138,7 @@ private class ShuffleStatus(numPartitions: Int, numReducers: Int) extends Loggin */ def addMapOutput(mapIndex: Int, status: MapStatus): Unit = withWriteLock { if (mapStatuses(mapIndex) == null) { - _numAvailableOutputs += 1 + _numAvailableMapOutputs += 1 invalidateSerializedMapOutputStatusCache() } mapStatuses(mapIndex) = status @@ -137,14 +148,19 @@ private class ShuffleStatus(numPartitions: Int, numReducers: Int) extends Loggin * Update the map output location (e.g. during migration). */ def updateMapOutput(mapId: Long, bmAddress: BlockManagerId): Unit = withWriteLock { - val mapStatusOpt = mapStatuses.find(_.mapId == mapId) - mapStatusOpt match { - case Some(mapStatus) => - logInfo(s"Updating map output for ${mapId} to ${bmAddress}") - mapStatus.updateLocation(bmAddress) - invalidateSerializedMapOutputStatusCache() - case None => - logError(s"Asked to update map output ${mapId} for untracked map status.") + try { + val mapStatusOpt = mapStatuses.find(_.mapId == mapId) + mapStatusOpt match { + case Some(mapStatus) => + logInfo(s"Updating map output for ${mapId} to ${bmAddress}") + mapStatus.updateLocation(bmAddress) + invalidateSerializedMapOutputStatusCache() + case None => + logWarning(s"Asked to update map output ${mapId} for untracked map status.") + } + } catch { + case e: java.lang.NullPointerException => + logWarning(s"Unable to update map output for ${mapId}, status removed in-flight") } } @@ -156,7 +172,7 @@ private class ShuffleStatus(numPartitions: Int, numReducers: Int) extends Loggin def removeMapOutput(mapIndex: Int, bmAddress: BlockManagerId): Unit = withWriteLock { logDebug(s"Removing existing map output ${mapIndex} ${bmAddress}") if (mapStatuses(mapIndex) != null && mapStatuses(mapIndex).location == bmAddress) { - _numAvailableOutputs -= 1 + _numAvailableMapOutputs -= 1 mapStatuses(mapIndex) = null invalidateSerializedMapOutputStatusCache() } @@ -165,7 +181,7 @@ private class ShuffleStatus(numPartitions: Int, numReducers: Int) extends Loggin /** * Register a merge result. */ - def addMergeResult(reduceId: Int, status: MergeStatus): Unit = synchronized { + def addMergeResult(reduceId: Int, status: MergeStatus): Unit = withWriteLock { if (mergeStatuses(reduceId) == null) { _numAvailableMergeResults += 1 invalidateSerializedMergeOutputStatusCache() @@ -173,10 +189,12 @@ private class ShuffleStatus(numPartitions: Int, numReducers: Int) extends Loggin mergeStatuses(reduceId) = status } + // TODO support updateMergeResult for similar use cases as updateMapOutput + /** * Remove the merge result which was served by the specified block manager. */ - def removeMergeResult(reduceId: Int, bmAddress: BlockManagerId): Unit = synchronized { + def removeMergeResult(reduceId: Int, bmAddress: BlockManagerId): Unit = withWriteLock { if (mergeStatuses(reduceId) != null && mergeStatuses(reduceId).location == bmAddress) { _numAvailableMergeResults -= 1 mergeStatuses(reduceId) = null @@ -210,7 +228,7 @@ private class ShuffleStatus(numPartitions: Int, numReducers: Int) extends Loggin def removeOutputsByFilter(f: BlockManagerId => Boolean): Unit = withWriteLock { for (mapIndex <- mapStatuses.indices) { if (mapStatuses(mapIndex) != null && f(mapStatuses(mapIndex).location)) { - _numAvailableOutputs -= 1 + _numAvailableMapOutputs -= 1 mapStatuses(mapIndex) = null invalidateSerializedMapOutputStatusCache() } @@ -225,16 +243,17 @@ private class ShuffleStatus(numPartitions: Int, numReducers: Int) extends Loggin } /** - * Number of partitions that have shuffle outputs. + * Number of partitions that have shuffle map outputs. */ - def numAvailableOutputs: Int = withReadLock { - _numAvailableOutputs + def numAvailableMapOutputs: Int = withReadLock { + _numAvailableMapOutputs } /** - * Number of merged partitions that have already been finalized. + * Number of shuffle partitions that have already been merge finalized when push-based + * is enabled. */ - def numAvailableMergeResults: Int = synchronized { + def numAvailableMergeResults: Int = withReadLock { _numAvailableMergeResults } @@ -243,75 +262,65 @@ private class ShuffleStatus(numPartitions: Int, numReducers: Int) extends Loggin */ def findMissingPartitions(): Seq[Int] = withReadLock { val missing = (0 until numPartitions).filter(id => mapStatuses(id) == null) - assert(missing.size == numPartitions - _numAvailableOutputs, - s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}") + assert(missing.size == numPartitions - _numAvailableMapOutputs, + s"${missing.size} missing, expected ${numPartitions - _numAvailableMapOutputs}") missing } /** - * Serializes the mapStatuses array into an efficient compressed format. See the comments on - * `MapOutputTracker.serializeMapStatuses()` for more details on the serialization format. + * Serializes the mapStatuses or mergeStatuses array into an efficient compressed format. See + * the comments on `MapOutputTracker.serializeOutputStatuses()` for more details on the + * serialization format. * * This method is designed to be called multiple times and implements caching in order to speed * up subsequent requests. If the cache is empty and multiple threads concurrently attempt to - * serialize the map statuses then serialization will only be performed in a single thread and all - * other threads will block until the cache is populated. + * serialize the statuses array then serialization will only be performed in a single thread and + * all other threads will block until the cache is populated. */ - def serializedMapStatus( + def serializedOutputStatus( broadcastManager: BroadcastManager, isLocal: Boolean, minBroadcastSize: Int, - conf: SparkConf): Array[Byte] = { + conf: SparkConf, + isMapOutput: Boolean): Array[Byte] = { var result: Array[Byte] = null withReadLock { - if (cachedSerializedMapStatus != null) { - result = cachedSerializedMapStatus + if (isMapOutput) { + if (cachedSerializedMapStatus != null) { + result = cachedSerializedMapStatus + } + } else { + if (cachedSerializedMergeStatus != null) { + result = cachedSerializedMergeStatus + } } } if (result == null) withWriteLock { - if (cachedSerializedMapStatus == null) { - val serResult = MapOutputTracker.serializeOutputStatuses( - mapStatuses, broadcastManager, isLocal, minBroadcastSize, conf) - cachedSerializedMapStatus = serResult._1 - cachedSerializedBroadcast = serResult._2 - } - // The following line has to be outside if statement since it's possible that another thread - // initializes cachedSerializedMapStatus in-between `withReadLock` and `withWriteLock`. - result = cachedSerializedMapStatus - } - result - } - - /** - * Serializes the mergeStatus array into an efficient compressed format. - */ - def serializedMergeStatus( - broadcastManager: BroadcastManager, - isLocal: Boolean, - minBroadcastSize: Int, - conf: SparkConf): Array[Byte] = { - var result: Array[Byte] = null - - withReadLock { - if (cachedSerializedMergeStatus != null) { + if (isMapOutput) { + if (cachedSerializedMapStatus == null) { + val serResult = MapOutputTracker.serializeOutputStatuses( + mapStatuses, broadcastManager, isLocal, minBroadcastSize, conf) + cachedSerializedMapStatus = serResult._1 + cachedSerializedBroadcast = serResult._2 + } + // The following line has to be outside if statement since it's possible that another + // thread initializes cachedSerializedMapStatus in-between `withReadLock` and + // `withWriteLock`. + result = cachedSerializedMapStatus + } else { + if (cachedSerializedMergeStatus == null) { + val serResult = MapOutputTracker.serializeOutputStatuses( + mergeStatuses, broadcastManager, isLocal, minBroadcastSize, conf) + cachedSerializedMergeStatus = serResult._1 + cachedSerializedBroadcastMergeStatus = serResult._2 + } + // The following line has to be outside if statement for similar reasons as above. result = cachedSerializedMergeStatus } } - - if (result == null) withWriteLock { - if (cachedSerializedMergeStatus == null ) { - val serResult = MapOutputTracker.serializeOutputStatuses( - mergeStatuses, broadcastManager, isLocal, minBroadcastSize, conf) - cachedSerializedMergeStatus = serResult._1 - cachedSerializedBroadcastMergeStatus = serResult._2 - } - // The following line has to be outside if statement since it's possible that another thread - // initializes cachedSerializedMergeStatus in-between `withReadLock` and `withWriteLock`. - result = cachedSerializedMergeStatus - } - cachedSerializedMergeStatus + result } // Used in testing. @@ -327,7 +336,7 @@ private class ShuffleStatus(numPartitions: Int, numReducers: Int) extends Loggin f(mapStatuses) } - def withMergeStatuses[T](f: Array[MergeStatus] => T): T = synchronized { + def withMergeStatuses[T](f: Array[MergeStatus] => T): T = withReadLock { f(mergeStatuses) } @@ -350,12 +359,12 @@ private class ShuffleStatus(numPartitions: Int, numReducers: Int) extends Loggin /** * Clears the cached serialized merge result statuses. */ - def invalidateSerializedMergeOutputStatusCache(): Unit = synchronized { + def invalidateSerializedMergeOutputStatusCache(): Unit = withWriteLock { if (cachedSerializedBroadcastMergeStatus != null) { Utils.tryLogNonFatalError { // Use `blocking = false` so that this operation doesn't hang while trying to send cleanup // RPCs to dead executors. - cachedSerializedBroadcastMergeStatus.destroy(blocking = false) + cachedSerializedBroadcastMergeStatus.destroy() } cachedSerializedBroadcastMergeStatus = null } @@ -388,12 +397,12 @@ private[spark] class MapOutputTrackerMasterEndpoint( case GetMapOutputStatuses(shuffleId: Int) => val hostPort = context.senderAddress.hostPort logInfo(s"Asked to send map output locations for shuffle ${shuffleId} to ${hostPort}") - tracker.post(GetOutputStatusesMessage(shuffleId, true, context)) + tracker.post(GetOutputStatusesMessage(shuffleId, fetchMapOutput = true, context)) case GetMergeResultStatuses(shuffleId: Int) => val hostPort = context.senderAddress.hostPort logInfo(s"Asked to send merge result locations for shuffle $shuffleId to $hostPort") - tracker.post(GetOutputStatusesMessage(shuffleId, false, context)) + tracker.post(GetOutputStatusesMessage(shuffleId, fetchMapOutput = false, context)) case StopMapOutputTracker => logInfo("MapOutputTrackerMasterEndpoint stopped!") @@ -472,19 +481,21 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] /** - * Called from executors upon fetch failure on a merged shuffle block. This is to get the server - * URIs and output sizes for each shuffle block that is merged in the specified merged shuffle - * block so fetch failure on a merged shuffle block can fall back to fetching the unmerged blocks. + * Called from executors upon fetch failure on an entire merged shuffle partition. Such failures + * can happen if the shuffle client fails to fetch the metadata for the given merged shuffle + * partition. This method is to get the server URIs and output sizes for each shuffle block that + * is merged in the specified merged shuffle block so fetch failure on a merged shuffle block can + * fall back to fetching the unmerged blocks. */ def getMapSizesForMergeResult( shuffleId: Int, partitionId: Int): Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])] /** - * Called from executors upon fetch failure on a merged shuffle block chunk. This is to get the - * server URIs and output sizes for each shuffle block that is merged in the specified merged - * shuffle block chunk so fetch failure on a merged shuffle block chunk can fall back to fetching - * the unmerged blocks. + * Called from executors upon fetch failure on a merged shuffle partition chunk. This is to get + * the server URIs and output sizes for each shuffle block that is merged in the specified merged + * shuffle partition chunk so fetch failure on a merged shuffle block chunk can fall back to + * fetching the unmerged blocks. */ def getMapSizesForMergeResult( shuffleId: Int, @@ -532,11 +543,6 @@ private[spark] class MapOutputTrackerMaster( // can be read locally, but may lead to more delay in scheduling if those locations are busy. private val REDUCER_PREF_LOCS_FRACTION = 0.2 - // Fraction of map outputs that must be merged at one location for it to be considered as - // a preferred location for a reduce task. To avoid computation overhead, the fraction is - // w.r.t the number of blocks instead of block sizes. - private val MERGE_REF_LOCS_FRACTION = 0.6 - // HashMap for storing shuffleStatuses in the driver. // Statuses are dropped only by explicit de-registering. // Exposed for testing @@ -594,14 +600,14 @@ private[spark] class MapOutputTrackerMaster( logDebug("Handling request to send map output locations for shuffle " + shuffleId + " to " + hostPort) context.reply( - shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast, - conf)) + shuffleStatus.serializedOutputStatus(broadcastManager, isLocal, minSizeForBroadcast, + conf, isMapOutput = true)) } else { logDebug("Handling request to send merge output locations for shuffle " + shuffleId + " to " + hostPort) context.reply( - shuffleStatus.serializedMergeStatus(broadcastManager, isLocal, minSizeForBroadcast, - conf)) + shuffleStatus.serializedOutputStatus(broadcastManager, isLocal, minSizeForBroadcast, + conf, isMapOutput = false)) } } catch { case NonFatal(e) => logError(e.getMessage, e) @@ -677,11 +683,13 @@ private[spark] class MapOutputTrackerMaster( shuffleStatuses.get(shuffleId) match { case Some(shuffleStatus) => shuffleStatus.removeMergeResult(reduceId, bmAddress) + // Here we share the same epoch for both map outputs and merge results. This means + // that even if we are only unregistering map output, this would also clear the executor + // side cached merge statuses and lead to executors re-fetching the merge statuses which + // hasn't changed, and vise versa. This is a reasonable compromise to prevent complicating + // how the epoch is currently used and to make sure the executor is always working with + // a pair of matching map statuses and merge statuses for each shuffle. incrementEpoch() - // TODO how to deal with epoch? Do we need separate epochs for map output and merge result - // TODO right now the reducer uses the epoch to decide whether their local cached - // TODO map output status is outdated or not. Ideally, updating merge result shouldn't - // TODO impact map output status. case None => throw new SparkException("unregisterMergeResult called for nonexistent shuffle ID") } @@ -718,7 +726,7 @@ private[spark] class MapOutputTrackerMaster( def containsShuffle(shuffleId: Int): Boolean = shuffleStatuses.contains(shuffleId) def getNumAvailableOutputs(shuffleId: Int): Int = { - shuffleStatuses.get(shuffleId).map(_.numAvailableOutputs).getOrElse(0) + shuffleStatuses.get(shuffleId).map(_.numAvailableMapOutputs).getOrElse(0) } def getNumAvailableMergeResults(shuffleId: Int): Int = { @@ -733,9 +741,6 @@ private[spark] class MapOutputTrackerMaster( shuffleStatuses.get(shuffleId).map(_.findMissingPartitions()) } - // TODO do we need to provide a method to calculate missing maps for a MergeStatus of a - // TODO given shuffle? - /** * Grouped function of Range, this is to avoid traverse of all elements of Range using * IterableLike's grouped function. @@ -949,7 +954,6 @@ private[spark] class MapOutputTrackerMaster( return epoch } } - // TODO should we have a separate epoch for merge status? // This method is only called in local-mode. def getMapSizesByExecutorId( @@ -975,7 +979,7 @@ private[spark] class MapOutputTrackerMaster( // This method is only called in local-mode. Since push based shuffle won't be // enabled in local-mode, this method returns empty list. - def getMapSizesForMergeResult( + override def getMapSizesForMergeResult( shuffleId: Int, partitionId: Int): Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { Seq.empty @@ -983,7 +987,7 @@ private[spark] class MapOutputTrackerMaster( // This method is only called in local-mode. Since push based shuffle won't be // enabled in local-mode, this method returns empty list. - def getMapSizesForMergeResult( + override def getMapSizesForMergeResult( shuffleId: Int, partitionId: Int, chunkTracker: RoaringBitmap): Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { @@ -1051,11 +1055,6 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr } } - /** - * Called from executors upon fetch failure on a merged shuffle block. This is to get the server - * URIs and output sizes for each shuffle block that is merged in the specified merged shuffle - * block so fetch failure on a merged shuffle block can fall back to fetching the unmerged blocks. - */ override def getMapSizesForMergeResult( shuffleId: Int, partitionId: Int): Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { @@ -1065,9 +1064,10 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr val (mapOutputStatuses, mergeResultStatuses) = getStatuses(shuffleId, conf) try { val mergeStatus = mergeResultStatuses(partitionId) - // The original MergeStatus is no longer available, we cannot identify the list of - // unmerged blocks to fetch in this case. Throw MetadataFetchFailedException. + // If the original MergeStatus is no longer available, we cannot identify the list of + // unmerged blocks to fetch in this case. Throw MetadataFetchFailedException in this case. MapOutputTracker.validateStatus(mergeStatus, shuffleId, partitionId) + // User the MergeStatus's partition level bitmap since we are doing partition level fallback MapOutputTracker.getMapStatusesForMergeStatus(shuffleId, partitionId, mapOutputStatuses, mergeStatus.tracker) } catch { @@ -1079,12 +1079,6 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr } } - /** - * Called from executors upon fetch failure on a merged shuffle block chunk. This is to get the - * server URIs and output sizes for each shuffle block that is merged in the specified merged - * shuffle block chunk so fetch failure on a merged shuffle block chunk can fall back to fetching - * the unmerged blocks. - */ override def getMapSizesForMergeResult( shuffleId: Int, partitionId: Int, @@ -1221,7 +1215,7 @@ private[spark] object MapOutputTracker extends Logging { } } - // Opposite of serializeMapStatuses. + // Opposite of serializeOutputStatuses. def deserializeOutputStatuses[T <: OutputStatus]( bytes: Array[Byte], conf: SparkConf): Array[T] = { assert (bytes.length > 0) @@ -1261,13 +1255,16 @@ private[spark] object MapOutputTracker extends Logging { * stored at that block manager. * Note that empty blocks are filtered in the result. * + * If push-based shuffle is enabled and an array of merge statuses is available, prioritize + * the locations of the merged shuffle partitions over unmerged shuffle blocks. + * * If any of the statuses is null (indicating a missing location due to a failed mapper), * throws a FetchFailedException. * * @param shuffleId Identifier for the shuffle * @param startPartition Start of map output partition ID range (included in range) * @param endPartition End of map output partition ID range (excluded from range) - * @param mapStatuses List of map statuses, indexed by map ID. + * @param mapStatuses List of map statuses, indexed by map partition index. * @param startMapIndex Start Map index. * @param endMapIndex End Map index. * @param mergeStatuses List of merge statuses, index by reduce ID. @@ -1286,6 +1283,11 @@ private[spark] object MapOutputTracker extends Logging { Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { assert (mapStatuses != null) val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]] + // Only use MergeStatus for reduce tasks that fetch all map outputs. Since a merged shuffle + // partition consists of blocks merged in random order, we are unable to serve map index + // subrange requests. However, when a reduce task needs to fetch blocks from a subrange of + // map outputs, it usually indicates skewed partitions which push-based shuffle delegates + // to AQE to handle. if (mergeStatuses.isDefined && startMapIndex == 0 && endMapIndex == mapStatuses.length) { // We have MergeStatus and full range of mapIds are requested so return a merged block. val numMaps = mapStatuses.length @@ -1293,14 +1295,12 @@ private[spark] object MapOutputTracker extends Logging { case (mergeStatus, partId) => val remainingMapStatuses = if (mergeStatus != null) { // If MergeStatus is available for the given partition, add location of the - // pre-merged shuffle partition for this partition ID - // TODO check recent upstream patch to see if there's a better way to handle - // TODO the dummy mapper ID here. For now, it is probably fine to use -1. + // pre-merged shuffle partition for this partition ID. Here we create a + // ShuffleBlockId with mapId being -1 to indicate this is a merged shuffle block. splitsByAddress.getOrElseUpdate(mergeStatus.location, ListBuffer()) += ((ShuffleBlockId(shuffleId, -1, partId), mergeStatus.totalSize, -1)) // For the "holes" in this pre-merged shuffle partition, i.e., unmerged mapper - // shuffle partition blocks, fall back to fetching the original map produced - // shuffle partition blocks + // shuffle partition blocks, fetch the original map produced shuffle partition blocks mergeStatus.getMissingMaps(numMaps).map(mapStatuses.zipWithIndex) } else { // If MergeStatus is not available for the given partition, fall back to @@ -1330,20 +1330,22 @@ private[spark] object MapOutputTracker extends Logging { } } } + splitsByAddress.mapValues(_.toSeq).iterator } /** - * Given a partition ID, an array of map statuses, and an array of merge statuses, identify - * the metadata about the shuffle partition blocks that are merged into the shuffle block - * represented by the merge status in the array specified by the given partition ID. + * Given a shuffle ID, a partition ID, an array of map statuses, and bitmap corresponding + * to either a merged shuffle partition or a merged shuffle partition chunk, identify + * the metadata about the shuffle partition blocks that are merged into the merged shuffle + * partition or partition chunk represented by the bitmap. * * @param shuffleId Identifier for the shuffle * @param partitionId The partition ID of the MergeStatus for which we look for the metadata * of the merged shuffle partition blocks * @param mapStatuses List of map statuses, indexed by map ID - * @param tracker bitmap containing mapIds that belong to the merged block or merged block - * chunk. + * @param tracker bitmap containing mapIndexes that belong to the merged block or merged + * block chunk. * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, * and the second item is a sequence of (shuffle block ID, shuffle block size) tuples * describing the shuffle blocks that are stored at that block manager. diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index df40bec6dede..5bdccb18e359 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -30,7 +30,7 @@ import org.apache.spark.util.Utils /** * A common trait between [[MapStatus]] and [[MergeStatus]]. This allows us to reuse existing - * code to handle MergeStatus. + * code to handle MergeStatus inside [[MapOutputTracker]]. */ private[spark] trait OutputStatus diff --git a/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala index f0b988ee0335..05b6af9fad4f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala @@ -26,10 +26,19 @@ import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils /** - * The status for the result of shuffle partition merge for each individual reducer partition - * maintained by the scheduler. The scheduler would separate the MergeStatuses received from - * ExternalShuffleService into individual MergeStatus which is maintained inside MapOutputTracker - * to be served to the reducers when they start fetching shuffle partition blocks. + * The status for the result of merging shuffle partition blocks per individual shuffle partition + * maintained by the scheduler. The scheduler would separate the [[MergeStatuses]] received from + * ExternalShuffleService into individual [[MergeStatus]] which is maintained inside + * [[MapOutputTracker]] to be served to the reducers when they start fetching shuffle partition + * blocks. Note that, the reducers are ultimately fetching individual chunks inside a merged + * shuffle file, as explained in [[org.apache.spark.network.shuffle.RemoteBlockPushResolver]]. + * Between the scheduler maintained MergeStatus and the shuffle service maintained per shuffle + * partition meta file, we are effectively dividing the metadata for a push-based shuffle into + * 2 layers. The scheduler would track the top-level metadata at the shuffle partition level + * with MergeStatus, and the shuffle service would maintain the partition level metadata about + * how to further divide a merged shuffle partition into multiple chunks with the per-partition + * meta file. This helps to reduce the amount of data the scheduler needs to maintain for + * push-based shuffle. */ private[spark] class MergeStatus( private[this] var loc: BlockManagerId, @@ -78,7 +87,9 @@ private[spark] object MergeStatus { def convertMergeStatusesToMergeStatusArr( mergeStatuses: MergeStatuses, loc: BlockManagerId): Seq[(Int, MergeStatus)] = { - val mergerLoc = BlockManagerId("", loc.host, loc.port) + assert(mergeStatuses.bitmaps.length == mergeStatuses.reduceIds.length && + mergeStatuses.bitmaps.length == mergeStatuses.sizes.length) + val mergerLoc = BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, loc.host, loc.port) mergeStatuses.bitmaps.zipWithIndex.map { case (bitmap, index) => val mergeStatus = new MergeStatus(mergerLoc, bitmap, mergeStatuses.sizes(index)) diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index f9bd75d5a75d..fc828b08da1b 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus, MergeStatus} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} -class MapOutputTrackerSuite extends SparkFunSuite { +class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { private val conf = new SparkConf private def newTrackerMaster(sparkConf: SparkConf = conf) = { @@ -333,7 +333,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { rpcEnv.shutdown() } - test("master register and unregister merge result") { + test("SPARK-32921: master register and unregister merge result") { val rpcEnv = createRpcEnv("test") val tracker = newTrackerMaster() tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, @@ -355,9 +355,51 @@ class MapOutputTrackerSuite extends SparkFunSuite { rpcEnv.shutdown() } - test("get map statuses from merged shuffle") { - conf.set("spark.shuffle.push.based.enabled", "true") - conf.set("spark.shuffle.service.enabled", "true") + test("SPARK-32921: get map sizes with merged shuffle") { + conf.set("spark.shuffle.push.enabled", "true") + conf.set("spark.testing", "true") + val hostname = "localhost" + val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf)) + + val masterTracker = newTrackerMaster() + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) + + val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(conf)) + val slaveTracker = new MapOutputTrackerWorker(conf) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + + masterTracker.registerShuffle(10, 4, 1) + slaveTracker.updateEpoch(masterTracker.getEpoch) + val bitmap = new RoaringBitmap() + bitmap.add(0) + bitmap.add(1) + bitmap.add(3) + + val blockMgrId = BlockManagerId("a", "hostA", 1000) + masterTracker.registerMapOutput(10, 0, MapStatus(blockMgrId, Array(1000L), 0)) + masterTracker.registerMapOutput(10, 1, MapStatus(blockMgrId, Array(1000L), 1)) + masterTracker.registerMapOutput(10, 2, MapStatus(blockMgrId, Array(1000L), 2)) + masterTracker.registerMapOutput(10, 3, MapStatus(blockMgrId, Array(1000L), 3)) + + masterTracker.registerMergeResult(10, 0, MergeStatus(blockMgrId, + bitmap, 3000L)) + slaveTracker.updateEpoch(masterTracker.getEpoch) + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === + Seq((blockMgrId, ArrayBuffer((ShuffleBlockId(10, -1, 0), 3000, -1), + (ShuffleBlockId(10, 2, 0), size1000, 2))))) + + masterTracker.stop() + slaveTracker.stop() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() + } + + test("SPARK-32921: get map statuses from merged shuffle") { + conf.set("spark.shuffle.push.enabled", "true") + conf.set("spark.testing", "true") val hostname = "localhost" val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf)) @@ -381,14 +423,10 @@ class MapOutputTrackerSuite extends SparkFunSuite { bitmap.add(3) val blockMgrId = BlockManagerId("a", "hostA", 1000) - masterTracker.registerMapOutput( - 10, 0, MapStatus(blockMgrId, Array(1000L), 0)) - masterTracker.registerMapOutput( - 10, 1, MapStatus(blockMgrId, Array(1000L), 1)) - masterTracker.registerMapOutput( - 10, 2, MapStatus(blockMgrId, Array(1000L), 2)) - masterTracker.registerMapOutput( - 10, 3, MapStatus(blockMgrId, Array(1000L), 3)) + masterTracker.registerMapOutput(10, 0, MapStatus(blockMgrId, Array(1000L), 0)) + masterTracker.registerMapOutput(10, 1, MapStatus(blockMgrId, Array(1000L), 1)) + masterTracker.registerMapOutput(10, 2, MapStatus(blockMgrId, Array(1000L), 2)) + masterTracker.registerMapOutput(10, 3, MapStatus(blockMgrId, Array(1000L), 3)) masterTracker.registerMergeResult(10, 0, MergeStatus(blockMgrId, bitmap, 4000L)) @@ -404,9 +442,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { slaveRpcEnv.shutdown() } - test("get map statuses for merged shuffle block chunks") { - conf.set("spark.shuffle.push.based.enabled", "true") - conf.set("spark.shuffle.service.enabled", "true") + test("SPARK-32921: get map statuses for merged shuffle block chunks") { + conf.set("spark.shuffle.push.enabled", "true") + conf.set("spark.testing", "true") val hostname = "localhost" val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf)) @@ -441,4 +479,68 @@ class MapOutputTrackerSuite extends SparkFunSuite { rpcEnv.shutdown() slaveRpcEnv.shutdown() } + + test("SPARK-32921: getPreferredLocationsForShuffle with MergeStatus") { + val rpcEnv = createRpcEnv("test") + val tracker = newTrackerMaster() + sc = new SparkContext("local", "test", conf.clone()) + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + // Setup 5 map tasks + // on hostA with output size 2 + // on hostA with output size 2 + // on hostB with output size 3 + // on hostB with output size 3 + // on hostC with output size 1 + // on hostC with output size 1 + tracker.registerShuffle(10, 6, 1) + tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), + Array(2L), 5)) + tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("a", "hostA", 1000), + Array(2L), 6)) + tracker.registerMapOutput(10, 2, MapStatus(BlockManagerId("b", "hostB", 1000), + Array(3L), 7)) + tracker.registerMapOutput(10, 3, MapStatus(BlockManagerId("b", "hostB", 1000), + Array(3L), 8)) + tracker.registerMapOutput(10, 4, MapStatus(BlockManagerId("c", "hostC", 1000), + Array(1L), 9)) + tracker.registerMapOutput(10, 5, MapStatus(BlockManagerId("c", "hostC", 1000), + Array(1L), 10)) + + val rdd = sc.parallelize(1 to 6, 6).map(num => (num, num).asInstanceOf[Product2[Int, Int]]) + val mockShuffleDep = mock(classOf[ShuffleDependency[Int, Int, _]]) + when(mockShuffleDep.shuffleId).thenReturn(10) + when(mockShuffleDep.partitioner).thenReturn(new HashPartitioner(1)) + when(mockShuffleDep.rdd).thenReturn(rdd) + + // Prepare a MergeStatus that merges 4 out of 5 blocks + val bitmap80 = new RoaringBitmap() + bitmap80.add(0) + bitmap80.add(1) + bitmap80.add(2) + bitmap80.add(3) + bitmap80.add(4) + tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000), + bitmap80, 11)) + + val preferredLocs1 = tracker.getPreferredLocationsForShuffle(mockShuffleDep, 0) + assert(preferredLocs1.nonEmpty) + assert(preferredLocs1.length === 1) + assert(preferredLocs1.head === "hostA") + + tracker.unregisterMergeResult(10, 0, BlockManagerId("a", "hostA", 1000)) + // Prepare another MergeStatus that merges only 1 out of 5 blocks + val bitmap20 = new RoaringBitmap() + bitmap20.add(0) + tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000), + bitmap20, 2)) + + val preferredLocs2 = tracker.getPreferredLocationsForShuffle(mockShuffleDep, 0) + assert(preferredLocs2.nonEmpty) + assert(preferredLocs2.length === 2) + assert(preferredLocs2 === Seq("hostA", "hostB")) + + tracker.stop() + rpcEnv.shutdown() + } } From 9fee2effc1aa0fe2ee6bff6fc7a506c51f8a8b61 Mon Sep 17 00:00:00 2001 From: Min Shen Date: Tue, 24 Nov 2020 10:53:35 -0800 Subject: [PATCH 08/24] Fix Scala 2.13 compatibility issue --- .../main/scala/org/apache/spark/MapOutputTracker.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index d22ede4e28bf..f3f17e58881c 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -1057,7 +1057,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr override def getMapSizesForMergeResult( shuffleId: Int, - partitionId: Int): Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition $partitionId") // Fetch the map statuses and merge statuses again since they might have already been // cleared by another task running in the same executor. @@ -1082,7 +1082,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr override def getMapSizesForMergeResult( shuffleId: Int, partitionId: Int, - chunkTracker: RoaringBitmap): Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + chunkTracker: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition $partitionId") // Fetch the map statuses and merge statuses again since they might have already been // cleared by another task running in the same executor. @@ -1354,7 +1354,7 @@ private[spark] object MapOutputTracker extends Logging { shuffleId: Int, partitionId: Int, mapStatuses: Array[MapStatus], - tracker: RoaringBitmap): Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + tracker: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { assert (mapStatuses != null && tracker != null) val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]] for ((status, mapIndex) <- mapStatuses.zipWithIndex) { @@ -1366,7 +1366,7 @@ private[spark] object MapOutputTracker extends Logging { status.getSizeForBlock(partitionId), mapIndex)) } } - splitsByAddress.toSeq + splitsByAddress.mapValues(_.toSeq).iterator } def validateStatus(status: OutputStatus, shuffleId: Int, partition: Int) : Unit = { From 10f3079aa03333aed51e81ffce2a200c67131921 Mon Sep 17 00:00:00 2001 From: Min Shen Date: Tue, 24 Nov 2020 16:20:31 -0800 Subject: [PATCH 09/24] Fix build issues --- .../scala/org/apache/spark/MapOutputTracker.scala | 12 ++++++------ .../org/apache/spark/MapOutputTrackerSuite.scala | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index f3f17e58881c..5e309cce8f0a 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -489,7 +489,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging */ def getMapSizesForMergeResult( shuffleId: Int, - partitionId: Int): Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])] + partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] /** * Called from executors upon fetch failure on a merged shuffle partition chunk. This is to get @@ -500,7 +500,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging def getMapSizesForMergeResult( shuffleId: Int, partitionId: Int, - chunkBitmap: RoaringBitmap): Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])] + chunkBitmap: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] /** * Deletes map output status information for the specified shuffle stage. @@ -981,8 +981,8 @@ private[spark] class MapOutputTrackerMaster( // enabled in local-mode, this method returns empty list. override def getMapSizesForMergeResult( shuffleId: Int, - partitionId: Int): Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { - Seq.empty + partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + Seq.empty.toIterator } // This method is only called in local-mode. Since push based shuffle won't be @@ -990,8 +990,8 @@ private[spark] class MapOutputTrackerMaster( override def getMapSizesForMergeResult( shuffleId: Int, partitionId: Int, - chunkTracker: RoaringBitmap): Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { - Seq.empty + chunkTracker: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + Seq.empty.toIterator } override def stop(): Unit = { diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index fc828b08da1b..0f399ae2376a 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -432,7 +432,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { bitmap, 4000L)) slaveTracker.updateEpoch(masterTracker.getEpoch) val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) - assert(slaveTracker.getMapSizesForMergeResult(10, 0) === + assert(slaveTracker.getMapSizesForMergeResult(10, 0).toSeq === Seq((blockMgrId, ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000, 0), (ShuffleBlockId(10, 1, 0), size1000, 1), (ShuffleBlockId(10, 2, 0), size1000, 2), (ShuffleBlockId(10, 3, 0), size1000, 3))))) @@ -470,7 +470,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { chunkBitmap.add(0) chunkBitmap.add(2) val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) - assert(slaveTracker.getMapSizesForMergeResult(10, 0, chunkBitmap) === + assert(slaveTracker.getMapSizesForMergeResult(10, 0, chunkBitmap).toSeq === Seq((blockMgrId, ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000, 0), (ShuffleBlockId(10, 2, 0), size1000, 2)))) ) From b37efa4fbc4fff97d29d80dbe3064a0c5ff6c074 Mon Sep 17 00:00:00 2001 From: Min Shen Date: Wed, 25 Nov 2020 10:00:58 -0800 Subject: [PATCH 10/24] Fix javadoc issue --- .../src/main/scala/org/apache/spark/scheduler/MapStatus.scala | 4 ++-- .../main/scala/org/apache/spark/scheduler/MergeStatus.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 5bdccb18e359..ecd5d6647372 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -30,7 +30,7 @@ import org.apache.spark.util.Utils /** * A common trait between [[MapStatus]] and [[MergeStatus]]. This allows us to reuse existing - * code to handle MergeStatus inside [[MapOutputTracker]]. + * code to handle MergeStatus inside MapOutputTracker. */ private[spark] trait OutputStatus @@ -40,7 +40,7 @@ private[spark] trait OutputStatus * on to the reduce tasks. */ private[spark] sealed trait MapStatus extends OutputStatus { - /** Location where this task was run. */ + /** Location where this task output is. */ def location: BlockManagerId def updateLocation(newLoc: BlockManagerId): Unit diff --git a/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala index 05b6af9fad4f..85219712b47c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala @@ -29,7 +29,7 @@ import org.apache.spark.util.Utils * The status for the result of merging shuffle partition blocks per individual shuffle partition * maintained by the scheduler. The scheduler would separate the [[MergeStatuses]] received from * ExternalShuffleService into individual [[MergeStatus]] which is maintained inside - * [[MapOutputTracker]] to be served to the reducers when they start fetching shuffle partition + * MapOutputTracker to be served to the reducers when they start fetching shuffle partition * blocks. Note that, the reducers are ultimately fetching individual chunks inside a merged * shuffle file, as explained in [[org.apache.spark.network.shuffle.RemoteBlockPushResolver]]. * Between the scheduler maintained MergeStatus and the shuffle service maintained per shuffle From a188649b45efb30cfce0f959439ae25c59872214 Mon Sep 17 00:00:00 2001 From: Min Shen Date: Wed, 25 Nov 2020 11:50:29 -0800 Subject: [PATCH 11/24] Fix more javadoc issue --- .../main/scala/org/apache/spark/scheduler/MergeStatus.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala index 85219712b47c..86c9ce064b68 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala @@ -27,7 +27,8 @@ import org.apache.spark.util.Utils /** * The status for the result of merging shuffle partition blocks per individual shuffle partition - * maintained by the scheduler. The scheduler would separate the [[MergeStatuses]] received from + * maintained by the scheduler. The scheduler would separate the + * [[org.apache.spark.network.shuffle.protocol.MergeStatuses]] received from * ExternalShuffleService into individual [[MergeStatus]] which is maintained inside * MapOutputTracker to be served to the reducers when they start fetching shuffle partition * blocks. Note that, the reducers are ultimately fetching individual chunks inside a merged From 3c5fc126cc1388d4437b144c44988ccb65143b3b Mon Sep 17 00:00:00 2001 From: Min Shen Date: Wed, 2 Dec 2020 12:38:54 -0800 Subject: [PATCH 12/24] Address review comments --- .../scala/org/apache/spark/MapOutputTracker.scala | 15 ++++++++------- .../org/apache/spark/scheduler/MapStatus.scala | 4 ++-- .../org/apache/spark/scheduler/MergeStatus.scala | 2 +- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 5e309cce8f0a..0b58980fcc8c 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -36,7 +36,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.io.CompressionCodec import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} -import org.apache.spark.scheduler.{MapStatus, MergeStatus, OutputStatus} +import org.apache.spark.scheduler.{MapStatus, MergeStatus, ShuffleOutputStatus} import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util._ @@ -396,7 +396,7 @@ private[spark] class MapOutputTrackerMasterEndpoint( override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case GetMapOutputStatuses(shuffleId: Int) => val hostPort = context.senderAddress.hostPort - logInfo(s"Asked to send map output locations for shuffle ${shuffleId} to ${hostPort}") + logInfo(s"Asked to send map output locations for shuffle $shuffleId to $hostPort") tracker.post(GetOutputStatusesMessage(shuffleId, fetchMapOutput = true, context)) case GetMergeResultStatuses(shuffleId: Int) => @@ -729,7 +729,8 @@ private[spark] class MapOutputTrackerMaster( shuffleStatuses.get(shuffleId).map(_.numAvailableMapOutputs).getOrElse(0) } - def getNumAvailableMergeResults(shuffleId: Int): Int = { + /** VisibleForTest. Invoked in test only. */ + private[spark] def getNumAvailableMergeResults(shuffleId: Int): Int = { shuffleStatuses.get(shuffleId).map(_.numAvailableMergeResults).getOrElse(0) } @@ -1171,7 +1172,7 @@ private[spark] object MapOutputTracker extends Logging { // Serialize an array of map/merge output locations into an efficient byte format so that we can // send it to reduce tasks. We do this by compressing the serialized bytes using Zstd. They will // generally be pretty compressible because many outputs will be on the same hostname. - def serializeOutputStatuses[T <: OutputStatus]( + def serializeOutputStatuses[T <: ShuffleOutputStatus]( statuses: Array[T], broadcastManager: BroadcastManager, isLocal: Boolean, @@ -1216,7 +1217,7 @@ private[spark] object MapOutputTracker extends Logging { } // Opposite of serializeOutputStatuses. - def deserializeOutputStatuses[T <: OutputStatus]( + def deserializeOutputStatuses[T <: ShuffleOutputStatus]( bytes: Array[Byte], conf: SparkConf): Array[T] = { assert (bytes.length > 0) @@ -1293,7 +1294,7 @@ private[spark] object MapOutputTracker extends Logging { val numMaps = mapStatuses.length mergeStatuses.get.zipWithIndex.slice(startPartition, endPartition).foreach { case (mergeStatus, partId) => - val remainingMapStatuses = if (mergeStatus != null) { + val remainingMapStatuses = if (mergeStatus != null && mergeStatus.totalSize > 0) { // If MergeStatus is available for the given partition, add location of the // pre-merged shuffle partition for this partition ID. Here we create a // ShuffleBlockId with mapId being -1 to indicate this is a merged shuffle block. @@ -1369,7 +1370,7 @@ private[spark] object MapOutputTracker extends Logging { splitsByAddress.mapValues(_.toSeq).iterator } - def validateStatus(status: OutputStatus, shuffleId: Int, partition: Int) : Unit = { + def validateStatus(status: ShuffleOutputStatus, shuffleId: Int, partition: Int) : Unit = { if (status == null) { val errorMessage = s"Missing an output location for shuffle $shuffleId partition $partition" logError(errorMessage) diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index ecd5d6647372..07eed76805dd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -32,14 +32,14 @@ import org.apache.spark.util.Utils * A common trait between [[MapStatus]] and [[MergeStatus]]. This allows us to reuse existing * code to handle MergeStatus inside MapOutputTracker. */ -private[spark] trait OutputStatus +private[spark] trait ShuffleOutputStatus /** * Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the * task has shuffle files stored on as well as the sizes of outputs for each reducer, for passing * on to the reduce tasks. */ -private[spark] sealed trait MapStatus extends OutputStatus { +private[spark] sealed trait MapStatus extends ShuffleOutputStatus { /** Location where this task output is. */ def location: BlockManagerId diff --git a/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala index 86c9ce064b68..ed7e14b7efe4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala @@ -45,7 +45,7 @@ private[spark] class MergeStatus( private[this] var loc: BlockManagerId, private[this] var mapTracker: RoaringBitmap, private[this] var size: Long) - extends Externalizable with OutputStatus { + extends Externalizable with ShuffleOutputStatus { protected def this() = this(null, null, -1) // For deserialization only From 384de48ed42901f98427eefdfeeb0c3b676693c6 Mon Sep 17 00:00:00 2001 From: Venkata krishnan Sowrirajan Date: Thu, 18 Mar 2021 11:51:16 -0700 Subject: [PATCH 13/24] Address Mridul's review comments --- .../org/apache/spark/MapOutputTracker.scala | 110 +++++++++++------- .../apache/spark/scheduler/DAGScheduler.scala | 3 +- .../apache/spark/scheduler/MergeStatus.scala | 9 ++ .../apache/spark/MapOutputTrackerSuite.scala | 31 ++--- .../spark/MapStatusesSerDeserBenchmark.scala | 4 +- .../scala/org/apache/spark/ShuffleSuite.scala | 6 +- .../spark/storage/BlockManagerSuite.scala | 4 +- 7 files changed, 103 insertions(+), 64 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 0b58980fcc8c..97c751b338d6 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -50,7 +50,10 @@ import org.apache.spark.util._ * * All public methods of this class are thread-safe. */ -private class ShuffleStatus(numPartitions: Int, numReducers: Int) extends Logging { +private class ShuffleStatus( + numPartitions: Int, + numReducers: Int, + isPushBasedShuffleEnabled: Boolean = false) extends Logging { private val (readLock, writeLock) = { val lock = new ReentrantReadWriteLock() @@ -94,7 +97,11 @@ private class ShuffleStatus(numPartitions: Int, numReducers: Int) extends Loggin * provides a reducer oriented view of the shuffle status specifically for the results of * merging shuffle partition blocks into per-partition merged shuffle files. */ - val mergeStatuses = new Array[MergeStatus](numReducers) + val mergeStatuses = if (isPushBasedShuffleEnabled) { + new Array[MergeStatus](numReducers) + } else { + Array.empty[MergeStatus] + } /** * The cached result of serializing the map statuses array. This cache is lazily populated when @@ -182,7 +189,7 @@ private class ShuffleStatus(numPartitions: Int, numReducers: Int) extends Loggin * Register a merge result. */ def addMergeResult(reduceId: Int, status: MergeStatus): Unit = withWriteLock { - if (mergeStatuses(reduceId) == null) { + if (mergeStatuses(reduceId) != status) { _numAvailableMergeResults += 1 invalidateSerializedMergeOutputStatusCache() } @@ -379,12 +386,11 @@ private[spark] case class GetMergeResultStatuses(shuffleId: Int) extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage -/** - * The boolean flag in the case class indicates whether the request is for map output or not. - * If false, the request is for merge statuses instead. - */ -private[spark] case class GetOutputStatusesMessage(shuffleId: Int, - fetchMapOutput: Boolean, context: RpcCallContext) +private[spark] sealed trait MapOutputTrackerMasterMessage +private[spark] case class GetMapStatusMessage(shuffleId: Int, + context: RpcCallContext) extends MapOutputTrackerMasterMessage +private[spark] case class GetMergeStatusMessage(shuffleId: Int, + context: RpcCallContext) extends MapOutputTrackerMasterMessage /** RpcEndpoint class for MapOutputTrackerMaster */ private[spark] class MapOutputTrackerMasterEndpoint( @@ -397,12 +403,12 @@ private[spark] class MapOutputTrackerMasterEndpoint( case GetMapOutputStatuses(shuffleId: Int) => val hostPort = context.senderAddress.hostPort logInfo(s"Asked to send map output locations for shuffle $shuffleId to $hostPort") - tracker.post(GetOutputStatusesMessage(shuffleId, fetchMapOutput = true, context)) + tracker.post(GetMapStatusMessage(shuffleId, context)) case GetMergeResultStatuses(shuffleId: Int) => val hostPort = context.senderAddress.hostPort logInfo(s"Asked to send merge result locations for shuffle $shuffleId to $hostPort") - tracker.post(GetOutputStatusesMessage(shuffleId, fetchMapOutput = false, context)) + tracker.post(GetMergeStatusMessage(shuffleId, context)) case StopMapOutputTracker => logInfo("MapOutputTrackerMasterEndpoint stopped!") @@ -486,6 +492,10 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * partition. This method is to get the server URIs and output sizes for each shuffle block that * is merged in the specified merged shuffle block so fetch failure on a merged shuffle block can * fall back to fetching the unmerged blocks. + * + * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, + * and the second item is a sequence of (shuffle block ID, shuffle block size, map index) + * tuples describing the shuffle blocks that are stored at that block manager. */ def getMapSizesForMergeResult( shuffleId: Int, @@ -496,6 +506,14 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * the server URIs and output sizes for each shuffle block that is merged in the specified merged * shuffle partition chunk so fetch failure on a merged shuffle block chunk can fall back to * fetching the unmerged blocks. + * + * chunkBitMap tracks the mapIds which are part of the current merged chunk, this way if there is + * a fetch failure on the merged chunk, it can fallback to fetching the corresponding original + * blocks part of this merged chunk. + * + * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, + * and the second item is a sequence of (shuffle block ID, shuffle block size, map index) + * tuples describing the shuffle blocks that are stored at that block manager. */ def getMapSizesForMergeResult( shuffleId: Int, @@ -550,8 +568,9 @@ private[spark] class MapOutputTrackerMaster( private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) - // requests for map/merge output statuses - private val outputStatusesRequests = new LinkedBlockingQueue[GetOutputStatusesMessage] + // requests for MapOutputTrackerMasterMessages + private val mapOutputTrackerMasterMessages = + new LinkedBlockingQueue[MapOutputTrackerMasterMessage] private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(conf) @@ -576,8 +595,8 @@ private[spark] class MapOutputTrackerMaster( throw new IllegalArgumentException(msg) } - def post(message: GetOutputStatusesMessage): Unit = { - outputStatusesRequests.offer(message) + def post(message: MapOutputTrackerMasterMessage): Unit = { + mapOutputTrackerMasterMessages.offer(message) } /** Message loop used for dispatching messages. */ @@ -586,28 +605,31 @@ private[spark] class MapOutputTrackerMaster( try { while (true) { try { - val data = outputStatusesRequests.take() + val data = mapOutputTrackerMasterMessages.take() if (data == PoisonPill) { // Put PoisonPill back so that other MessageLoops can see it. - outputStatusesRequests.offer(PoisonPill) + mapOutputTrackerMasterMessages.offer(PoisonPill) return } - val context = data.context - val shuffleId = data.shuffleId - val hostPort = context.senderAddress.hostPort - val shuffleStatus = shuffleStatuses.get(shuffleId).head - if (data.fetchMapOutput) { - logDebug("Handling request to send map output locations for shuffle " + shuffleId + - " to " + hostPort) - context.reply( - shuffleStatus.serializedOutputStatus(broadcastManager, isLocal, minSizeForBroadcast, - conf, isMapOutput = true)) - } else { - logDebug("Handling request to send merge output locations for shuffle " + shuffleId + - " to " + hostPort) - context.reply( - shuffleStatus.serializedOutputStatus(broadcastManager, isLocal, minSizeForBroadcast, - conf, isMapOutput = false)) + + data match { + case GetMapStatusMessage(shuffleId, context) => + val hostPort = context.senderAddress.hostPort + val shuffleStatus = shuffleStatuses.get(shuffleId).head + logDebug("Handling request to send map output locations for shuffle " + shuffleId + + " to " + hostPort) + context.reply( + shuffleStatus.serializedOutputStatus(broadcastManager, isLocal, + minSizeForBroadcast, conf, isMapOutput = true)) + + case GetMergeStatusMessage(shuffleId, context) => + val hostPort = context.senderAddress.hostPort + val shuffleStatus = shuffleStatuses.get(shuffleId).head + logDebug("Handling request to send merge output locations for" + + " shuffle " + shuffleId + " to " + hostPort) + context.reply( + shuffleStatus.serializedOutputStatus(broadcastManager, isLocal, + minSizeForBroadcast, conf, isMapOutput = false)) } } catch { case NonFatal(e) => logError(e.getMessage, e) @@ -620,15 +642,16 @@ private[spark] class MapOutputTrackerMaster( } /** A poison endpoint that indicates MessageLoop should exit its message loop. */ - private val PoisonPill = new GetOutputStatusesMessage(-99, true, null) + private val PoisonPill = GetMapStatusMessage(-99, null) // Used only in unit tests. private[spark] def getNumCachedSerializedBroadcast: Int = { shuffleStatuses.valuesIterator.count(_.hasCachedSerializedBroadcast) } - def registerShuffle(shuffleId: Int, numMaps: Int, numReduces: Int = 0): Unit = { - if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps, numReduces)).isDefined) { + def registerShuffle(shuffleId: Int, numMaps: Int, numReduces: Int): Unit = { + if (shuffleStatuses.put(shuffleId, + new ShuffleStatus(numMaps, numReduces, pushBasedShuffleEnabled)).isDefined) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } } @@ -828,7 +851,7 @@ private[spark] class MapOutputTrackerMaster( shuffleStatus.withMergeStatuses { statuses => val status = statuses(partitionId) val numMaps = dep.rdd.partitions.length - if (status != null && status.getMissingMaps(numMaps).length.toDouble / numMaps + if (status != null && status.getNumMissingMapOutputs(numMaps).toDouble / numMaps <= (1 - REDUCER_PREF_LOCS_FRACTION)) { Seq(status.location.host) } else { @@ -996,7 +1019,7 @@ private[spark] class MapOutputTrackerMaster( } override def stop(): Unit = { - outputStatusesRequests.offer(PoisonPill) + mapOutputTrackerMasterMessages.offer(PoisonPill) threadpool.shutdown() try { sendTracker(StopMapOutputTracker) @@ -1068,11 +1091,11 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr // If the original MergeStatus is no longer available, we cannot identify the list of // unmerged blocks to fetch in this case. Throw MetadataFetchFailedException in this case. MapOutputTracker.validateStatus(mergeStatus, shuffleId, partitionId) - // User the MergeStatus's partition level bitmap since we are doing partition level fallback + // Use the MergeStatus's partition level bitmap since we are doing partition level fallback MapOutputTracker.getMapStatusesForMergeStatus(shuffleId, partitionId, mapOutputStatuses, mergeStatus.tracker) } catch { - // We experienced a fetch failure so our mapStatuses cache is outdated; clear it: + // We experienced a fetch failure so our mapStatuses cache is outdated; clear it case e: MetadataFetchFailedException => mapStatuses.clear() mergeStatuses.clear() @@ -1169,6 +1192,8 @@ private[spark] object MapOutputTracker extends Logging { private val DIRECT = 0 private val BROADCAST = 1 + private val SHUFFLE_PUSH_MAP_ID = -1 + // Serialize an array of map/merge output locations into an efficient byte format so that we can // send it to reduce tasks. We do this by compressing the serialized bytes using Zstd. They will // generally be pretty compressible because many outputs will be on the same hostname. @@ -1297,9 +1322,10 @@ private[spark] object MapOutputTracker extends Logging { val remainingMapStatuses = if (mergeStatus != null && mergeStatus.totalSize > 0) { // If MergeStatus is available for the given partition, add location of the // pre-merged shuffle partition for this partition ID. Here we create a - // ShuffleBlockId with mapId being -1 to indicate this is a merged shuffle block. + // ShuffleBlockId with mapId being SHUFFLE_PUSH_MAP_ID to indicate this is + // a merged shuffle block. splitsByAddress.getOrElseUpdate(mergeStatus.location, ListBuffer()) += - ((ShuffleBlockId(shuffleId, -1, partId), mergeStatus.totalSize, -1)) + ((ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, partId), mergeStatus.totalSize, -1)) // For the "holes" in this pre-merged shuffle partition, i.e., unmerged mapper // shuffle partition blocks, fetch the original map produced shuffle partition blocks mergeStatus.getMissingMaps(numMaps).map(mapStatuses.zipWithIndex) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index c2e7c4dae038..a92d9fab6efc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -456,7 +456,8 @@ private[spark] class DAGScheduler( // since we can't do it in the RDD constructor because # of partitions is unknown logInfo(s"Registering RDD ${rdd.id} (${rdd.getCreationSite}) as input to " + s"shuffle ${shuffleDep.shuffleId}") - mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.length) + mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.length, + shuffleDep.partitioner.numPartitions) } stage } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala index ed7e14b7efe4..99e01941d9fc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala @@ -64,6 +64,13 @@ private[spark] class MergeStatus( (0 until numMaps).filter(i => !mapTracker.contains(i)) } + /** + * Get the number of missing map outputs for missing mapper partition blocks that are not merged. + */ + def getNumMissingMapOutputs(numMaps: Int): Int = { + getMissingMaps(numMaps).length + } + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) mapTracker.writeExternal(out) @@ -79,6 +86,8 @@ private[spark] class MergeStatus( } private[spark] object MergeStatus { + // Dummy number of reduces for the tests where push based shuffle is not enabled + val SHUFFLE_PUSH_DUMMY_NUM_REDUCES = 1 /** * Separate a MergeStatuses received from an ExternalShuffleService into individual diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 0f399ae2376a..4dae20020149 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.LocalSparkContext._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.internal.config._ import org.apache.spark.internal.config.Network.{RPC_ASK_TIMEOUT, RPC_MESSAGE_MAX_SIZE} +import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus, MergeStatus} import org.apache.spark.shuffle.FetchFailedException @@ -59,7 +60,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { val tracker = newTrackerMaster() tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) - tracker.registerShuffle(10, 2) + tracker.registerShuffle(10, 2, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) assert(tracker.containsShuffle(10)) val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) @@ -83,7 +84,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { val tracker = newTrackerMaster() tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) - tracker.registerShuffle(10, 2) + tracker.registerShuffle(10, 2, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), @@ -106,7 +107,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { val tracker = newTrackerMaster() tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) - tracker.registerShuffle(10, 2) + tracker.registerShuffle(10, 2, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), @@ -141,7 +142,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { mapWorkerTracker.trackerEndpoint = mapWorkerRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - masterTracker.registerShuffle(10, 1) + masterTracker.registerShuffle(10, 1, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) mapWorkerTracker.updateEpoch(masterTracker.getEpoch) // This is expected to fail because no outputs have been registered for the shuffle. intercept[FetchFailedException] { mapWorkerTracker.getMapSizesByExecutorId(10, 0) } @@ -184,7 +185,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) // Message size should be ~123B, and no exception should be thrown - masterTracker.registerShuffle(10, 1) + masterTracker.registerShuffle(10, 1, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0), 5)) val senderAddress = RpcAddress("localhost", 12345) @@ -218,7 +219,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { // on hostA with output size 2 // on hostA with output size 2 // on hostB with output size 3 - tracker.registerShuffle(10, 3) + tracker.registerShuffle(10, 3, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), Array(2L), 5)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("a", "hostA", 1000), @@ -261,7 +262,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { // Frame size should be ~1.1MB, and MapOutputTrackerMasterEndpoint should throw exception. // Note that the size is hand-selected here because map output statuses are compressed before // being sent. - masterTracker.registerShuffle(20, 100) + masterTracker.registerShuffle(20, 100, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) (0 until 100).foreach { i => masterTracker.registerMapOutput(20, i, new CompressedMapStatus( BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 5)) @@ -307,7 +308,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { val tracker = newTrackerMaster() tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) - tracker.registerShuffle(10, 2) + tracker.registerShuffle(10, 2, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) val size0 = MapStatus.decompressSize(MapStatus.compressSize(0L)) val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) @@ -334,6 +335,8 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { } test("SPARK-32921: master register and unregister merge result") { + conf.set(PUSH_BASED_SHUFFLE_ENABLED, true) + conf.set(IS_TESTING, true) val rpcEnv = createRpcEnv("test") val tracker = newTrackerMaster() tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, @@ -356,8 +359,8 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { } test("SPARK-32921: get map sizes with merged shuffle") { - conf.set("spark.shuffle.push.enabled", "true") - conf.set("spark.testing", "true") + conf.set(PUSH_BASED_SHUFFLE_ENABLED, true) + conf.set(IS_TESTING, true) val hostname = "localhost" val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf)) @@ -398,8 +401,8 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { } test("SPARK-32921: get map statuses from merged shuffle") { - conf.set("spark.shuffle.push.enabled", "true") - conf.set("spark.testing", "true") + conf.set(PUSH_BASED_SHUFFLE_ENABLED, true) + conf.set(IS_TESTING, true) val hostname = "localhost" val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf)) @@ -443,8 +446,8 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { } test("SPARK-32921: get map statuses for merged shuffle block chunks") { - conf.set("spark.shuffle.push.enabled", "true") - conf.set("spark.testing", "true") + conf.set(PUSH_BASED_SHUFFLE_ENABLED, true) + conf.set(IS_TESTING, true) val hostname = "localhost" val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf)) diff --git a/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala b/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala index 1af1e0f778cf..38274d0c5c70 100644 --- a/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala @@ -21,7 +21,7 @@ import org.scalatest.Assertions._ import org.apache.spark.benchmark.Benchmark import org.apache.spark.benchmark.BenchmarkBase -import org.apache.spark.scheduler.CompressedMapStatus +import org.apache.spark.scheduler.{CompressedMapStatus, MergeStatus} import org.apache.spark.storage.BlockManagerId /** @@ -52,7 +52,7 @@ object MapStatusesSerDeserBenchmark extends BenchmarkBase { val shuffleId = 10 - tracker.registerShuffle(shuffleId, numMaps) + tracker.registerShuffle(shuffleId, numMaps, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) val r = new scala.util.Random(912) (0 until numMaps).foreach { i => tracker.registerMapOutput(shuffleId, i, diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 56684d9b0327..126faec334e7 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark import java.io.File import java.util.{Locale, Properties} -import java.util.concurrent.{Callable, CyclicBarrier, Executors, ExecutorService} +import java.util.concurrent.{Callable, CyclicBarrier, Executors, ExecutorService } import scala.collection.JavaConverters._ @@ -33,7 +33,7 @@ import org.apache.spark.internal.config import org.apache.spark.internal.config.Tests.TEST_NO_STAGE_RETRY import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD, SubtractedRDD} -import org.apache.spark.scheduler.{MapStatus, MyRDD, SparkListener, SparkListenerTaskEnd} +import org.apache.spark.scheduler.{MapStatus, MergeStatus, MyRDD, SparkListener, SparkListenerTaskEnd} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.shuffle.ShuffleWriter import org.apache.spark.storage.{ShuffleBlockId, ShuffleDataBlockId, ShuffleIndexBlockId} @@ -367,7 +367,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC val shuffleMapRdd = new MyRDD(sc, 1, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) val shuffleHandle = manager.registerShuffle(0, shuffleDep) - mapTrackerMaster.registerShuffle(0, 1) + mapTrackerMaster.registerShuffle(0, 1, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) // first attempt -- its successful val context1 = diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 055ee0debeb1..707e1684f78f 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -53,7 +53,7 @@ import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, Transpo import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ExecutorDiskUtils, ExternalBlockStoreClient} import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor} import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv} -import org.apache.spark.scheduler.{LiveListenerBus, MapStatus, SparkListenerBlockUpdated} +import org.apache.spark.scheduler.{LiveListenerBus, MapStatus, MergeStatus, SparkListenerBlockUpdated} import org.apache.spark.scheduler.cluster.{CoarseGrainedClusterMessages, CoarseGrainedSchedulerBackend} import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerManager} @@ -1956,7 +1956,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE Files.write(bm1.diskBlockManager.getFile(shuffleIndex).toPath(), shuffleIndexBlockContent) Files.write(bm2.diskBlockManager.getFile(shuffleIndex2).toPath(), shuffleIndexBlockContent) - mapOutputTracker.registerShuffle(0, 1) + mapOutputTracker.registerShuffle(0, 1, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) val decomManager = new BlockManagerDecommissioner(conf, bm1) try { mapOutputTracker.registerMapOutput(0, 0, MapStatus(bm1.blockManagerId, Array(blockSize), 0)) From 694c8e7519fb45655ed11d0726e514dcc3a55530 Mon Sep 17 00:00:00 2001 From: Venkata krishnan Sowrirajan Date: Sun, 28 Mar 2021 16:36:04 -0700 Subject: [PATCH 14/24] Addressed review comments of ngone51 and mridulm --- .../org/apache/spark/MapOutputTracker.scala | 57 ++++++++++--------- .../apache/spark/scheduler/MergeStatus.scala | 2 +- 2 files changed, 32 insertions(+), 27 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 97c751b338d6..7995c5990d44 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -487,11 +487,11 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] /** - * Called from executors upon fetch failure on an entire merged shuffle partition. Such failures - * can happen if the shuffle client fails to fetch the metadata for the given merged shuffle - * partition. This method is to get the server URIs and output sizes for each shuffle block that - * is merged in the specified merged shuffle block so fetch failure on a merged shuffle block can - * fall back to fetching the unmerged blocks. + * Called from executors upon fetch failure on an entire merged shuffle reduce partition. + * Such failures can happen if the shuffle client fails to fetch the metadata for the given + * merged shuffle partition. This method is to get the server URIs and output sizes for each + * shuffle block that is merged in the specified merged shuffle block so fetch failure on a + * merged shuffle block can fall back to fetching the unmerged blocks. * * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, * and the second item is a sequence of (shuffle block ID, shuffle block size, map index) @@ -502,10 +502,10 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] /** - * Called from executors upon fetch failure on a merged shuffle partition chunk. This is to get - * the server URIs and output sizes for each shuffle block that is merged in the specified merged - * shuffle partition chunk so fetch failure on a merged shuffle block chunk can fall back to - * fetching the unmerged blocks. + * Called from executors upon fetch failure on a merged shuffle reduce partition chunk. This is + * to get the server URIs and output sizes for each shuffle block that is merged in the specified + * merged shuffle partition chunk so fetch failure on a merged shuffle block chunk can fall back + * to fetching the unmerged blocks. * * chunkBitMap tracks the mapIds which are part of the current merged chunk, this way if there is * a fetch failure on the merged chunk, it can fallback to fetching the corresponding original @@ -601,6 +601,24 @@ private[spark] class MapOutputTrackerMaster( /** Message loop used for dispatching messages. */ private class MessageLoop extends Runnable { + private def handleStatusMessage( + shuffleId: Int, + context: RpcCallContext, + isMapOutput: Boolean): Unit = { + val hostPort = context.senderAddress.hostPort + val shuffleStatus = shuffleStatuses.get(shuffleId).head + val mapOrMerge = if (isMapOutput) { + "map" + } else { + "merge" + } + logDebug(s"Handling request to send $mapOrMerge output locations" + + s" for shuffle $shuffleId to $hostPort") + context.reply( + shuffleStatus.serializedOutputStatus(broadcastManager, isLocal, + minSizeForBroadcast, conf, isMapOutput = isMapOutput)) + } + override def run(): Unit = { try { while (true) { @@ -614,22 +632,9 @@ private[spark] class MapOutputTrackerMaster( data match { case GetMapStatusMessage(shuffleId, context) => - val hostPort = context.senderAddress.hostPort - val shuffleStatus = shuffleStatuses.get(shuffleId).head - logDebug("Handling request to send map output locations for shuffle " + shuffleId + - " to " + hostPort) - context.reply( - shuffleStatus.serializedOutputStatus(broadcastManager, isLocal, - minSizeForBroadcast, conf, isMapOutput = true)) - + handleStatusMessage(shuffleId, context, true) case GetMergeStatusMessage(shuffleId, context) => - val hostPort = context.senderAddress.hostPort - val shuffleStatus = shuffleStatuses.get(shuffleId).head - logDebug("Handling request to send merge output locations for" + - " shuffle " + shuffleId + " to " + hostPort) - context.reply( - shuffleStatus.serializedOutputStatus(broadcastManager, isLocal, - minSizeForBroadcast, conf, isMapOutput = false)) + handleStatusMessage(shuffleId, context, false) } } catch { case NonFatal(e) => logError(e.getMessage, e) @@ -1134,7 +1139,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr val mapOutputStatuses = mapStatuses.get(shuffleId).orNull val mergeResultStatuses = mergeStatuses.get(shuffleId).orNull if (mapOutputStatuses == null || (fetchMergeResult && mergeResultStatuses == null)) { - logInfo("Don't have map/merge outputs for shuffle " + shuffleId + ", fetching them") + logInfo(s"Don't have map/merge outputs for shuffle $shuffleId, fetching them") val startTimeNs = System.nanoTime() fetchingLock.withLock(shuffleId) { var fetchedMapStatuses = mapStatuses.get(shuffleId).orNull @@ -1153,7 +1158,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr logInfo("Got the merge output locations") mergeStatuses.put(shuffleId, fetchedMergeStatues) } - logDebug(s"Fetching map/merge output statuses for shuffle $shuffleId took " + + logDebug(s"Fetching map ${if (fetchMergeResult) "/merge"} for shuffle $shuffleId took " + s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms") (fetchedMapStatuses, fetchedMergeStatues) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala index 99e01941d9fc..77d8f8e040da 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala @@ -68,7 +68,7 @@ private[spark] class MergeStatus( * Get the number of missing map outputs for missing mapper partition blocks that are not merged. */ def getNumMissingMapOutputs(numMaps: Int): Int = { - getMissingMaps(numMaps).length + (0 until numMaps).count(i => !mapTracker.contains(i)) } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { From 3c91ed024afb580c79a05fbe707bb629dc6a3606 Mon Sep 17 00:00:00 2001 From: Venkata krishnan Sowrirajan Date: Sun, 4 Apr 2021 10:49:37 -0700 Subject: [PATCH 15/24] fix test --- .../src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 33762f558c95..93a6f5345f42 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -585,7 +585,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { val err = intercept[SparkException] { MapOutputTracker.deserializeOutputStatuses(fetchedBytes, conf) } - assert(err.getMessage.contains("Unable to deserialize broadcasted map statuses")) + assert(err.getMessage.contains("Unable to deserialize broadcasted output statuses")) } } } From 4940b57ff2b732d42c9537693d982b9da93097ed Mon Sep 17 00:00:00 2001 From: Venkata krishnan Sowrirajan Date: Tue, 6 Apr 2021 12:17:23 -0700 Subject: [PATCH 16/24] Make one rpc for both MapStatus and MergeStatus --- .../org/apache/spark/MapOutputTracker.scala | 123 ++++++++++++------ 1 file changed, 81 insertions(+), 42 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 0ccdd56355c7..07d3bbaa07b6 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -20,17 +20,14 @@ package org.apache.spark import java.io.{ByteArrayInputStream, IOException, ObjectInputStream, ObjectOutputStream} import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit} import java.util.concurrent.locks.ReentrantReadWriteLock - import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, ListBuffer, Map} import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration import scala.reflect.ClassTag import scala.util.control.NonFatal - import org.apache.commons.io.output.{ByteArrayOutputStream => ApacheByteArrayOutputStream} import org.roaringbitmap.RoaringBitmap - import org.apache.spark.broadcast.{Broadcast, BroadcastManager} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ @@ -289,45 +286,86 @@ private class ShuffleStatus( isLocal: Boolean, minBroadcastSize: Int, conf: SparkConf, - isMapOutput: Boolean): Array[Byte] = { - var result: Array[Byte] = null + isMapOutput: Boolean): (Array[Byte], Array[Byte]) = { + var mapStatuses: Array[Byte] = null + var mergeStatuses: Array[Byte] = null withReadLock { if (isMapOutput) { if (cachedSerializedMapStatus != null) { - result = cachedSerializedMapStatus + mapStatuses = cachedSerializedMapStatus } } else { + if (cachedSerializedMapStatus != null) { + mapStatuses = cachedSerializedMapStatus + } + if (cachedSerializedMergeStatus != null) { - result = cachedSerializedMergeStatus + mergeStatuses = cachedSerializedMergeStatus } } } - if (result == null) withWriteLock { - if (isMapOutput) { - if (cachedSerializedMapStatus == null) { - val serResult = MapOutputTracker.serializeOutputStatuses( - mapStatuses, broadcastManager, isLocal, minBroadcastSize, conf) - cachedSerializedMapStatus = serResult._1 - cachedSerializedBroadcast = serResult._2 - } - // The following line has to be outside if statement since it's possible that another - // thread initializes cachedSerializedMapStatus in-between `withReadLock` and - // `withWriteLock`. - result = cachedSerializedMapStatus - } else { - if (cachedSerializedMergeStatus == null) { - val serResult = MapOutputTracker.serializeOutputStatuses( - mergeStatuses, broadcastManager, isLocal, minBroadcastSize, conf) - cachedSerializedMergeStatus = serResult._1 - cachedSerializedBroadcastMergeStatus = serResult._2 - } - // The following line has to be outside if statement for similar reasons as above. - result = cachedSerializedMergeStatus + if (isMapOutput) { + if (mapStatuses == null) { + mapStatuses = + serializeAndCacheMapStatuses(broadcastManager, isLocal, minBroadcastSize, conf) + } + } else { + if (mapStatuses == null) { + mapStatuses = + serializeAndCacheMapStatuses(broadcastManager, isLocal, minBroadcastSize, conf) } + + if (mergeStatuses == null) { + mergeStatuses = + serializeAndCacheMergeStatuses(broadcastManager, isLocal, minBroadcastSize, conf) + } + } + (mapStatuses, mergeStatuses) + } + + private def serializeAndCacheMapStatuses( + broadcastManager: BroadcastManager, + isLocal: Boolean, + minBroadcastSize: Int, + conf: SparkConf): Array[Byte] = { + var mapStatuses: Array[Byte] = null + withWriteLock { + if (cachedSerializedMapStatus == null) { + val serResult = MapOutputTracker.serializeOutputStatuses( + mapStatuses, broadcastManager, isLocal, minBroadcastSize, conf) + cachedSerializedMapStatus = serResult._1 + cachedSerializedBroadcast = serResult._2 + } + // The following line has to be outside if statement since it's possible that another + // thread initializes cachedSerializedMapStatus in-between `withReadLock` and + // `withWriteLock`. + mapStatuses = cachedSerializedMapStatus + } + mapStatuses + } + + private def serializeAndCacheMergeStatuses( + broadcastManager: BroadcastManager, + isLocal: Boolean, + minBroadcastSize: Int, + conf: SparkConf): Array[Byte] = { + var mergeStatuses: Array[Byte] = null + withWriteLock { + if (cachedSerializedMergeStatus == null) { + val serResult = MapOutputTracker.serializeOutputStatuses( + mergeStatuses, broadcastManager, isLocal, minBroadcastSize, conf) + cachedSerializedMergeStatus = serResult._1 + cachedSerializedBroadcastMergeStatus = serResult._2 + } + + // The following line has to be outside if statement since it's possible that another + // thread initializes cachedSerializedMergeStatus in-between `withReadLock` and + // `withWriteLock`. + mergeStatuses = cachedSerializedMergeStatus } - result + mergeStatuses } // Used in testing. @@ -382,14 +420,14 @@ private class ShuffleStatus( private[spark] sealed trait MapOutputTrackerMessage private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage -private[spark] case class GetMergeResultStatuses(shuffleId: Int) +private[spark] case class GetMapAndMergeResultStatuses(shuffleId: Int) extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage private[spark] sealed trait MapOutputTrackerMasterMessage private[spark] case class GetMapStatusMessage(shuffleId: Int, context: RpcCallContext) extends MapOutputTrackerMasterMessage -private[spark] case class GetMergeStatusMessage(shuffleId: Int, +private[spark] case class GetMapAndMergeStatusMessage(shuffleId: Int, context: RpcCallContext) extends MapOutputTrackerMasterMessage /** RpcEndpoint class for MapOutputTrackerMaster */ @@ -405,10 +443,10 @@ private[spark] class MapOutputTrackerMasterEndpoint( logInfo(s"Asked to send map output locations for shuffle $shuffleId to $hostPort") tracker.post(GetMapStatusMessage(shuffleId, context)) - case GetMergeResultStatuses(shuffleId: Int) => + case GetMapAndMergeResultStatuses(shuffleId: Int) => val hostPort = context.senderAddress.hostPort logInfo(s"Asked to send merge result locations for shuffle $shuffleId to $hostPort") - tracker.post(GetMergeStatusMessage(shuffleId, context)) + tracker.post(GetMapAndMergeStatusMessage(shuffleId, context)) case StopMapOutputTracker => logInfo("MapOutputTrackerMasterEndpoint stopped!") @@ -633,7 +671,7 @@ private[spark] class MapOutputTrackerMaster( data match { case GetMapStatusMessage(shuffleId, context) => handleStatusMessage(shuffleId, context, true) - case GetMergeStatusMessage(shuffleId, context) => + case GetMapAndMergeStatusMessage(shuffleId, context) => handleStatusMessage(shuffleId, context, false) } } catch { @@ -1145,9 +1183,9 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr var fetchedMapStatuses = mapStatuses.get(shuffleId).orNull if (fetchedMapStatuses == null) { logInfo("Doing the map fetch; tracker endpoint = " + trackerEndpoint) - val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) + val fetchedBytes = askTracker[(Array[Byte], Array[Byte])](GetMapOutputStatuses(shuffleId)) try { - fetchedMapStatuses = MapOutputTracker.deserializeOutputStatuses(fetchedBytes, conf) + fetchedMapStatuses = MapOutputTracker.deserializeOutputStatuses(fetchedBytes._1, conf) } catch { case e: SparkException => throw new MetadataFetchFailedException(shuffleId, -1, @@ -1157,12 +1195,13 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr logInfo("Got the map output locations") mapStatuses.put(shuffleId, fetchedMapStatuses) } - var fetchedMergeStatues = mergeStatuses.get(shuffleId).orNull - if (fetchMergeResult && fetchedMergeStatues == null) { + var fetchedMergeStatuses = mergeStatuses.get(shuffleId).orNull + if (fetchMergeResult && fetchedMergeStatuses == null) { logInfo("Doing the merge fetch; tracker endpoint = " + trackerEndpoint) - val fetchedBytes = askTracker[Array[Byte]](GetMergeResultStatuses(shuffleId)) + val fetchedBytes = + askTracker[(Array[Byte], Array[Byte])](GetMapAndMergeResultStatuses(shuffleId)) try { - fetchedMergeStatues = MapOutputTracker.deserializeOutputStatuses(fetchedBytes, conf) + fetchedMergeStatuses = MapOutputTracker.deserializeOutputStatuses(fetchedBytes._2, conf) } catch { case e: SparkException => throw new MetadataFetchFailedException(shuffleId, -1, @@ -1170,11 +1209,11 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr e.getCause) } logInfo("Got the merge output locations") - mergeStatuses.put(shuffleId, fetchedMergeStatues) + mergeStatuses.put(shuffleId, fetchedMergeStatuses) } logDebug(s"Fetching map ${if (fetchMergeResult) "/merge"} for shuffle $shuffleId took " + s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms") - (fetchedMapStatuses, fetchedMergeStatues) + (fetchedMapStatuses, fetchedMergeStatuses) } } else { (mapOutputStatuses, mergeResultStatuses) From cd6a82c67bf5c5fa382e6715c81e0b80f351741b Mon Sep 17 00:00:00 2001 From: Venkata krishnan Sowrirajan Date: Tue, 6 Apr 2021 12:57:34 -0700 Subject: [PATCH 17/24] getStatuses master changes --- .../org/apache/spark/MapOutputTracker.scala | 44 ++++++------------- 1 file changed, 13 insertions(+), 31 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 07d3bbaa07b6..937d19d93fe5 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -1172,51 +1172,33 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr * * (It would be nice to remove this restriction in the future.) */ - private def getStatuses( - shuffleId: Int, conf: SparkConf): (Array[MapStatus], Array[MergeStatus]) = { - val mapOutputStatuses = mapStatuses.get(shuffleId).orNull - val mergeResultStatuses = mergeStatuses.get(shuffleId).orNull - if (mapOutputStatuses == null || (fetchMergeResult && mergeResultStatuses == null)) { - logInfo(s"Don't have map/merge outputs for shuffle $shuffleId, fetching them") + private def getStatuses(shuffleId: Int, conf: SparkConf): Array[MapStatus] = { + val statuses = mapStatuses.get(shuffleId).orNull + if (statuses == null) { + logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") val startTimeNs = System.nanoTime() fetchingLock.withLock(shuffleId) { - var fetchedMapStatuses = mapStatuses.get(shuffleId).orNull - if (fetchedMapStatuses == null) { - logInfo("Doing the map fetch; tracker endpoint = " + trackerEndpoint) + var fetchedStatuses = mapStatuses.get(shuffleId).orNull + if (fetchedStatuses == null) { + logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) val fetchedBytes = askTracker[(Array[Byte], Array[Byte])](GetMapOutputStatuses(shuffleId)) try { - fetchedMapStatuses = MapOutputTracker.deserializeOutputStatuses(fetchedBytes._1, conf) + fetchedStatuses = MapOutputTracker.deserializeOutputStatuses(fetchedBytes._1, conf) } catch { case e: SparkException => throw new MetadataFetchFailedException(shuffleId, -1, s"Unable to deserialize broadcasted map statuses for shuffle $shuffleId: " + e.getCause) } - logInfo("Got the map output locations") - mapStatuses.put(shuffleId, fetchedMapStatuses) + logInfo("Got the output locations") + mapStatuses.put(shuffleId, fetchedStatuses) } - var fetchedMergeStatuses = mergeStatuses.get(shuffleId).orNull - if (fetchMergeResult && fetchedMergeStatuses == null) { - logInfo("Doing the merge fetch; tracker endpoint = " + trackerEndpoint) - val fetchedBytes = - askTracker[(Array[Byte], Array[Byte])](GetMapAndMergeResultStatuses(shuffleId)) - try { - fetchedMergeStatuses = MapOutputTracker.deserializeOutputStatuses(fetchedBytes._2, conf) - } catch { - case e: SparkException => - throw new MetadataFetchFailedException(shuffleId, -1, - s"Unable to deserialize broadcasted merge statuses for shuffle $shuffleId: " + - e.getCause) - } - logInfo("Got the merge output locations") - mergeStatuses.put(shuffleId, fetchedMergeStatuses) - } - logDebug(s"Fetching map ${if (fetchMergeResult) "/merge"} for shuffle $shuffleId took " + + logDebug(s"Fetching map output statuses for shuffle $shuffleId took " + s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms") - (fetchedMapStatuses, fetchedMergeStatuses) + fetchedStatuses } } else { - (mapOutputStatuses, mergeResultStatuses) + statuses } } From 04e0e27d4457f96a4f38e68bf503b0e7f6fc57fd Mon Sep 17 00:00:00 2001 From: Venkata krishnan Sowrirajan Date: Mon, 12 Apr 2021 11:07:04 -0700 Subject: [PATCH 18/24] Combine MapStatus and MergeStatus fetch into a single RPC --- .../org/apache/spark/MapOutputTracker.scala | 112 ++++++++++++------ .../apache/spark/MapOutputTrackerSuite.scala | 6 +- 2 files changed, 82 insertions(+), 36 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 937d19d93fe5..e07852fc7066 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -20,14 +20,17 @@ package org.apache.spark import java.io.{ByteArrayInputStream, IOException, ObjectInputStream, ObjectOutputStream} import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit} import java.util.concurrent.locks.ReentrantReadWriteLock + import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, ListBuffer, Map} import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration import scala.reflect.ClassTag import scala.util.control.NonFatal + import org.apache.commons.io.output.{ByteArrayOutputStream => ApacheByteArrayOutputStream} import org.roaringbitmap.RoaringBitmap + import org.apache.spark.broadcast.{Broadcast, BroadcastManager} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ @@ -312,6 +315,7 @@ private class ShuffleStatus( serializeAndCacheMapStatuses(broadcastManager, isLocal, minBroadcastSize, conf) } } else { + // If push based shuffle enabled, serialize and cache both Map and Merge Status if (mapStatuses == null) { mapStatuses = serializeAndCacheMapStatuses(broadcastManager, isLocal, minBroadcastSize, conf) @@ -330,10 +334,10 @@ private class ShuffleStatus( isLocal: Boolean, minBroadcastSize: Int, conf: SparkConf): Array[Byte] = { - var mapStatuses: Array[Byte] = null + var mapStatusesBytes: Array[Byte] = null withWriteLock { if (cachedSerializedMapStatus == null) { - val serResult = MapOutputTracker.serializeOutputStatuses( + val serResult = MapOutputTracker.serializeOutputStatuses[MapStatus]( mapStatuses, broadcastManager, isLocal, minBroadcastSize, conf) cachedSerializedMapStatus = serResult._1 cachedSerializedBroadcast = serResult._2 @@ -341,9 +345,9 @@ private class ShuffleStatus( // The following line has to be outside if statement since it's possible that another // thread initializes cachedSerializedMapStatus in-between `withReadLock` and // `withWriteLock`. - mapStatuses = cachedSerializedMapStatus + mapStatusesBytes = cachedSerializedMapStatus } - mapStatuses + mapStatusesBytes } private def serializeAndCacheMergeStatuses( @@ -351,10 +355,10 @@ private class ShuffleStatus( isLocal: Boolean, minBroadcastSize: Int, conf: SparkConf): Array[Byte] = { - var mergeStatuses: Array[Byte] = null + var mergeStatusesBytes: Array[Byte] = null withWriteLock { if (cachedSerializedMergeStatus == null) { - val serResult = MapOutputTracker.serializeOutputStatuses( + val serResult = MapOutputTracker.serializeOutputStatuses[MergeStatus]( mergeStatuses, broadcastManager, isLocal, minBroadcastSize, conf) cachedSerializedMergeStatus = serResult._1 cachedSerializedBroadcastMergeStatus = serResult._2 @@ -363,9 +367,9 @@ private class ShuffleStatus( // The following line has to be outside if statement since it's possible that another // thread initializes cachedSerializedMergeStatus in-between `withReadLock` and // `withWriteLock`. - mergeStatuses = cachedSerializedMergeStatus + mergeStatusesBytes = cachedSerializedMergeStatus } - mergeStatuses + mergeStatusesBytes } // Used in testing. @@ -1167,42 +1171,84 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr } /** - * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize + * Get or fetch the array of MapStatuses and MergeStatuses if push based shuffle enabled + * for a given shuffle ID. NOTE: clients MUST synchronize * on this array when reading it, because on the driver, we may be changing it in place. * * (It would be nice to remove this restriction in the future.) */ - private def getStatuses(shuffleId: Int, conf: SparkConf): Array[MapStatus] = { - val statuses = mapStatuses.get(shuffleId).orNull - if (statuses == null) { - logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") - val startTimeNs = System.nanoTime() - fetchingLock.withLock(shuffleId) { - var fetchedStatuses = mapStatuses.get(shuffleId).orNull - if (fetchedStatuses == null) { - logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) - val fetchedBytes = askTracker[(Array[Byte], Array[Byte])](GetMapOutputStatuses(shuffleId)) - try { - fetchedStatuses = MapOutputTracker.deserializeOutputStatuses(fetchedBytes._1, conf) - } catch { - case e: SparkException => - throw new MetadataFetchFailedException(shuffleId, -1, - s"Unable to deserialize broadcasted map statuses for shuffle $shuffleId: " + - e.getCause) + private def getStatuses( + shuffleId: Int, + conf: SparkConf): (Array[MapStatus], Array[MergeStatus]) = { + if (fetchMergeResult) { + val mapOutputStatuses = mapStatuses.get(shuffleId).orNull + val mergeOutputStatuses = mergeStatuses.get(shuffleId).orNull + + if (mapOutputStatuses == null || mergeOutputStatuses == null) { + logInfo("Don't have map/merge outputs for shuffle " + shuffleId + ", fetching them") + val startTimeNs = System.nanoTime() + fetchingLock.withLock(shuffleId) { + var fetchedMapStatuses = mapStatuses.get(shuffleId).orNull + var fetchedMergeStatuses = mergeStatuses.get(shuffleId).orNull + if (fetchedMapStatuses == null || fetchedMergeStatuses == null) { + logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) + val fetchedBytes = + askTracker[(Array[Byte], Array[Byte])](GetMapAndMergeResultStatuses(shuffleId)) + try { + fetchedMapStatuses = + MapOutputTracker.deserializeOutputStatuses[MapStatus](fetchedBytes._1, conf) + fetchedMergeStatuses = + MapOutputTracker.deserializeOutputStatuses[MergeStatus](fetchedBytes._2, conf) + } catch { + case e: SparkException => + throw new MetadataFetchFailedException(shuffleId, -1, + s"Unable to deserialize broadcasted map/merge statuses" + + s" for shuffle $shuffleId: " + e.getCause) + } + logInfo("Got the map/merge output locations") + mapStatuses.put(shuffleId, fetchedMapStatuses) + mergeStatuses.put(shuffleId, fetchedMergeStatuses) } - logInfo("Got the output locations") - mapStatuses.put(shuffleId, fetchedStatuses) + logDebug(s"Fetching map/merge output statuses for shuffle $shuffleId took " + + s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms") + (fetchedMapStatuses, fetchedMergeStatuses) } - logDebug(s"Fetching map output statuses for shuffle $shuffleId took " + - s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms") - fetchedStatuses + } else { + (mapOutputStatuses, mergeOutputStatuses) } } else { - statuses + val statuses = mapStatuses.get(shuffleId).orNull + if (statuses == null) { + logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") + val startTimeNs = System.nanoTime() + fetchingLock.withLock(shuffleId) { + var fetchedStatuses = mapStatuses.get(shuffleId).orNull + if (fetchedStatuses == null) { + logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) + val fetchedBytes = + askTracker[(Array[Byte], Array[Byte])](GetMapOutputStatuses(shuffleId)) + try { + fetchedStatuses = + MapOutputTracker.deserializeOutputStatuses[MapStatus](fetchedBytes._1, conf) + } catch { + case e: SparkException => + throw new MetadataFetchFailedException(shuffleId, -1, + s"Unable to deserialize broadcasted map statuses for shuffle $shuffleId: " + + e.getCause) + } + logInfo("Got the map output locations") + mapStatuses.put(shuffleId, fetchedStatuses) + } + logDebug(s"Fetching map output statuses for shuffle $shuffleId took " + + s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms") + (fetchedStatuses, null) + } + } else { + (statuses, null) + } } } - /** Unregister shuffle data. */ def unregisterShuffle(shuffleId: Int): Unit = { mapStatuses.remove(shuffleId) diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 93a6f5345f42..9d0f7f2113b6 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -573,8 +573,8 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { mapWorkerRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) val fetchedBytes = mapWorkerTracker.trackerEndpoint - .askSync[Array[Byte]](GetMapOutputStatuses(20)) - assert(fetchedBytes(0) == 1) + .askSync[(Array[Byte], Array[Byte])](GetMapOutputStatuses(20)) + assert(fetchedBytes._1(0) == 1) // Normally `unregisterMapOutput` triggers the destroy of broadcasted value. // But the timing of destroying broadcasted value is indeterminate, we manually destroy @@ -583,7 +583,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { shuffleStatus.cachedSerializedBroadcast.destroy(true) } val err = intercept[SparkException] { - MapOutputTracker.deserializeOutputStatuses(fetchedBytes, conf) + MapOutputTracker.deserializeOutputStatuses[MapStatus](fetchedBytes._1, conf) } assert(err.getMessage.contains("Unable to deserialize broadcasted output statuses")) } From 0e36f807f53a1eecae7cd377df85452af6791ce7 Mon Sep 17 00:00:00 2001 From: Venkata krishnan Sowrirajan Date: Mon, 12 Apr 2021 12:55:05 -0700 Subject: [PATCH 19/24] Added TODO comment wrt AQE improvements --- core/src/main/scala/org/apache/spark/MapOutputTracker.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index e07852fc7066..123d66f8c051 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -449,7 +449,7 @@ private[spark] class MapOutputTrackerMasterEndpoint( case GetMapAndMergeResultStatuses(shuffleId: Int) => val hostPort = context.senderAddress.hostPort - logInfo(s"Asked to send merge result locations for shuffle $shuffleId to $hostPort") + logInfo(s"Asked to send map/merge result locations for shuffle $shuffleId to $hostPort") tracker.post(GetMapAndMergeStatusMessage(shuffleId, context)) case StopMapOutputTracker => @@ -1408,6 +1408,9 @@ private[spark] object MapOutputTracker extends Logging { // subrange requests. However, when a reduce task needs to fetch blocks from a subrange of // map outputs, it usually indicates skewed partitions which push-based shuffle delegates // to AQE to handle. + // TODO: SPARK-35036: Instead of reading map blocks in case of AQE with Push based shuffle, + // TODO: improve push based shuffle to read partial merged blocks satisfying the start/end + // TODO: map indexes if (mergeStatuses.isDefined && startMapIndex == 0 && endMapIndex == mapStatuses.length) { // We have MergeStatus and full range of mapIds are requested so return a merged block. val numMaps = mapStatuses.length From 336765fef69ebe9be38b3044a5449cdb67800105 Mon Sep 17 00:00:00 2001 From: Venkata krishnan Sowrirajan Date: Mon, 12 Apr 2021 13:17:38 -0700 Subject: [PATCH 20/24] Address otterc comment --- .../org/apache/spark/MapOutputTracker.scala | 30 +++++++++++++------ .../apache/spark/MapOutputTrackerSuite.scala | 2 +- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 123d66f8c051..3308c1417a1e 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -749,17 +749,29 @@ private[spark] class MapOutputTrackerMaster( } } - def unregisterMergeResult(shuffleId: Int, reduceId: Int, bmAddress: BlockManagerId) { + /** + * Unregisters a merge result corresponding to the reduceId if present. If the optional mapId + * is specified, it will only unregister the merge result if the mapId is part of that merge + * result. + * + * @param shuffleId the shuffleId. + * @param reduceId the reduceId. + * @param bmAddress block manager address. + * @param mapId the optional mapId which should be checked to see it was part of the merge + * result. + */ + def unregisterMergeResult( + shuffleId: Int, + reduceId: Int, + bmAddress: BlockManagerId, + mapId: Option[Int] = None) { shuffleStatuses.get(shuffleId) match { case Some(shuffleStatus) => - shuffleStatus.removeMergeResult(reduceId, bmAddress) - // Here we share the same epoch for both map outputs and merge results. This means - // that even if we are only unregistering map output, this would also clear the executor - // side cached merge statuses and lead to executors re-fetching the merge statuses which - // hasn't changed, and vise versa. This is a reasonable compromise to prevent complicating - // how the epoch is currently used and to make sure the executor is always working with - // a pair of matching map statuses and merge statuses for each shuffle. - incrementEpoch() + val mergeStatus = shuffleStatus.mergeStatuses(reduceId) + if (mergeStatus != null && (mapId.isEmpty || mergeStatus.tracker.contains(mapId.get))) { + shuffleStatus.removeMergeResult(reduceId, bmAddress) + incrementEpoch() + } case None => throw new SparkException("unregisterMergeResult called for nonexistent shuffle ID") } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 9d0f7f2113b6..95ecc3c00278 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -352,7 +352,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { tracker.registerMergeResult(10, 1, MergeStatus(BlockManagerId("b", "hostB", 1000), bitmap, 1000L)) assert(tracker.getNumAvailableMergeResults(10) == 2) - tracker.unregisterMergeResult(10, 0, BlockManagerId("a", "hostA", 1000)); + tracker.unregisterMergeResult(10, 0, BlockManagerId("a", "hostA", 1000)) assert(tracker.getNumAvailableMergeResults(10) == 1) tracker.stop() rpcEnv.shutdown() From 351ae5318d0e064e1fd0be1f70260bd418f6b36e Mon Sep 17 00:00:00 2001 From: Venkata krishnan Sowrirajan Date: Tue, 13 Apr 2021 09:59:32 -0700 Subject: [PATCH 21/24] MapOutputTrackerSuite test --- .../apache/spark/MapOutputTrackerSuite.scala | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 95ecc3c00278..4ef343662f0e 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -588,4 +588,33 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { assert(err.getMessage.contains("Unable to deserialize broadcasted output statuses")) } } + + test("SPARK-32921: unregister merge result if it is present and contains the map Id") { + val rpcEnv = createRpcEnv("test") + val tracker = newTrackerMaster() + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + tracker.registerShuffle(10, 4, 2) + assert(tracker.containsShuffle(10)) + val bitmap1 = new RoaringBitmap() + bitmap1.add(0) + bitmap1.add(1) + tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000), + bitmap1, 1000L)) + + val bitmap2 = new RoaringBitmap() + bitmap2.add(5) + bitmap2.add(6) + tracker.registerMergeResult(10, 1, MergeStatus(BlockManagerId("b", "hostB", 1000), + bitmap2, 1000L)) + assert(tracker.getNumAvailableMergeResults(10) == 2) + tracker.unregisterMergeResult(10, 0, BlockManagerId("a", "hostA", 1000), Option(0)) + assert(tracker.getNumAvailableMergeResults(10) == 1) + tracker.unregisterMergeResult(10, 1, BlockManagerId("b", "hostB", 1000), Option(1)) + assert(tracker.getNumAvailableMergeResults(10) == 1) + tracker.unregisterMergeResult(10, 1, BlockManagerId("b", "hostB", 1000), Option(5)) + assert(tracker.getNumAvailableMergeResults(10) == 0) + tracker.stop() + rpcEnv.shutdown() + } } From 9614a0c4124b521915923d34ea192bfde4eddcc2 Mon Sep 17 00:00:00 2001 From: Venkata krishnan Sowrirajan Date: Tue, 13 Apr 2021 11:44:41 -0700 Subject: [PATCH 22/24] Additional tests to test protocol changes --- .../org/apache/spark/MapOutputTracker.scala | 19 +++----- .../apache/spark/MapOutputTrackerSuite.scala | 47 +++++++++++++++++++ 2 files changed, 54 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 3308c1417a1e..03f8e0b6e1f6 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -289,12 +289,12 @@ private class ShuffleStatus( isLocal: Boolean, minBroadcastSize: Int, conf: SparkConf, - isMapOutput: Boolean): (Array[Byte], Array[Byte]) = { + isMapOnlyOutput: Boolean): (Array[Byte], Array[Byte]) = { var mapStatuses: Array[Byte] = null var mergeStatuses: Array[Byte] = null withReadLock { - if (isMapOutput) { + if (isMapOnlyOutput) { if (cachedSerializedMapStatus != null) { mapStatuses = cachedSerializedMapStatus } @@ -309,7 +309,7 @@ private class ShuffleStatus( } } - if (isMapOutput) { + if (isMapOnlyOutput) { if (mapStatuses == null) { mapStatuses = serializeAndCacheMapStatuses(broadcastManager, isLocal, minBroadcastSize, conf) @@ -646,19 +646,14 @@ private[spark] class MapOutputTrackerMaster( private def handleStatusMessage( shuffleId: Int, context: RpcCallContext, - isMapOutput: Boolean): Unit = { + isMapOnlyOutput: Boolean): Unit = { val hostPort = context.senderAddress.hostPort val shuffleStatus = shuffleStatuses.get(shuffleId).head - val mapOrMerge = if (isMapOutput) { - "map" - } else { - "merge" - } - logDebug(s"Handling request to send $mapOrMerge output locations" + - s" for shuffle $shuffleId to $hostPort") + logDebug(s"Handling request to send ${if (isMapOnlyOutput) "map" else "map/merge"}" + + s" output locations for shuffle $shuffleId to $hostPort") context.reply( shuffleStatus.serializedOutputStatus(broadcastManager, isLocal, - minSizeForBroadcast, conf, isMapOutput = isMapOutput)) + minSizeForBroadcast, conf, isMapOnlyOutput = isMapOnlyOutput)) } override def run(): Unit = { diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 4ef343662f0e..a74c0d896502 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -589,6 +589,53 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { } } + test("SPARK-32921: test new protocol changes fetching both Map and Merge status in single RPC") { + val newConf = new SparkConf + newConf.set(RPC_MESSAGE_MAX_SIZE, 1) + newConf.set(RPC_ASK_TIMEOUT, "1") // Fail fast + newConf.set(SHUFFLE_MAPOUTPUT_MIN_SIZE_FOR_BROADCAST, 10240L) // 10 KiB << 1MiB framesize + newConf.set(PUSH_BASED_SHUFFLE_ENABLED, true) + newConf.set(IS_TESTING, true) + + // needs TorrentBroadcast so need a SparkContext + withSpark(new SparkContext("local", "MapOutputTrackerSuite", newConf)) { sc => + val masterTracker = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + val rpcEnv = sc.env.rpcEnv + val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) + rpcEnv.stop(masterTracker.trackerEndpoint) + rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) + val bitmap1 = new RoaringBitmap() + bitmap1.add(1) + + masterTracker.registerShuffle(20, 100, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) + (0 until 100).foreach { i => + masterTracker.registerMapOutput(20, i, new CompressedMapStatus( + BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 5)) + } + masterTracker.registerMergeResult(20, 0, MergeStatus(BlockManagerId("999", "mps", 1000), + bitmap1, 1000L)) + + val mapWorkerRpcEnv = createRpcEnv("spark-worker", "localhost", 0, new SecurityManager(conf)) + val mapWorkerTracker = new MapOutputTrackerWorker(conf) + mapWorkerTracker.trackerEndpoint = + mapWorkerRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + + val fetchedBytes = mapWorkerTracker.trackerEndpoint + .askSync[(Array[Byte], Array[Byte])](GetMapAndMergeResultStatuses(20)) + assert(masterTracker.getNumAvailableMergeResults(20) == 1) + assert(masterTracker.getNumAvailableOutputs(20) == 100) + + val mapOutput = + MapOutputTracker.deserializeOutputStatuses[MapStatus](fetchedBytes._1, newConf) + val mergeOutput = + MapOutputTracker.deserializeOutputStatuses[MergeStatus](fetchedBytes._2, newConf) + assert(mapOutput.length == 100) + assert(mergeOutput.length == 1) + mapWorkerTracker.stop() + masterTracker.stop() + } + } + test("SPARK-32921: unregister merge result if it is present and contains the map Id") { val rpcEnv = createRpcEnv("test") val tracker = newTrackerMaster() From 7dd24bc38b278af77418aab436a67ee1244bafe3 Mon Sep 17 00:00:00 2001 From: Venkata krishnan Sowrirajan Date: Sun, 18 Apr 2021 09:50:36 -0700 Subject: [PATCH 23/24] Address ngone51 comments --- .../org/apache/spark/MapOutputTracker.scala | 147 ++++++++---------- .../apache/spark/MapOutputTrackerSuite.scala | 6 +- 2 files changed, 67 insertions(+), 86 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 03f8e0b6e1f6..b2b79aaa8bbb 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -105,18 +105,17 @@ private class ShuffleStatus( /** * The cached result of serializing the map statuses array. This cache is lazily populated when - * [[serializedOutputStatus]] is called. The cache is invalidated when map outputs are removed. + * [[serializedMapStatus]] is called. The cache is invalidated when map outputs are removed. */ private[this] var cachedSerializedMapStatus: Array[Byte] = _ /** - * Broadcast variable holding serialized map output statuses array. When - * [[serializedOutputStatus]] serializes the map statuses array it may detect that the result is - * too large to send in a single RPC, in which case it places the serialized array into a - * broadcast variable and then sends a serialized broadcast variable instead. This variable holds - * a reference to that broadcast variable in order to keep it from being garbage collected and - * to allow for it to be explicitly destroyed later on when the ShuffleMapStage is - * garbage-collected. + * Broadcast variable holding serialized map output statuses array. When [[serializedMapStatus]] + * serializes the map statuses array it may detect that the result is too large to send in a + * single RPC, in which case it places the serialized array into a broadcast variable and then + * sends a serialized broadcast variable instead. This variable holds a reference to that + * broadcast variable in order to keep it from being garbage collected and to allow for it to be + * explicitly destroyed later on when the ShuffleMapStage is garbage-collected. */ private[spark] var cachedSerializedBroadcast: Broadcast[Array[Byte]] = _ @@ -275,88 +274,66 @@ private class ShuffleStatus( } /** - * Serializes the mapStatuses or mergeStatuses array into an efficient compressed format. See - * the comments on `MapOutputTracker.serializeOutputStatuses()` for more details on the - * serialization format. + * Serializes the mapStatuses array into an efficient compressed format. See the comments on + * `MapOutputTracker.serializeOutputStatuses()` for more details on the serialization format. * * This method is designed to be called multiple times and implements caching in order to speed * up subsequent requests. If the cache is empty and multiple threads concurrently attempt to - * serialize the statuses array then serialization will only be performed in a single thread and + * serialize the map statuses then serialization will only be performed in a single thread and * all other threads will block until the cache is populated. */ - def serializedOutputStatus( + def serializedMapStatus( broadcastManager: BroadcastManager, isLocal: Boolean, minBroadcastSize: Int, - conf: SparkConf, - isMapOnlyOutput: Boolean): (Array[Byte], Array[Byte]) = { - var mapStatuses: Array[Byte] = null - var mergeStatuses: Array[Byte] = null - + conf: SparkConf): Array[Byte] = { + var result: Array[Byte] = null withReadLock { - if (isMapOnlyOutput) { - if (cachedSerializedMapStatus != null) { - mapStatuses = cachedSerializedMapStatus - } - } else { - if (cachedSerializedMapStatus != null) { - mapStatuses = cachedSerializedMapStatus - } - - if (cachedSerializedMergeStatus != null) { - mergeStatuses = cachedSerializedMergeStatus - } + if (cachedSerializedMapStatus != null) { + result = cachedSerializedMapStatus } } - if (isMapOnlyOutput) { - if (mapStatuses == null) { - mapStatuses = - serializeAndCacheMapStatuses(broadcastManager, isLocal, minBroadcastSize, conf) - } - } else { - // If push based shuffle enabled, serialize and cache both Map and Merge Status - if (mapStatuses == null) { - mapStatuses = - serializeAndCacheMapStatuses(broadcastManager, isLocal, minBroadcastSize, conf) - } - - if (mergeStatuses == null) { - mergeStatuses = - serializeAndCacheMergeStatuses(broadcastManager, isLocal, minBroadcastSize, conf) - } - } - (mapStatuses, mergeStatuses) - } - - private def serializeAndCacheMapStatuses( - broadcastManager: BroadcastManager, - isLocal: Boolean, - minBroadcastSize: Int, - conf: SparkConf): Array[Byte] = { - var mapStatusesBytes: Array[Byte] = null - withWriteLock { + if (result == null) withWriteLock { if (cachedSerializedMapStatus == null) { val serResult = MapOutputTracker.serializeOutputStatuses[MapStatus]( mapStatuses, broadcastManager, isLocal, minBroadcastSize, conf) cachedSerializedMapStatus = serResult._1 cachedSerializedBroadcast = serResult._2 } - // The following line has to be outside if statement since it's possible that another - // thread initializes cachedSerializedMapStatus in-between `withReadLock` and - // `withWriteLock`. - mapStatusesBytes = cachedSerializedMapStatus + // The following line has to be outside if statement since it's possible that another thread + // initializes cachedSerializedMapStatus in-between `withReadLock` and `withWriteLock`. + result = cachedSerializedMapStatus } - mapStatusesBytes + result } - private def serializeAndCacheMergeStatuses( + /** + * Serializes the mapStatuses and mergeStatuses array into an efficient compressed format. + * See the comments on `MapOutputTracker.serializeOutputStatuses()` for more details + * on the serialization format. + * + * This method is designed to be called multiple times and implements caching in order to speed + * up subsequent requests. If the cache is empty and multiple threads concurrently attempt to + * serialize the statuses array then serialization will only be performed in a single thread and + * all other threads will block until the cache is populated. + */ + def serializedMapAndMergeStatus( broadcastManager: BroadcastManager, isLocal: Boolean, minBroadcastSize: Int, - conf: SparkConf): Array[Byte] = { + conf: SparkConf): (Array[Byte], Array[Byte]) = { + val mapStatusesBytes: Array[Byte] = + serializedMapStatus(broadcastManager, isLocal, minBroadcastSize, conf) var mergeStatusesBytes: Array[Byte] = null - withWriteLock { + + withReadLock { + if (cachedSerializedMergeStatus != null) { + mergeStatusesBytes = cachedSerializedMergeStatus + } + } + + if (mergeStatusesBytes == null) withWriteLock { if (cachedSerializedMergeStatus == null) { val serResult = MapOutputTracker.serializeOutputStatuses[MergeStatus]( mergeStatuses, broadcastManager, isLocal, minBroadcastSize, conf) @@ -369,7 +346,7 @@ private class ShuffleStatus( // `withWriteLock`. mergeStatusesBytes = cachedSerializedMergeStatus } - mergeStatusesBytes + (mapStatusesBytes, mergeStatusesBytes) } // Used in testing. @@ -429,9 +406,9 @@ private[spark] case class GetMapAndMergeResultStatuses(shuffleId: Int) private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage private[spark] sealed trait MapOutputTrackerMasterMessage -private[spark] case class GetMapStatusMessage(shuffleId: Int, +private[spark] case class GetMapOutputMessage(shuffleId: Int, context: RpcCallContext) extends MapOutputTrackerMasterMessage -private[spark] case class GetMapAndMergeStatusMessage(shuffleId: Int, +private[spark] case class GetMapAndMergeOutputMessage(shuffleId: Int, context: RpcCallContext) extends MapOutputTrackerMasterMessage /** RpcEndpoint class for MapOutputTrackerMaster */ @@ -445,12 +422,12 @@ private[spark] class MapOutputTrackerMasterEndpoint( case GetMapOutputStatuses(shuffleId: Int) => val hostPort = context.senderAddress.hostPort logInfo(s"Asked to send map output locations for shuffle $shuffleId to $hostPort") - tracker.post(GetMapStatusMessage(shuffleId, context)) + tracker.post(GetMapOutputMessage(shuffleId, context)) case GetMapAndMergeResultStatuses(shuffleId: Int) => val hostPort = context.senderAddress.hostPort logInfo(s"Asked to send map/merge result locations for shuffle $shuffleId to $hostPort") - tracker.post(GetMapAndMergeStatusMessage(shuffleId, context)) + tracker.post(GetMapAndMergeOutputMessage(shuffleId, context)) case StopMapOutputTracker => logInfo("MapOutputTrackerMasterEndpoint stopped!") @@ -646,14 +623,19 @@ private[spark] class MapOutputTrackerMaster( private def handleStatusMessage( shuffleId: Int, context: RpcCallContext, - isMapOnlyOutput: Boolean): Unit = { + needMergeOutput: Boolean): Unit = { val hostPort = context.senderAddress.hostPort val shuffleStatus = shuffleStatuses.get(shuffleId).head - logDebug(s"Handling request to send ${if (isMapOnlyOutput) "map" else "map/merge"}" + + logDebug(s"Handling request to send ${if (needMergeOutput) "map" else "map/merge"}" + s" output locations for shuffle $shuffleId to $hostPort") - context.reply( - shuffleStatus.serializedOutputStatus(broadcastManager, isLocal, - minSizeForBroadcast, conf, isMapOnlyOutput = isMapOnlyOutput)) + if (needMergeOutput) { + context.reply( + shuffleStatus. + serializedMapAndMergeStatus(broadcastManager, isLocal, minSizeForBroadcast, conf)) + } else { + context.reply( + shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast, conf)) + } } override def run(): Unit = { @@ -668,10 +650,10 @@ private[spark] class MapOutputTrackerMaster( } data match { - case GetMapStatusMessage(shuffleId, context) => - handleStatusMessage(shuffleId, context, true) - case GetMapAndMergeStatusMessage(shuffleId, context) => + case GetMapOutputMessage(shuffleId, context) => handleStatusMessage(shuffleId, context, false) + case GetMapAndMergeOutputMessage(shuffleId, context) => + handleStatusMessage(shuffleId, context, true) } } catch { case NonFatal(e) => logError(e.getMessage, e) @@ -684,7 +666,7 @@ private[spark] class MapOutputTrackerMaster( } /** A poison endpoint that indicates MessageLoop should exit its message loop. */ - private val PoisonPill = GetMapStatusMessage(-99, null) + private val PoisonPill = GetMapOutputMessage(-99, null) // Used only in unit tests. private[spark] def getNumCachedSerializedBroadcast: Int = { @@ -915,7 +897,7 @@ private[spark] class MapOutputTrackerMaster( } else { Nil } - if (!preferredLoc.isEmpty) { + if (preferredLoc.nonEmpty) { preferredLoc } else { if (shuffleLocalityEnabled && dep.rdd.partitions.length < SHUFFLE_PREF_MAP_THRESHOLD && @@ -1232,11 +1214,10 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr var fetchedStatuses = mapStatuses.get(shuffleId).orNull if (fetchedStatuses == null) { logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) - val fetchedBytes = - askTracker[(Array[Byte], Array[Byte])](GetMapOutputStatuses(shuffleId)) + val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) try { fetchedStatuses = - MapOutputTracker.deserializeOutputStatuses[MapStatus](fetchedBytes._1, conf) + MapOutputTracker.deserializeOutputStatuses[MapStatus](fetchedBytes, conf) } catch { case e: SparkException => throw new MetadataFetchFailedException(shuffleId, -1, diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index a74c0d896502..f4b47e2bb0cd 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -573,8 +573,8 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { mapWorkerRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) val fetchedBytes = mapWorkerTracker.trackerEndpoint - .askSync[(Array[Byte], Array[Byte])](GetMapOutputStatuses(20)) - assert(fetchedBytes._1(0) == 1) + .askSync[Array[Byte]](GetMapOutputStatuses(20)) + assert(fetchedBytes(0) == 1) // Normally `unregisterMapOutput` triggers the destroy of broadcasted value. // But the timing of destroying broadcasted value is indeterminate, we manually destroy @@ -583,7 +583,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { shuffleStatus.cachedSerializedBroadcast.destroy(true) } val err = intercept[SparkException] { - MapOutputTracker.deserializeOutputStatuses[MapStatus](fetchedBytes._1, conf) + MapOutputTracker.deserializeOutputStatuses[MapStatus](fetchedBytes, conf) } assert(err.getMessage.contains("Unable to deserialize broadcasted output statuses")) } From d1422bdd5a7b17c69a43208c23becf76a4bad16c Mon Sep 17 00:00:00 2001 From: Venkata krishnan Sowrirajan Date: Mon, 19 Apr 2021 16:30:37 -0700 Subject: [PATCH 24/24] Address other comments --- .../org/apache/spark/MapOutputTracker.scala | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index b2b79aaa8bbb..b749d7e8626b 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -52,8 +52,7 @@ import org.apache.spark.util._ */ private class ShuffleStatus( numPartitions: Int, - numReducers: Int, - isPushBasedShuffleEnabled: Boolean = false) extends Logging { + numReducers: Int = -1) extends Logging { private val (readLock, writeLock) = { val lock = new ReentrantReadWriteLock() @@ -97,7 +96,7 @@ private class ShuffleStatus( * provides a reducer oriented view of the shuffle status specifically for the results of * merging shuffle partition blocks into per-partition merged shuffle files. */ - val mergeStatuses = if (isPushBasedShuffleEnabled) { + val mergeStatuses = if (numReducers > 0) { new Array[MergeStatus](numReducers) } else { Array.empty[MergeStatus] @@ -674,9 +673,14 @@ private[spark] class MapOutputTrackerMaster( } def registerShuffle(shuffleId: Int, numMaps: Int, numReduces: Int): Unit = { - if (shuffleStatuses.put(shuffleId, - new ShuffleStatus(numMaps, numReduces, pushBasedShuffleEnabled)).isDefined) { - throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") + if (pushBasedShuffleEnabled) { + if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps, numReduces)).isDefined) { + throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") + } + } else { + if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps)).isDefined) { + throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") + } } } @@ -1399,7 +1403,8 @@ private[spark] object MapOutputTracker extends Logging { // TODO: SPARK-35036: Instead of reading map blocks in case of AQE with Push based shuffle, // TODO: improve push based shuffle to read partial merged blocks satisfying the start/end // TODO: map indexes - if (mergeStatuses.isDefined && startMapIndex == 0 && endMapIndex == mapStatuses.length) { + if (mergeStatuses.exists(_.nonEmpty) && startMapIndex == 0 + && endMapIndex == mapStatuses.length) { // We have MergeStatus and full range of mapIds are requested so return a merged block. val numMaps = mapStatuses.length mergeStatuses.get.zipWithIndex.slice(startPartition, endPartition).foreach { @@ -1413,7 +1418,8 @@ private[spark] object MapOutputTracker extends Logging { ((ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, partId), mergeStatus.totalSize, -1)) // For the "holes" in this pre-merged shuffle partition, i.e., unmerged mapper // shuffle partition blocks, fetch the original map produced shuffle partition blocks - mergeStatus.getMissingMaps(numMaps).map(mapStatuses.zipWithIndex) + val mapStatusesWithIndex = mapStatuses.zipWithIndex + mergeStatus.getMissingMaps(numMaps).map(mapStatusesWithIndex) } else { // If MergeStatus is not available for the given partition, fall back to // fetching all the original mapper shuffle partition blocks