diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index ce71c2c7bc306..b749d7e8626ba 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -29,13 +29,14 @@ import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.commons.io.output.{ByteArrayOutputStream => ApacheByteArrayOutputStream} +import org.roaringbitmap.RoaringBitmap import org.apache.spark.broadcast.{Broadcast, BroadcastManager} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.io.CompressionCodec import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} -import org.apache.spark.scheduler.MapStatus +import org.apache.spark.scheduler.{MapStatus, MergeStatus, ShuffleOutputStatus} import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util._ @@ -49,7 +50,9 @@ import org.apache.spark.util._ * * All public methods of this class are thread-safe. */ -private class ShuffleStatus(numPartitions: Int) extends Logging { +private class ShuffleStatus( + numPartitions: Int, + numReducers: Int = -1) extends Logging { private val (readLock, writeLock) = { val lock = new ReentrantReadWriteLock() @@ -86,6 +89,19 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { // Exposed for testing val mapStatuses = new Array[MapStatus](numPartitions) + /** + * MergeStatus for each shuffle partition when push-based shuffle is enabled. The index of the + * array is the shuffle partition id (reduce id). Each value in the array is the MergeStatus for + * a shuffle partition, or null if not available. When push-based shuffle is enabled, this array + * provides a reducer oriented view of the shuffle status specifically for the results of + * merging shuffle partition blocks into per-partition merged shuffle files. + */ + val mergeStatuses = if (numReducers > 0) { + new Array[MergeStatus](numReducers) + } else { + Array.empty[MergeStatus] + } + /** * The cached result of serializing the map statuses array. This cache is lazily populated when * [[serializedMapStatus]] is called. The cache is invalidated when map outputs are removed. @@ -102,12 +118,24 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { */ private[spark] var cachedSerializedBroadcast: Broadcast[Array[Byte]] = _ + /** + * Similar to cachedSerializedMapStatus and cachedSerializedBroadcast, but for MergeStatus. + */ + private[this] var cachedSerializedMergeStatus: Array[Byte] = _ + + private[this] var cachedSerializedBroadcastMergeStatus: Broadcast[Array[Byte]] = _ + /** * Counter tracking the number of partitions that have output. This is a performance optimization * to avoid having to count the number of non-null entries in the `mapStatuses` array and should * be equivalent to`mapStatuses.count(_ ne null)`. */ - private[this] var _numAvailableOutputs: Int = 0 + private[this] var _numAvailableMapOutputs: Int = 0 + + /** + * Counter tracking the number of MergeStatus results received so far from the shuffle services. + */ + private[this] var _numAvailableMergeResults: Int = 0 /** * Register a map output. If there is already a registered location for the map output then it @@ -115,7 +143,7 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { */ def addMapOutput(mapIndex: Int, status: MapStatus): Unit = withWriteLock { if (mapStatuses(mapIndex) == null) { - _numAvailableOutputs += 1 + _numAvailableMapOutputs += 1 invalidateSerializedMapOutputStatusCache() } mapStatuses(mapIndex) = status @@ -149,12 +177,36 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { def removeMapOutput(mapIndex: Int, bmAddress: BlockManagerId): Unit = withWriteLock { logDebug(s"Removing existing map output ${mapIndex} ${bmAddress}") if (mapStatuses(mapIndex) != null && mapStatuses(mapIndex).location == bmAddress) { - _numAvailableOutputs -= 1 + _numAvailableMapOutputs -= 1 mapStatuses(mapIndex) = null invalidateSerializedMapOutputStatusCache() } } + /** + * Register a merge result. + */ + def addMergeResult(reduceId: Int, status: MergeStatus): Unit = withWriteLock { + if (mergeStatuses(reduceId) != status) { + _numAvailableMergeResults += 1 + invalidateSerializedMergeOutputStatusCache() + } + mergeStatuses(reduceId) = status + } + + // TODO support updateMergeResult for similar use cases as updateMapOutput + + /** + * Remove the merge result which was served by the specified block manager. + */ + def removeMergeResult(reduceId: Int, bmAddress: BlockManagerId): Unit = withWriteLock { + if (mergeStatuses(reduceId) != null && mergeStatuses(reduceId).location == bmAddress) { + _numAvailableMergeResults -= 1 + mergeStatuses(reduceId) = null + invalidateSerializedMergeOutputStatusCache() + } + } + /** * Removes all shuffle outputs associated with this host. Note that this will also remove * outputs which are served by an external shuffle server (if one exists). @@ -181,18 +233,33 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { def removeOutputsByFilter(f: BlockManagerId => Boolean): Unit = withWriteLock { for (mapIndex <- mapStatuses.indices) { if (mapStatuses(mapIndex) != null && f(mapStatuses(mapIndex).location)) { - _numAvailableOutputs -= 1 + _numAvailableMapOutputs -= 1 mapStatuses(mapIndex) = null invalidateSerializedMapOutputStatusCache() } } + for (reduceId <- mergeStatuses.indices) { + if (mergeStatuses(reduceId) != null && f(mergeStatuses(reduceId).location)) { + _numAvailableMergeResults -= 1 + mergeStatuses(reduceId) = null + invalidateSerializedMergeOutputStatusCache() + } + } + } + + /** + * Number of partitions that have shuffle map outputs. + */ + def numAvailableMapOutputs: Int = withReadLock { + _numAvailableMapOutputs } /** - * Number of partitions that have shuffle outputs. + * Number of shuffle partitions that have already been merge finalized when push-based + * is enabled. */ - def numAvailableOutputs: Int = withReadLock { - _numAvailableOutputs + def numAvailableMergeResults: Int = withReadLock { + _numAvailableMergeResults } /** @@ -200,19 +267,19 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { */ def findMissingPartitions(): Seq[Int] = withReadLock { val missing = (0 until numPartitions).filter(id => mapStatuses(id) == null) - assert(missing.size == numPartitions - _numAvailableOutputs, - s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}") + assert(missing.size == numPartitions - _numAvailableMapOutputs, + s"${missing.size} missing, expected ${numPartitions - _numAvailableMapOutputs}") missing } /** * Serializes the mapStatuses array into an efficient compressed format. See the comments on - * `MapOutputTracker.serializeMapStatuses()` for more details on the serialization format. + * `MapOutputTracker.serializeOutputStatuses()` for more details on the serialization format. * * This method is designed to be called multiple times and implements caching in order to speed * up subsequent requests. If the cache is empty and multiple threads concurrently attempt to - * serialize the map statuses then serialization will only be performed in a single thread and all - * other threads will block until the cache is populated. + * serialize the map statuses then serialization will only be performed in a single thread and + * all other threads will block until the cache is populated. */ def serializedMapStatus( broadcastManager: BroadcastManager, @@ -220,7 +287,6 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { minBroadcastSize: Int, conf: SparkConf): Array[Byte] = { var result: Array[Byte] = null - withReadLock { if (cachedSerializedMapStatus != null) { result = cachedSerializedMapStatus @@ -229,7 +295,7 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { if (result == null) withWriteLock { if (cachedSerializedMapStatus == null) { - val serResult = MapOutputTracker.serializeMapStatuses( + val serResult = MapOutputTracker.serializeOutputStatuses[MapStatus]( mapStatuses, broadcastManager, isLocal, minBroadcastSize, conf) cachedSerializedMapStatus = serResult._1 cachedSerializedBroadcast = serResult._2 @@ -241,6 +307,47 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { result } + /** + * Serializes the mapStatuses and mergeStatuses array into an efficient compressed format. + * See the comments on `MapOutputTracker.serializeOutputStatuses()` for more details + * on the serialization format. + * + * This method is designed to be called multiple times and implements caching in order to speed + * up subsequent requests. If the cache is empty and multiple threads concurrently attempt to + * serialize the statuses array then serialization will only be performed in a single thread and + * all other threads will block until the cache is populated. + */ + def serializedMapAndMergeStatus( + broadcastManager: BroadcastManager, + isLocal: Boolean, + minBroadcastSize: Int, + conf: SparkConf): (Array[Byte], Array[Byte]) = { + val mapStatusesBytes: Array[Byte] = + serializedMapStatus(broadcastManager, isLocal, minBroadcastSize, conf) + var mergeStatusesBytes: Array[Byte] = null + + withReadLock { + if (cachedSerializedMergeStatus != null) { + mergeStatusesBytes = cachedSerializedMergeStatus + } + } + + if (mergeStatusesBytes == null) withWriteLock { + if (cachedSerializedMergeStatus == null) { + val serResult = MapOutputTracker.serializeOutputStatuses[MergeStatus]( + mergeStatuses, broadcastManager, isLocal, minBroadcastSize, conf) + cachedSerializedMergeStatus = serResult._1 + cachedSerializedBroadcastMergeStatus = serResult._2 + } + + // The following line has to be outside if statement since it's possible that another + // thread initializes cachedSerializedMergeStatus in-between `withReadLock` and + // `withWriteLock`. + mergeStatusesBytes = cachedSerializedMergeStatus + } + (mapStatusesBytes, mergeStatusesBytes) + } + // Used in testing. def hasCachedSerializedBroadcast: Boolean = withReadLock { cachedSerializedBroadcast != null @@ -254,6 +361,10 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { f(mapStatuses) } + def withMergeStatuses[T](f: Array[MergeStatus] => T): T = withReadLock { + f(mergeStatuses) + } + /** * Clears the cached serialized map output statuses. */ @@ -269,14 +380,35 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { } cachedSerializedMapStatus = null } + + /** + * Clears the cached serialized merge result statuses. + */ + def invalidateSerializedMergeOutputStatusCache(): Unit = withWriteLock { + if (cachedSerializedBroadcastMergeStatus != null) { + Utils.tryLogNonFatalError { + // Use `blocking = false` so that this operation doesn't hang while trying to send cleanup + // RPCs to dead executors. + cachedSerializedBroadcastMergeStatus.destroy() + } + cachedSerializedBroadcastMergeStatus = null + } + cachedSerializedMergeStatus = null + } } private[spark] sealed trait MapOutputTrackerMessage private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage +private[spark] case class GetMapAndMergeResultStatuses(shuffleId: Int) + extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage -private[spark] case class GetMapOutputMessage(shuffleId: Int, context: RpcCallContext) +private[spark] sealed trait MapOutputTrackerMasterMessage +private[spark] case class GetMapOutputMessage(shuffleId: Int, + context: RpcCallContext) extends MapOutputTrackerMasterMessage +private[spark] case class GetMapAndMergeOutputMessage(shuffleId: Int, + context: RpcCallContext) extends MapOutputTrackerMasterMessage /** RpcEndpoint class for MapOutputTrackerMaster */ private[spark] class MapOutputTrackerMasterEndpoint( @@ -288,8 +420,13 @@ private[spark] class MapOutputTrackerMasterEndpoint( override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case GetMapOutputStatuses(shuffleId: Int) => val hostPort = context.senderAddress.hostPort - logInfo(s"Asked to send map output locations for shuffle ${shuffleId} to ${hostPort}") - tracker.post(new GetMapOutputMessage(shuffleId, context)) + logInfo(s"Asked to send map output locations for shuffle $shuffleId to $hostPort") + tracker.post(GetMapOutputMessage(shuffleId, context)) + + case GetMapAndMergeResultStatuses(shuffleId: Int) => + val hostPort = context.senderAddress.hostPort + logInfo(s"Asked to send map/merge result locations for shuffle $shuffleId to $hostPort") + tracker.post(GetMapAndMergeOutputMessage(shuffleId, context)) case StopMapOutputTracker => logInfo("MapOutputTrackerMasterEndpoint stopped!") @@ -367,6 +504,40 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging startPartition: Int, endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] + /** + * Called from executors upon fetch failure on an entire merged shuffle reduce partition. + * Such failures can happen if the shuffle client fails to fetch the metadata for the given + * merged shuffle partition. This method is to get the server URIs and output sizes for each + * shuffle block that is merged in the specified merged shuffle block so fetch failure on a + * merged shuffle block can fall back to fetching the unmerged blocks. + * + * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, + * and the second item is a sequence of (shuffle block ID, shuffle block size, map index) + * tuples describing the shuffle blocks that are stored at that block manager. + */ + def getMapSizesForMergeResult( + shuffleId: Int, + partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] + + /** + * Called from executors upon fetch failure on a merged shuffle reduce partition chunk. This is + * to get the server URIs and output sizes for each shuffle block that is merged in the specified + * merged shuffle partition chunk so fetch failure on a merged shuffle block chunk can fall back + * to fetching the unmerged blocks. + * + * chunkBitMap tracks the mapIds which are part of the current merged chunk, this way if there is + * a fetch failure on the merged chunk, it can fallback to fetching the corresponding original + * blocks part of this merged chunk. + * + * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, + * and the second item is a sequence of (shuffle block ID, shuffle block size, map index) + * tuples describing the shuffle blocks that are stored at that block manager. + */ + def getMapSizesForMergeResult( + shuffleId: Int, + partitionId: Int, + chunkBitmap: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] + /** * Deletes map output status information for the specified shuffle stage. */ @@ -415,8 +586,11 @@ private[spark] class MapOutputTrackerMaster( private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) - // requests for map output statuses - private val mapOutputRequests = new LinkedBlockingQueue[GetMapOutputMessage] + // requests for MapOutputTrackerMasterMessages + private val mapOutputTrackerMasterMessages = + new LinkedBlockingQueue[MapOutputTrackerMasterMessage] + + private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(conf) // Thread pool used for handling map output status requests. This is a separate thread pool // to ensure we don't block the normal dispatcher threads. @@ -439,31 +613,47 @@ private[spark] class MapOutputTrackerMaster( throw new IllegalArgumentException(msg) } - def post(message: GetMapOutputMessage): Unit = { - mapOutputRequests.offer(message) + def post(message: MapOutputTrackerMasterMessage): Unit = { + mapOutputTrackerMasterMessages.offer(message) } /** Message loop used for dispatching messages. */ private class MessageLoop extends Runnable { + private def handleStatusMessage( + shuffleId: Int, + context: RpcCallContext, + needMergeOutput: Boolean): Unit = { + val hostPort = context.senderAddress.hostPort + val shuffleStatus = shuffleStatuses.get(shuffleId).head + logDebug(s"Handling request to send ${if (needMergeOutput) "map" else "map/merge"}" + + s" output locations for shuffle $shuffleId to $hostPort") + if (needMergeOutput) { + context.reply( + shuffleStatus. + serializedMapAndMergeStatus(broadcastManager, isLocal, minSizeForBroadcast, conf)) + } else { + context.reply( + shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast, conf)) + } + } + override def run(): Unit = { try { while (true) { try { - val data = mapOutputRequests.take() - if (data == PoisonPill) { + val data = mapOutputTrackerMasterMessages.take() + if (data == PoisonPill) { // Put PoisonPill back so that other MessageLoops can see it. - mapOutputRequests.offer(PoisonPill) + mapOutputTrackerMasterMessages.offer(PoisonPill) return } - val context = data.context - val shuffleId = data.shuffleId - val hostPort = context.senderAddress.hostPort - logDebug("Handling request to send map output locations for shuffle " + shuffleId + - " to " + hostPort) - val shuffleStatus = shuffleStatuses.get(shuffleId).head - context.reply( - shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast, - conf)) + + data match { + case GetMapOutputMessage(shuffleId, context) => + handleStatusMessage(shuffleId, context, false) + case GetMapAndMergeOutputMessage(shuffleId, context) => + handleStatusMessage(shuffleId, context, true) + } } catch { case NonFatal(e) => logError(e.getMessage, e) } @@ -475,16 +665,22 @@ private[spark] class MapOutputTrackerMaster( } /** A poison endpoint that indicates MessageLoop should exit its message loop. */ - private val PoisonPill = new GetMapOutputMessage(-99, null) + private val PoisonPill = GetMapOutputMessage(-99, null) // Used only in unit tests. private[spark] def getNumCachedSerializedBroadcast: Int = { shuffleStatuses.valuesIterator.count(_.hasCachedSerializedBroadcast) } - def registerShuffle(shuffleId: Int, numMaps: Int): Unit = { - if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps)).isDefined) { - throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") + def registerShuffle(shuffleId: Int, numMaps: Int, numReduces: Int): Unit = { + if (pushBasedShuffleEnabled) { + if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps, numReduces)).isDefined) { + throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") + } + } else { + if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps)).isDefined) { + throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") + } } } @@ -524,10 +720,49 @@ private[spark] class MapOutputTrackerMaster( } } + def registerMergeResult(shuffleId: Int, reduceId: Int, status: MergeStatus) { + shuffleStatuses(shuffleId).addMergeResult(reduceId, status) + } + + def registerMergeResults(shuffleId: Int, statuses: Seq[(Int, MergeStatus)]): Unit = { + statuses.foreach { + case (reduceId, status) => registerMergeResult(shuffleId, reduceId, status) + } + } + + /** + * Unregisters a merge result corresponding to the reduceId if present. If the optional mapId + * is specified, it will only unregister the merge result if the mapId is part of that merge + * result. + * + * @param shuffleId the shuffleId. + * @param reduceId the reduceId. + * @param bmAddress block manager address. + * @param mapId the optional mapId which should be checked to see it was part of the merge + * result. + */ + def unregisterMergeResult( + shuffleId: Int, + reduceId: Int, + bmAddress: BlockManagerId, + mapId: Option[Int] = None) { + shuffleStatuses.get(shuffleId) match { + case Some(shuffleStatus) => + val mergeStatus = shuffleStatus.mergeStatuses(reduceId) + if (mergeStatus != null && (mapId.isEmpty || mergeStatus.tracker.contains(mapId.get))) { + shuffleStatus.removeMergeResult(reduceId, bmAddress) + incrementEpoch() + } + case None => + throw new SparkException("unregisterMergeResult called for nonexistent shuffle ID") + } + } + /** Unregister shuffle data */ def unregisterShuffle(shuffleId: Int): Unit = { shuffleStatuses.remove(shuffleId).foreach { shuffleStatus => shuffleStatus.invalidateSerializedMapOutputStatusCache() + shuffleStatus.invalidateSerializedMergeOutputStatusCache() } } @@ -554,7 +789,12 @@ private[spark] class MapOutputTrackerMaster( def containsShuffle(shuffleId: Int): Boolean = shuffleStatuses.contains(shuffleId) def getNumAvailableOutputs(shuffleId: Int): Int = { - shuffleStatuses.get(shuffleId).map(_.numAvailableOutputs).getOrElse(0) + shuffleStatuses.get(shuffleId).map(_.numAvailableMapOutputs).getOrElse(0) + } + + /** VisibleForTest. Invoked in test only. */ + private[spark] def getNumAvailableMergeResults(shuffleId: Int): Int = { + shuffleStatuses.get(shuffleId).map(_.numAvailableMergeResults).getOrElse(0) } /** @@ -633,7 +873,9 @@ private[spark] class MapOutputTrackerMaster( /** * Return the preferred hosts on which to run the given map output partition in a given shuffle, - * i.e. the nodes that the most outputs for that partition are on. + * i.e. the nodes that the most outputs for that partition are on. If the map output is + * pre-merged, then return the node where the merged block is located if the merge ratio is + * above the threshold. * * @param dep shuffle dependency object * @param partitionId map output partition that we want to read @@ -641,15 +883,40 @@ private[spark] class MapOutputTrackerMaster( */ def getPreferredLocationsForShuffle(dep: ShuffleDependency[_, _, _], partitionId: Int) : Seq[String] = { - if (shuffleLocalityEnabled && dep.rdd.partitions.length < SHUFFLE_PREF_MAP_THRESHOLD && - dep.partitioner.numPartitions < SHUFFLE_PREF_REDUCE_THRESHOLD) { - val blockManagerIds = getLocationsWithLargestOutputs(dep.shuffleId, partitionId, - dep.partitioner.numPartitions, REDUCER_PREF_LOCS_FRACTION) - if (blockManagerIds.nonEmpty) { - blockManagerIds.get.map(_.host) + val shuffleStatus = shuffleStatuses.get(dep.shuffleId).orNull + if (shuffleStatus != null) { + // Check if the map output is pre-merged and if the merge ratio is above the threshold. + // If so, the location of the merged block is the preferred location. + val preferredLoc = if (pushBasedShuffleEnabled) { + shuffleStatus.withMergeStatuses { statuses => + val status = statuses(partitionId) + val numMaps = dep.rdd.partitions.length + if (status != null && status.getNumMissingMapOutputs(numMaps).toDouble / numMaps + <= (1 - REDUCER_PREF_LOCS_FRACTION)) { + Seq(status.location.host) + } else { + Nil + } + } } else { Nil } + if (preferredLoc.nonEmpty) { + preferredLoc + } else { + if (shuffleLocalityEnabled && dep.rdd.partitions.length < SHUFFLE_PREF_MAP_THRESHOLD && + dep.partitioner.numPartitions < SHUFFLE_PREF_REDUCE_THRESHOLD) { + val blockManagerIds = getLocationsWithLargestOutputs(dep.shuffleId, partitionId, + dep.partitioner.numPartitions, REDUCER_PREF_LOCS_FRACTION) + if (blockManagerIds.nonEmpty) { + blockManagerIds.get.map(_.host) + } else { + Nil + } + } else { + Nil + } + } } else { Nil } @@ -774,8 +1041,25 @@ private[spark] class MapOutputTrackerMaster( } } + // This method is only called in local-mode. Since push based shuffle won't be + // enabled in local-mode, this method returns empty list. + override def getMapSizesForMergeResult( + shuffleId: Int, + partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + Seq.empty.toIterator + } + + // This method is only called in local-mode. Since push based shuffle won't be + // enabled in local-mode, this method returns empty list. + override def getMapSizesForMergeResult( + shuffleId: Int, + partitionId: Int, + chunkTracker: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + Seq.empty.toIterator + } + override def stop(): Unit = { - mapOutputRequests.offer(PoisonPill) + mapOutputTrackerMasterMessages.offer(PoisonPill) threadpool.shutdown() try { sendTracker(StopMapOutputTracker) @@ -799,6 +1083,11 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr val mapStatuses: Map[Int, Array[MapStatus]] = new ConcurrentHashMap[Int, Array[MapStatus]]().asScala + val mergeStatuses: Map[Int, Array[MergeStatus]] = + new ConcurrentHashMap[Int, Array[MergeStatus]]().asScala + + private val fetchMergeResult = Utils.isPushBasedShuffleEnabled(conf) + /** * A [[KeyLock]] whose key is a shuffle id to ensure there is only one thread fetching * the same shuffle block. @@ -812,61 +1101,150 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr startPartition: Int, endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId") - val statuses = getStatuses(shuffleId, conf) + val (mapOutputStatuses, mergedOutputStatuses) = getStatuses(shuffleId, conf) try { - val actualEndMapIndex = if (endMapIndex == Int.MaxValue) statuses.length else endMapIndex + val actualEndMapIndex = + if (endMapIndex == Int.MaxValue) mapOutputStatuses.length else endMapIndex logDebug(s"Convert map statuses for shuffle $shuffleId, " + s"mappers $startMapIndex-$actualEndMapIndex, partitions $startPartition-$endPartition") MapOutputTracker.convertMapStatuses( - shuffleId, startPartition, endPartition, statuses, startMapIndex, actualEndMapIndex) + shuffleId, startPartition, endPartition, mapOutputStatuses, startMapIndex, + actualEndMapIndex, Option(mergedOutputStatuses)) } catch { case e: MetadataFetchFailedException => // We experienced a fetch failure so our mapStatuses cache is outdated; clear it: mapStatuses.clear() + mergeStatuses.clear() + throw e + } + } + + override def getMapSizesForMergeResult( + shuffleId: Int, + partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition $partitionId") + // Fetch the map statuses and merge statuses again since they might have already been + // cleared by another task running in the same executor. + val (mapOutputStatuses, mergeResultStatuses) = getStatuses(shuffleId, conf) + try { + val mergeStatus = mergeResultStatuses(partitionId) + // If the original MergeStatus is no longer available, we cannot identify the list of + // unmerged blocks to fetch in this case. Throw MetadataFetchFailedException in this case. + MapOutputTracker.validateStatus(mergeStatus, shuffleId, partitionId) + // Use the MergeStatus's partition level bitmap since we are doing partition level fallback + MapOutputTracker.getMapStatusesForMergeStatus(shuffleId, partitionId, + mapOutputStatuses, mergeStatus.tracker) + } catch { + // We experienced a fetch failure so our mapStatuses cache is outdated; clear it + case e: MetadataFetchFailedException => + mapStatuses.clear() + mergeStatuses.clear() + throw e + } + } + + override def getMapSizesForMergeResult( + shuffleId: Int, + partitionId: Int, + chunkTracker: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition $partitionId") + // Fetch the map statuses and merge statuses again since they might have already been + // cleared by another task running in the same executor. + val (mapOutputStatuses, _) = getStatuses(shuffleId, conf) + try { + MapOutputTracker.getMapStatusesForMergeStatus(shuffleId, partitionId, mapOutputStatuses, + chunkTracker) + } catch { + // We experienced a fetch failure so our mapStatuses cache is outdated; clear it: + case e: MetadataFetchFailedException => + mapStatuses.clear() + mergeStatuses.clear() throw e } } /** - * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize + * Get or fetch the array of MapStatuses and MergeStatuses if push based shuffle enabled + * for a given shuffle ID. NOTE: clients MUST synchronize * on this array when reading it, because on the driver, we may be changing it in place. * * (It would be nice to remove this restriction in the future.) */ - private def getStatuses(shuffleId: Int, conf: SparkConf): Array[MapStatus] = { - val statuses = mapStatuses.get(shuffleId).orNull - if (statuses == null) { - logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") - val startTimeNs = System.nanoTime() - fetchingLock.withLock(shuffleId) { - var fetchedStatuses = mapStatuses.get(shuffleId).orNull - if (fetchedStatuses == null) { - logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) - val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) - try { - fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes, conf) - } catch { - case e: SparkException => - throw new MetadataFetchFailedException(shuffleId, -1, - s"Unable to deserialize broadcasted map statuses for shuffle $shuffleId: " + - e.getCause) + private def getStatuses( + shuffleId: Int, + conf: SparkConf): (Array[MapStatus], Array[MergeStatus]) = { + if (fetchMergeResult) { + val mapOutputStatuses = mapStatuses.get(shuffleId).orNull + val mergeOutputStatuses = mergeStatuses.get(shuffleId).orNull + + if (mapOutputStatuses == null || mergeOutputStatuses == null) { + logInfo("Don't have map/merge outputs for shuffle " + shuffleId + ", fetching them") + val startTimeNs = System.nanoTime() + fetchingLock.withLock(shuffleId) { + var fetchedMapStatuses = mapStatuses.get(shuffleId).orNull + var fetchedMergeStatuses = mergeStatuses.get(shuffleId).orNull + if (fetchedMapStatuses == null || fetchedMergeStatuses == null) { + logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) + val fetchedBytes = + askTracker[(Array[Byte], Array[Byte])](GetMapAndMergeResultStatuses(shuffleId)) + try { + fetchedMapStatuses = + MapOutputTracker.deserializeOutputStatuses[MapStatus](fetchedBytes._1, conf) + fetchedMergeStatuses = + MapOutputTracker.deserializeOutputStatuses[MergeStatus](fetchedBytes._2, conf) + } catch { + case e: SparkException => + throw new MetadataFetchFailedException(shuffleId, -1, + s"Unable to deserialize broadcasted map/merge statuses" + + s" for shuffle $shuffleId: " + e.getCause) + } + logInfo("Got the map/merge output locations") + mapStatuses.put(shuffleId, fetchedMapStatuses) + mergeStatuses.put(shuffleId, fetchedMergeStatuses) } - logInfo("Got the output locations") - mapStatuses.put(shuffleId, fetchedStatuses) + logDebug(s"Fetching map/merge output statuses for shuffle $shuffleId took " + + s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms") + (fetchedMapStatuses, fetchedMergeStatuses) } - logDebug(s"Fetching map output statuses for shuffle $shuffleId took " + - s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms") - fetchedStatuses + } else { + (mapOutputStatuses, mergeOutputStatuses) } } else { - statuses + val statuses = mapStatuses.get(shuffleId).orNull + if (statuses == null) { + logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") + val startTimeNs = System.nanoTime() + fetchingLock.withLock(shuffleId) { + var fetchedStatuses = mapStatuses.get(shuffleId).orNull + if (fetchedStatuses == null) { + logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) + val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) + try { + fetchedStatuses = + MapOutputTracker.deserializeOutputStatuses[MapStatus](fetchedBytes, conf) + } catch { + case e: SparkException => + throw new MetadataFetchFailedException(shuffleId, -1, + s"Unable to deserialize broadcasted map statuses for shuffle $shuffleId: " + + e.getCause) + } + logInfo("Got the map output locations") + mapStatuses.put(shuffleId, fetchedStatuses) + } + logDebug(s"Fetching map output statuses for shuffle $shuffleId took " + + s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms") + (fetchedStatuses, null) + } + } else { + (statuses, null) + } } } - /** Unregister shuffle data. */ def unregisterShuffle(shuffleId: Int): Unit = { mapStatuses.remove(shuffleId) + mergeStatuses.remove(shuffleId) } /** @@ -880,6 +1258,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr logInfo("Updating epoch to " + newEpoch + " and clearing cache") epoch = newEpoch mapStatuses.clear() + mergeStatuses.clear() } } } @@ -891,11 +1270,13 @@ private[spark] object MapOutputTracker extends Logging { private val DIRECT = 0 private val BROADCAST = 1 - // Serialize an array of map output locations into an efficient byte format so that we can send - // it to reduce tasks. We do this by compressing the serialized bytes using Zstd. They will - // generally be pretty compressible because many map outputs will be on the same hostname. - def serializeMapStatuses( - statuses: Array[MapStatus], + private val SHUFFLE_PUSH_MAP_ID = -1 + + // Serialize an array of map/merge output locations into an efficient byte format so that we can + // send it to reduce tasks. We do this by compressing the serialized bytes using Zstd. They will + // generally be pretty compressible because many outputs will be on the same hostname. + def serializeOutputStatuses[T <: ShuffleOutputStatus]( + statuses: Array[T], broadcastManager: BroadcastManager, isLocal: Boolean, minBroadcastSize: Int, @@ -931,15 +1312,16 @@ private[spark] object MapOutputTracker extends Logging { oos.close() } val outArr = out.toByteArray - logInfo("Broadcast mapstatuses size = " + outArr.length + ", actual size = " + arr.length) + logInfo("Broadcast outputstatuses size = " + outArr.length + ", actual size = " + arr.length) (outArr, bcast) } else { (arr, null) } } - // Opposite of serializeMapStatuses. - def deserializeMapStatuses(bytes: Array[Byte], conf: SparkConf): Array[MapStatus] = { + // Opposite of serializeOutputStatuses. + def deserializeOutputStatuses[T <: ShuffleOutputStatus]( + bytes: Array[Byte], conf: SparkConf): Array[T] = { assert (bytes.length > 0) def deserializeObject(arr: Array[Byte], off: Int, len: Int): AnyRef = { @@ -958,20 +1340,22 @@ private[spark] object MapOutputTracker extends Logging { bytes(0) match { case DIRECT => - deserializeObject(bytes, 1, bytes.length - 1).asInstanceOf[Array[MapStatus]] + deserializeObject(bytes, 1, bytes.length - 1).asInstanceOf[Array[T]] case BROADCAST => try { // deserialize the Broadcast, pull .value array out of it, and then deserialize that val bcast = deserializeObject(bytes, 1, bytes.length - 1). asInstanceOf[Broadcast[Array[Byte]]] - logInfo("Broadcast mapstatuses size = " + bytes.length + + logInfo("Broadcast outputstatuses size = " + bytes.length + ", actual size = " + bcast.value.length) // Important - ignore the DIRECT tag ! Start from offset 1 - deserializeObject(bcast.value, 1, bcast.value.length - 1).asInstanceOf[Array[MapStatus]] + deserializeObject(bcast.value, 1, bcast.value.length - 1).asInstanceOf[Array[T]] } catch { case e: IOException => - logWarning("Exception encountered during deserializing broadcasted map statuses: ", e) - throw new SparkException("Unable to deserialize broadcasted map statuses", e) + logWarning("Exception encountered during deserializing broadcasted" + + " output statuses: ", e) + throw new SparkException("Unable to deserialize broadcasted" + + " output statuses", e) } case _ => throw new IllegalArgumentException("Unexpected byte tag = " + bytes(0)) } @@ -983,15 +1367,19 @@ private[spark] object MapOutputTracker extends Logging { * stored at that block manager. * Note that empty blocks are filtered in the result. * + * If push-based shuffle is enabled and an array of merge statuses is available, prioritize + * the locations of the merged shuffle partitions over unmerged shuffle blocks. + * * If any of the statuses is null (indicating a missing location due to a failed mapper), * throws a FetchFailedException. * * @param shuffleId Identifier for the shuffle * @param startPartition Start of map output partition ID range (included in range) * @param endPartition End of map output partition ID range (excluded from range) - * @param statuses List of map statuses, indexed by map partition index. + * @param mapStatuses List of map statuses, indexed by map partition index. * @param startMapIndex Start Map index. * @param endMapIndex End Map index. + * @param mergeStatuses List of merge statuses, index by reduce ID. * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, * and the second item is a sequence of (shuffle block id, shuffle block size, map index) * tuples describing the shuffle blocks that are stored at that block manager. @@ -1000,18 +1388,57 @@ private[spark] object MapOutputTracker extends Logging { shuffleId: Int, startPartition: Int, endPartition: Int, - statuses: Array[MapStatus], + mapStatuses: Array[MapStatus], startMapIndex : Int, - endMapIndex: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { - assert (statuses != null) + endMapIndex: Int, + mergeStatuses: Option[Array[MergeStatus]] = None): + Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + assert (mapStatuses != null) val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]] - val iter = statuses.iterator.zipWithIndex - for ((status, mapIndex) <- iter.slice(startMapIndex, endMapIndex)) { - if (status == null) { - val errorMessage = s"Missing an output location for shuffle $shuffleId" - logError(errorMessage) - throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage) - } else { + // Only use MergeStatus for reduce tasks that fetch all map outputs. Since a merged shuffle + // partition consists of blocks merged in random order, we are unable to serve map index + // subrange requests. However, when a reduce task needs to fetch blocks from a subrange of + // map outputs, it usually indicates skewed partitions which push-based shuffle delegates + // to AQE to handle. + // TODO: SPARK-35036: Instead of reading map blocks in case of AQE with Push based shuffle, + // TODO: improve push based shuffle to read partial merged blocks satisfying the start/end + // TODO: map indexes + if (mergeStatuses.exists(_.nonEmpty) && startMapIndex == 0 + && endMapIndex == mapStatuses.length) { + // We have MergeStatus and full range of mapIds are requested so return a merged block. + val numMaps = mapStatuses.length + mergeStatuses.get.zipWithIndex.slice(startPartition, endPartition).foreach { + case (mergeStatus, partId) => + val remainingMapStatuses = if (mergeStatus != null && mergeStatus.totalSize > 0) { + // If MergeStatus is available for the given partition, add location of the + // pre-merged shuffle partition for this partition ID. Here we create a + // ShuffleBlockId with mapId being SHUFFLE_PUSH_MAP_ID to indicate this is + // a merged shuffle block. + splitsByAddress.getOrElseUpdate(mergeStatus.location, ListBuffer()) += + ((ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, partId), mergeStatus.totalSize, -1)) + // For the "holes" in this pre-merged shuffle partition, i.e., unmerged mapper + // shuffle partition blocks, fetch the original map produced shuffle partition blocks + val mapStatusesWithIndex = mapStatuses.zipWithIndex + mergeStatus.getMissingMaps(numMaps).map(mapStatusesWithIndex) + } else { + // If MergeStatus is not available for the given partition, fall back to + // fetching all the original mapper shuffle partition blocks + mapStatuses.zipWithIndex.toSeq + } + // Add location for the mapper shuffle partition blocks + for ((mapStatus, mapIndex) <- remainingMapStatuses) { + validateStatus(mapStatus, shuffleId, partId) + val size = mapStatus.getSizeForBlock(partId) + if (size != 0) { + splitsByAddress.getOrElseUpdate(mapStatus.location, ListBuffer()) += + ((ShuffleBlockId(shuffleId, mapStatus.mapId, partId), size, mapIndex)) + } + } + } + } else { + val iter = mapStatuses.iterator.zipWithIndex + for ((status, mapIndex) <- iter.slice(startMapIndex, endMapIndex)) { + validateStatus(status, shuffleId, startPartition) for (part <- startPartition until endPartition) { val size = status.getSizeForBlock(part) if (size != 0) { @@ -1024,4 +1451,47 @@ private[spark] object MapOutputTracker extends Logging { splitsByAddress.mapValues(_.toSeq).iterator } + + /** + * Given a shuffle ID, a partition ID, an array of map statuses, and bitmap corresponding + * to either a merged shuffle partition or a merged shuffle partition chunk, identify + * the metadata about the shuffle partition blocks that are merged into the merged shuffle + * partition or partition chunk represented by the bitmap. + * + * @param shuffleId Identifier for the shuffle + * @param partitionId The partition ID of the MergeStatus for which we look for the metadata + * of the merged shuffle partition blocks + * @param mapStatuses List of map statuses, indexed by map ID + * @param tracker bitmap containing mapIndexes that belong to the merged block or merged + * block chunk. + * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, + * and the second item is a sequence of (shuffle block ID, shuffle block size) tuples + * describing the shuffle blocks that are stored at that block manager. + */ + def getMapStatusesForMergeStatus( + shuffleId: Int, + partitionId: Int, + mapStatuses: Array[MapStatus], + tracker: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + assert (mapStatuses != null && tracker != null) + val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]] + for ((status, mapIndex) <- mapStatuses.zipWithIndex) { + // Only add blocks that are merged + if (tracker.contains(mapIndex)) { + MapOutputTracker.validateStatus(status, shuffleId, partitionId) + splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) += + ((ShuffleBlockId(shuffleId, status.mapId, partitionId), + status.getSizeForBlock(partitionId), mapIndex)) + } + } + splitsByAddress.mapValues(_.toSeq).iterator + } + + def validateStatus(status: ShuffleOutputStatus, shuffleId: Int, partition: Int) : Unit = { + if (status == null) { + val errorMessage = s"Missing an output location for shuffle $shuffleId partition $partition" + logError(errorMessage) + throw new MetadataFetchFailedException(shuffleId, partition, errorMessage) + } + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index c2e7c4dae038e..a92d9fab6efc6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -456,7 +456,8 @@ private[spark] class DAGScheduler( // since we can't do it in the RDD constructor because # of partitions is unknown logInfo(s"Registering RDD ${rdd.id} (${rdd.getCreationSite}) as input to " + s"shuffle ${shuffleDep.shuffleId}") - mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.length) + mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.length, + shuffleDep.partitioner.numPartitions) } stage } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 1239c32cee3ab..07eed76805dd2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -28,12 +28,18 @@ import org.apache.spark.internal.config import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils +/** + * A common trait between [[MapStatus]] and [[MergeStatus]]. This allows us to reuse existing + * code to handle MergeStatus inside MapOutputTracker. + */ +private[spark] trait ShuffleOutputStatus + /** * Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the * task has shuffle files stored on as well as the sizes of outputs for each reducer, for passing * on to the reduce tasks. */ -private[spark] sealed trait MapStatus { +private[spark] sealed trait MapStatus extends ShuffleOutputStatus { /** Location where this task output is. */ def location: BlockManagerId diff --git a/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala new file mode 100644 index 0000000000000..77d8f8e040da1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.io.{Externalizable, ObjectInput, ObjectOutput} + +import org.roaringbitmap.RoaringBitmap + +import org.apache.spark.network.shuffle.protocol.MergeStatuses +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.Utils + +/** + * The status for the result of merging shuffle partition blocks per individual shuffle partition + * maintained by the scheduler. The scheduler would separate the + * [[org.apache.spark.network.shuffle.protocol.MergeStatuses]] received from + * ExternalShuffleService into individual [[MergeStatus]] which is maintained inside + * MapOutputTracker to be served to the reducers when they start fetching shuffle partition + * blocks. Note that, the reducers are ultimately fetching individual chunks inside a merged + * shuffle file, as explained in [[org.apache.spark.network.shuffle.RemoteBlockPushResolver]]. + * Between the scheduler maintained MergeStatus and the shuffle service maintained per shuffle + * partition meta file, we are effectively dividing the metadata for a push-based shuffle into + * 2 layers. The scheduler would track the top-level metadata at the shuffle partition level + * with MergeStatus, and the shuffle service would maintain the partition level metadata about + * how to further divide a merged shuffle partition into multiple chunks with the per-partition + * meta file. This helps to reduce the amount of data the scheduler needs to maintain for + * push-based shuffle. + */ +private[spark] class MergeStatus( + private[this] var loc: BlockManagerId, + private[this] var mapTracker: RoaringBitmap, + private[this] var size: Long) + extends Externalizable with ShuffleOutputStatus { + + protected def this() = this(null, null, -1) // For deserialization only + + def location: BlockManagerId = loc + + def totalSize: Long = size + + def tracker: RoaringBitmap = mapTracker + + /** + * Get the list of mapper IDs for missing mapper partition blocks that are not merged. + * The reducer will use this information to decide which shuffle partition blocks to + * fetch in the original way. + */ + def getMissingMaps(numMaps: Int): Seq[Int] = { + (0 until numMaps).filter(i => !mapTracker.contains(i)) + } + + /** + * Get the number of missing map outputs for missing mapper partition blocks that are not merged. + */ + def getNumMissingMapOutputs(numMaps: Int): Int = { + (0 until numMaps).count(i => !mapTracker.contains(i)) + } + + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { + loc.writeExternal(out) + mapTracker.writeExternal(out) + out.writeLong(size) + } + + override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { + loc = BlockManagerId(in) + mapTracker = new RoaringBitmap() + mapTracker.readExternal(in) + size = in.readLong() + } +} + +private[spark] object MergeStatus { + // Dummy number of reduces for the tests where push based shuffle is not enabled + val SHUFFLE_PUSH_DUMMY_NUM_REDUCES = 1 + + /** + * Separate a MergeStatuses received from an ExternalShuffleService into individual + * MergeStatus. The scheduler is responsible for providing the location information + * for the given ExternalShuffleService. + */ + def convertMergeStatusesToMergeStatusArr( + mergeStatuses: MergeStatuses, + loc: BlockManagerId): Seq[(Int, MergeStatus)] = { + assert(mergeStatuses.bitmaps.length == mergeStatuses.reduceIds.length && + mergeStatuses.bitmaps.length == mergeStatuses.sizes.length) + val mergerLoc = BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, loc.host, loc.port) + mergeStatuses.bitmaps.zipWithIndex.map { + case (bitmap, index) => + val mergeStatus = new MergeStatus(mergerLoc, bitmap, mergeStatuses.sizes(index)) + (mergeStatuses.reduceIds(index), mergeStatus) + } + } + + def apply(loc: BlockManagerId, bitmap: RoaringBitmap, size: Long): MergeStatus = { + new MergeStatus(loc, bitmap, size) + } +} diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 83fe450425146..f4b47e2bb0cdc 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -21,17 +21,19 @@ import scala.collection.mutable.ArrayBuffer import org.mockito.ArgumentMatchers.any import org.mockito.Mockito._ +import org.roaringbitmap.RoaringBitmap import org.apache.spark.LocalSparkContext._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.internal.config._ import org.apache.spark.internal.config.Network.{RPC_ASK_TIMEOUT, RPC_MESSAGE_MAX_SIZE} +import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv} -import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} +import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus, MergeStatus} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} -class MapOutputTrackerSuite extends SparkFunSuite { +class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { private val conf = new SparkConf private def newTrackerMaster(sparkConf: SparkConf = conf) = { @@ -58,7 +60,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val tracker = newTrackerMaster() tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) - tracker.registerShuffle(10, 2) + tracker.registerShuffle(10, 2, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) assert(tracker.containsShuffle(10)) val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) @@ -82,7 +84,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val tracker = newTrackerMaster() tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) - tracker.registerShuffle(10, 2) + tracker.registerShuffle(10, 2, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), @@ -105,7 +107,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val tracker = newTrackerMaster() tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) - tracker.registerShuffle(10, 2) + tracker.registerShuffle(10, 2, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), @@ -140,7 +142,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { mapWorkerTracker.trackerEndpoint = mapWorkerRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - masterTracker.registerShuffle(10, 1) + masterTracker.registerShuffle(10, 1, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) mapWorkerTracker.updateEpoch(masterTracker.getEpoch) // This is expected to fail because no outputs have been registered for the shuffle. intercept[FetchFailedException] { mapWorkerTracker.getMapSizesByExecutorId(10, 0) } @@ -183,7 +185,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) // Message size should be ~123B, and no exception should be thrown - masterTracker.registerShuffle(10, 1) + masterTracker.registerShuffle(10, 1, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0), 5)) val senderAddress = RpcAddress("localhost", 12345) @@ -217,7 +219,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { // on hostA with output size 2 // on hostA with output size 2 // on hostB with output size 3 - tracker.registerShuffle(10, 3) + tracker.registerShuffle(10, 3, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), Array(2L), 5)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("a", "hostA", 1000), @@ -260,7 +262,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { // Frame size should be ~1.1MB, and MapOutputTrackerMasterEndpoint should throw exception. // Note that the size is hand-selected here because map output statuses are compressed before // being sent. - masterTracker.registerShuffle(20, 100) + masterTracker.registerShuffle(20, 100, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) (0 until 100).foreach { i => masterTracker.registerMapOutput(20, i, new CompressedMapStatus( BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 5)) @@ -306,7 +308,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val tracker = newTrackerMaster() tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) - tracker.registerShuffle(10, 2) + tracker.registerShuffle(10, 2, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) val size0 = MapStatus.decompressSize(MapStatus.compressSize(0L)) val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) @@ -332,6 +334,219 @@ class MapOutputTrackerSuite extends SparkFunSuite { rpcEnv.shutdown() } + test("SPARK-32921: master register and unregister merge result") { + conf.set(PUSH_BASED_SHUFFLE_ENABLED, true) + conf.set(IS_TESTING, true) + val rpcEnv = createRpcEnv("test") + val tracker = newTrackerMaster() + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + tracker.registerShuffle(10, 4, 2) + assert(tracker.containsShuffle(10)) + val bitmap = new RoaringBitmap() + bitmap.add(0) + bitmap.add(1) + + tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000), + bitmap, 1000L)) + tracker.registerMergeResult(10, 1, MergeStatus(BlockManagerId("b", "hostB", 1000), + bitmap, 1000L)) + assert(tracker.getNumAvailableMergeResults(10) == 2) + tracker.unregisterMergeResult(10, 0, BlockManagerId("a", "hostA", 1000)) + assert(tracker.getNumAvailableMergeResults(10) == 1) + tracker.stop() + rpcEnv.shutdown() + } + + test("SPARK-32921: get map sizes with merged shuffle") { + conf.set(PUSH_BASED_SHUFFLE_ENABLED, true) + conf.set(IS_TESTING, true) + val hostname = "localhost" + val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf)) + + val masterTracker = newTrackerMaster() + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) + + val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(conf)) + val slaveTracker = new MapOutputTrackerWorker(conf) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + + masterTracker.registerShuffle(10, 4, 1) + slaveTracker.updateEpoch(masterTracker.getEpoch) + val bitmap = new RoaringBitmap() + bitmap.add(0) + bitmap.add(1) + bitmap.add(3) + + val blockMgrId = BlockManagerId("a", "hostA", 1000) + masterTracker.registerMapOutput(10, 0, MapStatus(blockMgrId, Array(1000L), 0)) + masterTracker.registerMapOutput(10, 1, MapStatus(blockMgrId, Array(1000L), 1)) + masterTracker.registerMapOutput(10, 2, MapStatus(blockMgrId, Array(1000L), 2)) + masterTracker.registerMapOutput(10, 3, MapStatus(blockMgrId, Array(1000L), 3)) + + masterTracker.registerMergeResult(10, 0, MergeStatus(blockMgrId, + bitmap, 3000L)) + slaveTracker.updateEpoch(masterTracker.getEpoch) + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === + Seq((blockMgrId, ArrayBuffer((ShuffleBlockId(10, -1, 0), 3000, -1), + (ShuffleBlockId(10, 2, 0), size1000, 2))))) + + masterTracker.stop() + slaveTracker.stop() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() + } + + test("SPARK-32921: get map statuses from merged shuffle") { + conf.set(PUSH_BASED_SHUFFLE_ENABLED, true) + conf.set(IS_TESTING, true) + val hostname = "localhost" + val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf)) + + val masterTracker = newTrackerMaster() + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) + + val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(conf)) + val slaveTracker = new MapOutputTrackerWorker(conf) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + + masterTracker.registerShuffle(10, 4, 1) + slaveTracker.updateEpoch(masterTracker.getEpoch) + // This is expected to fail because no outputs have been registered for the shuffle. + intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } + val bitmap = new RoaringBitmap() + bitmap.add(0) + bitmap.add(1) + bitmap.add(2) + bitmap.add(3) + + val blockMgrId = BlockManagerId("a", "hostA", 1000) + masterTracker.registerMapOutput(10, 0, MapStatus(blockMgrId, Array(1000L), 0)) + masterTracker.registerMapOutput(10, 1, MapStatus(blockMgrId, Array(1000L), 1)) + masterTracker.registerMapOutput(10, 2, MapStatus(blockMgrId, Array(1000L), 2)) + masterTracker.registerMapOutput(10, 3, MapStatus(blockMgrId, Array(1000L), 3)) + + masterTracker.registerMergeResult(10, 0, MergeStatus(blockMgrId, + bitmap, 4000L)) + slaveTracker.updateEpoch(masterTracker.getEpoch) + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + assert(slaveTracker.getMapSizesForMergeResult(10, 0).toSeq === + Seq((blockMgrId, ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000, 0), + (ShuffleBlockId(10, 1, 0), size1000, 1), (ShuffleBlockId(10, 2, 0), size1000, 2), + (ShuffleBlockId(10, 3, 0), size1000, 3))))) + masterTracker.stop() + slaveTracker.stop() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() + } + + test("SPARK-32921: get map statuses for merged shuffle block chunks") { + conf.set(PUSH_BASED_SHUFFLE_ENABLED, true) + conf.set(IS_TESTING, true) + val hostname = "localhost" + val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf)) + + val masterTracker = newTrackerMaster() + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) + + val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(conf)) + val slaveTracker = new MapOutputTrackerWorker(conf) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + + masterTracker.registerShuffle(10, 4, 1) + slaveTracker.updateEpoch(masterTracker.getEpoch) + + val blockMgrId = BlockManagerId("a", "hostA", 1000) + masterTracker.registerMapOutput(10, 0, MapStatus(blockMgrId, Array(1000L), 0)) + masterTracker.registerMapOutput(10, 1, MapStatus(blockMgrId, Array(1000L), 1)) + masterTracker.registerMapOutput(10, 2, MapStatus(blockMgrId, Array(1000L), 2)) + masterTracker.registerMapOutput(10, 3, MapStatus(blockMgrId, Array(1000L), 3)) + + val chunkBitmap = new RoaringBitmap() + chunkBitmap.add(0) + chunkBitmap.add(2) + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + assert(slaveTracker.getMapSizesForMergeResult(10, 0, chunkBitmap).toSeq === + Seq((blockMgrId, ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000, 0), + (ShuffleBlockId(10, 2, 0), size1000, 2)))) + ) + masterTracker.stop() + slaveTracker.stop() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() + } + + test("SPARK-32921: getPreferredLocationsForShuffle with MergeStatus") { + val rpcEnv = createRpcEnv("test") + val tracker = newTrackerMaster() + sc = new SparkContext("local", "test", conf.clone()) + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + // Setup 5 map tasks + // on hostA with output size 2 + // on hostA with output size 2 + // on hostB with output size 3 + // on hostB with output size 3 + // on hostC with output size 1 + // on hostC with output size 1 + tracker.registerShuffle(10, 6, 1) + tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), + Array(2L), 5)) + tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("a", "hostA", 1000), + Array(2L), 6)) + tracker.registerMapOutput(10, 2, MapStatus(BlockManagerId("b", "hostB", 1000), + Array(3L), 7)) + tracker.registerMapOutput(10, 3, MapStatus(BlockManagerId("b", "hostB", 1000), + Array(3L), 8)) + tracker.registerMapOutput(10, 4, MapStatus(BlockManagerId("c", "hostC", 1000), + Array(1L), 9)) + tracker.registerMapOutput(10, 5, MapStatus(BlockManagerId("c", "hostC", 1000), + Array(1L), 10)) + + val rdd = sc.parallelize(1 to 6, 6).map(num => (num, num).asInstanceOf[Product2[Int, Int]]) + val mockShuffleDep = mock(classOf[ShuffleDependency[Int, Int, _]]) + when(mockShuffleDep.shuffleId).thenReturn(10) + when(mockShuffleDep.partitioner).thenReturn(new HashPartitioner(1)) + when(mockShuffleDep.rdd).thenReturn(rdd) + + // Prepare a MergeStatus that merges 4 out of 5 blocks + val bitmap80 = new RoaringBitmap() + bitmap80.add(0) + bitmap80.add(1) + bitmap80.add(2) + bitmap80.add(3) + bitmap80.add(4) + tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000), + bitmap80, 11)) + + val preferredLocs1 = tracker.getPreferredLocationsForShuffle(mockShuffleDep, 0) + assert(preferredLocs1.nonEmpty) + assert(preferredLocs1.length === 1) + assert(preferredLocs1.head === "hostA") + + tracker.unregisterMergeResult(10, 0, BlockManagerId("a", "hostA", 1000)) + // Prepare another MergeStatus that merges only 1 out of 5 blocks + val bitmap20 = new RoaringBitmap() + bitmap20.add(0) + tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000), + bitmap20, 2)) + + val preferredLocs2 = tracker.getPreferredLocationsForShuffle(mockShuffleDep, 0) + assert(preferredLocs2.nonEmpty) + assert(preferredLocs2.length === 2) + assert(preferredLocs2 === Seq("hostA", "hostB")) + + tracker.stop() + rpcEnv.shutdown() + } + test("SPARK-34939: remote fetch using broadcast if broadcasted value is destroyed") { val newConf = new SparkConf newConf.set(RPC_MESSAGE_MAX_SIZE, 1) @@ -346,7 +561,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { rpcEnv.stop(masterTracker.trackerEndpoint) rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) - masterTracker.registerShuffle(20, 100) + masterTracker.registerShuffle(20, 100, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) (0 until 100).foreach { i => masterTracker.registerMapOutput(20, i, new CompressedMapStatus( BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 5)) @@ -368,9 +583,85 @@ class MapOutputTrackerSuite extends SparkFunSuite { shuffleStatus.cachedSerializedBroadcast.destroy(true) } val err = intercept[SparkException] { - MapOutputTracker.deserializeMapStatuses(fetchedBytes, conf) + MapOutputTracker.deserializeOutputStatuses[MapStatus](fetchedBytes, conf) + } + assert(err.getMessage.contains("Unable to deserialize broadcasted output statuses")) + } + } + + test("SPARK-32921: test new protocol changes fetching both Map and Merge status in single RPC") { + val newConf = new SparkConf + newConf.set(RPC_MESSAGE_MAX_SIZE, 1) + newConf.set(RPC_ASK_TIMEOUT, "1") // Fail fast + newConf.set(SHUFFLE_MAPOUTPUT_MIN_SIZE_FOR_BROADCAST, 10240L) // 10 KiB << 1MiB framesize + newConf.set(PUSH_BASED_SHUFFLE_ENABLED, true) + newConf.set(IS_TESTING, true) + + // needs TorrentBroadcast so need a SparkContext + withSpark(new SparkContext("local", "MapOutputTrackerSuite", newConf)) { sc => + val masterTracker = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + val rpcEnv = sc.env.rpcEnv + val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) + rpcEnv.stop(masterTracker.trackerEndpoint) + rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) + val bitmap1 = new RoaringBitmap() + bitmap1.add(1) + + masterTracker.registerShuffle(20, 100, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) + (0 until 100).foreach { i => + masterTracker.registerMapOutput(20, i, new CompressedMapStatus( + BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 5)) } - assert(err.getMessage.contains("Unable to deserialize broadcasted map statuses")) + masterTracker.registerMergeResult(20, 0, MergeStatus(BlockManagerId("999", "mps", 1000), + bitmap1, 1000L)) + + val mapWorkerRpcEnv = createRpcEnv("spark-worker", "localhost", 0, new SecurityManager(conf)) + val mapWorkerTracker = new MapOutputTrackerWorker(conf) + mapWorkerTracker.trackerEndpoint = + mapWorkerRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + + val fetchedBytes = mapWorkerTracker.trackerEndpoint + .askSync[(Array[Byte], Array[Byte])](GetMapAndMergeResultStatuses(20)) + assert(masterTracker.getNumAvailableMergeResults(20) == 1) + assert(masterTracker.getNumAvailableOutputs(20) == 100) + + val mapOutput = + MapOutputTracker.deserializeOutputStatuses[MapStatus](fetchedBytes._1, newConf) + val mergeOutput = + MapOutputTracker.deserializeOutputStatuses[MergeStatus](fetchedBytes._2, newConf) + assert(mapOutput.length == 100) + assert(mergeOutput.length == 1) + mapWorkerTracker.stop() + masterTracker.stop() } } + + test("SPARK-32921: unregister merge result if it is present and contains the map Id") { + val rpcEnv = createRpcEnv("test") + val tracker = newTrackerMaster() + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + tracker.registerShuffle(10, 4, 2) + assert(tracker.containsShuffle(10)) + val bitmap1 = new RoaringBitmap() + bitmap1.add(0) + bitmap1.add(1) + tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000), + bitmap1, 1000L)) + + val bitmap2 = new RoaringBitmap() + bitmap2.add(5) + bitmap2.add(6) + tracker.registerMergeResult(10, 1, MergeStatus(BlockManagerId("b", "hostB", 1000), + bitmap2, 1000L)) + assert(tracker.getNumAvailableMergeResults(10) == 2) + tracker.unregisterMergeResult(10, 0, BlockManagerId("a", "hostA", 1000), Option(0)) + assert(tracker.getNumAvailableMergeResults(10) == 1) + tracker.unregisterMergeResult(10, 1, BlockManagerId("b", "hostB", 1000), Option(1)) + assert(tracker.getNumAvailableMergeResults(10) == 1) + tracker.unregisterMergeResult(10, 1, BlockManagerId("b", "hostB", 1000), Option(5)) + assert(tracker.getNumAvailableMergeResults(10) == 0) + tracker.stop() + rpcEnv.shutdown() + } } diff --git a/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala b/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala index e433f429000c1..d8088239870ba 100644 --- a/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala @@ -19,7 +19,7 @@ package org.apache.spark import org.apache.spark.benchmark.Benchmark import org.apache.spark.benchmark.BenchmarkBase -import org.apache.spark.scheduler.CompressedMapStatus +import org.apache.spark.scheduler.{CompressedMapStatus, MergeStatus} import org.apache.spark.storage.BlockManagerId /** @@ -50,7 +50,7 @@ object MapStatusesSerDeserBenchmark extends BenchmarkBase { val shuffleId = 10 - tracker.registerShuffle(shuffleId, numMaps) + tracker.registerShuffle(shuffleId, numMaps, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) val r = new scala.util.Random(912) (0 until numMaps).foreach { i => tracker.registerMapOutput(shuffleId, i, @@ -66,7 +66,7 @@ object MapStatusesSerDeserBenchmark extends BenchmarkBase { var serializedMapStatusSizes = 0 var serializedBroadcastSizes = 0 - val (serializedMapStatus, serializedBroadcast) = MapOutputTracker.serializeMapStatuses( + val (serializedMapStatus, serializedBroadcast) = MapOutputTracker.serializeOutputStatuses( shuffleStatus.mapStatuses, tracker.broadcastManager, tracker.isLocal, minBroadcastSize, sc.getConf) serializedMapStatusSizes = serializedMapStatus.length @@ -75,12 +75,12 @@ object MapStatusesSerDeserBenchmark extends BenchmarkBase { } benchmark.addCase("Serialization") { _ => - MapOutputTracker.serializeMapStatuses(shuffleStatus.mapStatuses, tracker.broadcastManager, + MapOutputTracker.serializeOutputStatuses(shuffleStatus.mapStatuses, tracker.broadcastManager, tracker.isLocal, minBroadcastSize, sc.getConf) } benchmark.addCase("Deserialization") { _ => - val result = MapOutputTracker.deserializeMapStatuses(serializedMapStatus, sc.getConf) + val result = MapOutputTracker.deserializeOutputStatuses(serializedMapStatus, sc.getConf) assert(result.length == numMaps) } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 56684d9b03271..126faec334e77 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark import java.io.File import java.util.{Locale, Properties} -import java.util.concurrent.{Callable, CyclicBarrier, Executors, ExecutorService} +import java.util.concurrent.{Callable, CyclicBarrier, Executors, ExecutorService } import scala.collection.JavaConverters._ @@ -33,7 +33,7 @@ import org.apache.spark.internal.config import org.apache.spark.internal.config.Tests.TEST_NO_STAGE_RETRY import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD, SubtractedRDD} -import org.apache.spark.scheduler.{MapStatus, MyRDD, SparkListener, SparkListenerTaskEnd} +import org.apache.spark.scheduler.{MapStatus, MergeStatus, MyRDD, SparkListener, SparkListenerTaskEnd} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.shuffle.ShuffleWriter import org.apache.spark.storage.{ShuffleBlockId, ShuffleDataBlockId, ShuffleIndexBlockId} @@ -367,7 +367,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC val shuffleMapRdd = new MyRDD(sc, 1, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) val shuffleHandle = manager.registerShuffle(0, shuffleDep) - mapTrackerMaster.registerShuffle(0, 1) + mapTrackerMaster.registerShuffle(0, 1, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) // first attempt -- its successful val context1 = diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 055ee0debeb12..707e1684f78fd 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -53,7 +53,7 @@ import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, Transpo import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ExecutorDiskUtils, ExternalBlockStoreClient} import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor} import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv} -import org.apache.spark.scheduler.{LiveListenerBus, MapStatus, SparkListenerBlockUpdated} +import org.apache.spark.scheduler.{LiveListenerBus, MapStatus, MergeStatus, SparkListenerBlockUpdated} import org.apache.spark.scheduler.cluster.{CoarseGrainedClusterMessages, CoarseGrainedSchedulerBackend} import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerManager} @@ -1956,7 +1956,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE Files.write(bm1.diskBlockManager.getFile(shuffleIndex).toPath(), shuffleIndexBlockContent) Files.write(bm2.diskBlockManager.getFile(shuffleIndex2).toPath(), shuffleIndexBlockContent) - mapOutputTracker.registerShuffle(0, 1) + mapOutputTracker.registerShuffle(0, 1, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) val decomManager = new BlockManagerDecommissioner(conf, bm1) try { mapOutputTracker.registerMapOutput(0, 0, MapStatus(bm1.blockManagerId, Array(blockSize), 0))