diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index ba8e4d69ba755..d21b9d9833e9e 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -23,6 +23,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleHandle, ShuffleWriteProcessor} +import org.apache.spark.storage.BlockManagerId /** * :: DeveloperApi :: @@ -95,6 +96,20 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle( shuffleId, this) + /** + * Stores the location of the list of chosen external shuffle services for handling the + * shuffle merge requests from mappers in this shuffle map stage. + */ + private[spark] var mergerLocs: Seq[BlockManagerId] = Nil + + def setMergerLocs(mergerLocs: Seq[BlockManagerId]): Unit = { + if (mergerLocs != null) { + this.mergerLocs = mergerLocs + } + } + + def getMergerLocs: Seq[BlockManagerId] = mergerLocs + _rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this)) _rdd.sparkContext.shuffleDriverComponents.registerShuffle(shuffleId) } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 4bc49514fc5ad..b38d0e5c617b9 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1945,4 +1945,51 @@ package object config { .version("3.0.1") .booleanConf .createWithDefault(false) + + private[spark] val PUSH_BASED_SHUFFLE_ENABLED = + ConfigBuilder("spark.shuffle.push.enabled") + .doc("Set to 'true' to enable push-based shuffle on the client side and this works in " + + "conjunction with the server side flag spark.shuffle.server.mergedShuffleFileManagerImpl " + + "which needs to be set with the appropriate " + + "org.apache.spark.network.shuffle.MergedShuffleFileManager implementation for push-based " + + "shuffle to be enabled") + .version("3.1.0") + .booleanConf + .createWithDefault(false) + + private[spark] val SHUFFLE_MERGER_MAX_RETAINED_LOCATIONS = + ConfigBuilder("spark.shuffle.push.maxRetainedMergerLocations") + .doc("Maximum number of shuffle push merger locations cached for push based shuffle. " + + "Currently, shuffle push merger locations are nothing but external shuffle services " + + "which are responsible for handling pushed blocks and merging them and serving " + + "merged blocks for later shuffle fetch.") + .version("3.1.0") + .intConf + .createWithDefault(500) + + private[spark] val SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO = + ConfigBuilder("spark.shuffle.push.mergersMinThresholdRatio") + .doc("The minimum number of shuffle merger locations required to enable push based " + + "shuffle for a stage. This is specified as a ratio of the number of partitions in " + + "the child stage. For example, a reduce stage which has 100 partitions and uses the " + + "default value 0.05 requires at least 5 unique merger locations to enable push based " + + "shuffle. Merger locations are currently defined as external shuffle services.") + .version("3.1.0") + .doubleConf + .createWithDefault(0.05) + + private[spark] val SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD = + ConfigBuilder("spark.shuffle.push.mergersMinStaticThreshold") + .doc(s"The static threshold for number of shuffle push merger locations should be " + + "available in order to enable push based shuffle for a stage. Note this config " + + s"works in conjunction with ${SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO.key}. " + + "Maximum of spark.shuffle.push.mergersMinStaticThreshold and " + + s"${SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO.key} ratio number of mergers needed to " + + "enable push based shuffle for a stage. For eg: with 1000 partitions for the child " + + "stage with spark.shuffle.push.mergersMinStaticThreshold as 5 and " + + s"${SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO.key} set to 0.05, we would need " + + "at least 50 mergers to enable push based shuffle for that stage.") + .version("3.1.0") + .doubleConf + .createWithDefault(5) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 13b766e654832..6fb0fb93f253b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -249,6 +249,8 @@ private[spark] class DAGScheduler( private[spark] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) + private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(sc.getConf) + /** * Called by the TaskSetManager to report task's starting. */ @@ -1252,6 +1254,33 @@ private[spark] class DAGScheduler( execCores.map(cores => properties.setProperty(EXECUTOR_CORES_LOCAL_PROPERTY, cores)) } + /** + * If push based shuffle is enabled, set the shuffle services to be used for the given + * shuffle map stage for block push/merge. + * + * Even with dynamic resource allocation kicking in and significantly reducing the number + * of available active executors, we would still be able to get sufficient shuffle service + * locations for block push/merge by getting the historical locations of past executors. + */ + private def prepareShuffleServicesForShuffleMapStage(stage: ShuffleMapStage): Unit = { + // TODO(SPARK-32920) Handle stage reuse/retry cases separately as without finalize + // TODO changes we cannot disable shuffle merge for the retry/reuse cases + val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations( + stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId) + + if (mergerLocs.nonEmpty) { + stage.shuffleDep.setMergerLocs(mergerLocs) + logInfo(s"Push-based shuffle enabled for $stage (${stage.name}) with" + + s" ${stage.shuffleDep.getMergerLocs.size} merger locations") + + logDebug("List of shuffle push merger locations " + + s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}") + } else { + logInfo("No available merger locations." + + s" Push-based shuffle disabled for $stage (${stage.name})") + } + } + /** Called when stage's parents are available and we can now do its task. */ private def submitMissingTasks(stage: Stage, jobId: Int): Unit = { logDebug("submitMissingTasks(" + stage + ")") @@ -1281,6 +1310,12 @@ private[spark] class DAGScheduler( stage match { case s: ShuffleMapStage => outputCommitCoordinator.stageStart(stage = s.id, maxPartitionId = s.numPartitions - 1) + // Only generate merger location for a given shuffle dependency once. This way, even if + // this stage gets retried, it would still be merging blocks using the same set of + // shuffle services. + if (pushBasedShuffleEnabled) { + prepareShuffleServicesForShuffleMapStage(s) + } case s: ResultStage => outputCommitCoordinator.stageStart( stage = s.id, maxPartitionId = s.rdd.partitions.length - 1) @@ -2027,6 +2062,11 @@ private[spark] class DAGScheduler( if (!executorFailureEpoch.contains(execId) || executorFailureEpoch(execId) < currentEpoch) { executorFailureEpoch(execId) = currentEpoch logInfo(s"Executor lost: $execId (epoch $currentEpoch)") + if (pushBasedShuffleEnabled) { + // Remove fetchFailed host in the shuffle push merger list for push based shuffle + hostToUnregisterOutputs.foreach( + host => blockManagerMaster.removeShufflePushMergerLocation(host)) + } blockManagerMaster.removeExecutor(execId) clearCacheLocs() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala index a566d0a04387c..b2acdb3e12a6d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import org.apache.spark.resource.ResourceProfile +import org.apache.spark.storage.BlockManagerId /** * A backend interface for scheduling systems that allows plugging in different ones under @@ -92,4 +93,16 @@ private[spark] trait SchedulerBackend { */ def maxNumConcurrentTasks(rp: ResourceProfile): Int + /** + * Get the list of host locations for push based shuffle + * + * Currently push based shuffle is disabled for both stage retry and stage reuse cases + * (for eg: in the case where few partitions are lost due to failure). Hence this method + * should be invoked only once for a ShuffleDependency. + * @return List of external shuffle services locations + */ + def getShufflePushMergerLocations( + numPartitions: Int, + resourceProfileId: Int): Seq[BlockManagerId] = Nil + } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index 49e32d04d450a..c6a4457d8f910 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -145,4 +145,6 @@ private[spark] object BlockManagerId { def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = { blockManagerIdCache.get(id) } + + private[spark] val SHUFFLE_MERGER_IDENTIFIER = "shuffle-push-merger" } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index f544d47b8e13c..fe1a5aef9499c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -125,6 +125,26 @@ class BlockManagerMaster( driverEndpoint.askSync[Seq[BlockManagerId]](GetPeers(blockManagerId)) } + /** + * Get a list of unique shuffle service locations where an executor is successfully + * registered in the past for block push/merge with push based shuffle. + */ + def getShufflePushMergerLocations( + numMergersNeeded: Int, + hostsToFilter: Set[String]): Seq[BlockManagerId] = { + driverEndpoint.askSync[Seq[BlockManagerId]]( + GetShufflePushMergerLocations(numMergersNeeded, hostsToFilter)) + } + + /** + * Remove the host from the candidate list of shuffle push mergers. This can be + * triggered if there is a FetchFailedException on the host + * @param host + */ + def removeShufflePushMergerLocation(host: String): Unit = { + driverEndpoint.askSync[Seq[BlockManagerId]](RemoveShufflePushMergerLocation(host)) + } + def getExecutorEndpointRef(executorId: String): Option[RpcEndpointRef] = { driverEndpoint.askSync[Option[RpcEndpointRef]](GetExecutorEndpointRef(executorId)) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index a7532a9870fae..4d565511704d4 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -74,6 +74,14 @@ class BlockManagerMasterEndpoint( // Mapping from block id to the set of block managers that have the block. private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]] + // Mapping from host name to shuffle (mergers) services where the current app + // registered an executor in the past. Older hosts are removed when the + // maxRetainedMergerLocations size is reached in favor of newer locations. + private val shuffleMergerLocations = new mutable.LinkedHashMap[String, BlockManagerId]() + + // Maximum number of merger locations to cache + private val maxRetainedMergerLocations = conf.get(config.SHUFFLE_MERGER_MAX_RETAINED_LOCATIONS) + private val askThreadPool = ThreadUtils.newDaemonCachedThreadPool("block-manager-ask-thread-pool", 100) private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool) @@ -92,6 +100,8 @@ class BlockManagerMasterEndpoint( val defaultRpcTimeout = RpcUtils.askRpcTimeout(conf) + private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(conf) + logInfo("BlockManagerMasterEndpoint up") // same as `conf.get(config.SHUFFLE_SERVICE_ENABLED) // && conf.get(config.SHUFFLE_SERVICE_FETCH_RDD_ENABLED)` @@ -139,6 +149,12 @@ class BlockManagerMasterEndpoint( case GetBlockStatus(blockId, askStorageEndpoints) => context.reply(blockStatus(blockId, askStorageEndpoints)) + case GetShufflePushMergerLocations(numMergersNeeded, hostsToFilter) => + context.reply(getShufflePushMergerLocations(numMergersNeeded, hostsToFilter)) + + case RemoveShufflePushMergerLocation(host) => + context.reply(removeShufflePushMergerLocation(host)) + case IsExecutorAlive(executorId) => context.reply(blockManagerIdByExecutor.contains(executorId)) @@ -360,6 +376,17 @@ class BlockManagerMasterEndpoint( } + private def addMergerLocation(blockManagerId: BlockManagerId): Unit = { + if (!blockManagerId.isDriver && !shuffleMergerLocations.contains(blockManagerId.host)) { + val shuffleServerId = BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, + blockManagerId.host, externalShuffleServicePort) + if (shuffleMergerLocations.size >= maxRetainedMergerLocations) { + shuffleMergerLocations -= shuffleMergerLocations.head._1 + } + shuffleMergerLocations(shuffleServerId.host) = shuffleServerId + } + } + private def removeExecutor(execId: String): Unit = { logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.") blockManagerIdByExecutor.get(execId).foreach(removeBlockManager) @@ -526,6 +553,10 @@ class BlockManagerMasterEndpoint( blockManagerInfo(id) = new BlockManagerInfo(id, System.currentTimeMillis(), maxOnHeapMemSize, maxOffHeapMemSize, storageEndpoint, externalShuffleServiceBlockStatus) + + if (pushBasedShuffleEnabled) { + addMergerLocation(id) + } } listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxOnHeapMemSize + maxOffHeapMemSize, Some(maxOnHeapMemSize), Some(maxOffHeapMemSize))) @@ -657,6 +688,40 @@ class BlockManagerMasterEndpoint( } } + private def getShufflePushMergerLocations( + numMergersNeeded: Int, + hostsToFilter: Set[String]): Seq[BlockManagerId] = { + val blockManagerHosts = blockManagerIdByExecutor.values.map(_.host).toSet + val filteredBlockManagerHosts = blockManagerHosts.filterNot(hostsToFilter.contains(_)) + val filteredMergersWithExecutors = filteredBlockManagerHosts.map( + BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, _, externalShuffleServicePort)) + // Enough mergers are available as part of active executors list + if (filteredMergersWithExecutors.size >= numMergersNeeded) { + filteredMergersWithExecutors.toSeq + } else { + // Delta mergers added from inactive mergers list to the active mergers list + val filteredMergersWithExecutorsHosts = filteredMergersWithExecutors.map(_.host) + val filteredMergersWithoutExecutors = shuffleMergerLocations.values + .filterNot(x => hostsToFilter.contains(x.host)) + .filterNot(x => filteredMergersWithExecutorsHosts.contains(x.host)) + val randomFilteredMergersLocations = + if (filteredMergersWithoutExecutors.size > + numMergersNeeded - filteredMergersWithExecutors.size) { + Utils.randomize(filteredMergersWithoutExecutors) + .take(numMergersNeeded - filteredMergersWithExecutors.size) + } else { + filteredMergersWithoutExecutors + } + filteredMergersWithExecutors.toSeq ++ randomFilteredMergersLocations + } + } + + private def removeShufflePushMergerLocation(host: String): Unit = { + if (shuffleMergerLocations.contains(host)) { + shuffleMergerLocations.remove(host) + } + } + /** * Returns an [[RpcEndpointRef]] of the [[BlockManagerReplicaEndpoint]] for sending RPC messages. */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index bbc076cea9ba8..afe416a55ed0d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -141,4 +141,10 @@ private[spark] object BlockManagerMessages { case class BlockManagerHeartbeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster case class IsExecutorAlive(executorId: String) extends ToBlockManagerMaster + + case class GetShufflePushMergerLocations(numMergersNeeded: Int, hostsToFilter: Set[String]) + extends ToBlockManagerMaster + + case class RemoveShufflePushMergerLocation(host: String) extends ToBlockManagerMaster + } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index b743ab6507117..6ccf65b737c1a 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2541,6 +2541,14 @@ private[spark] object Utils extends Logging { master == "local" || master.startsWith("local[") } + /** + * Push based shuffle can only be enabled when external shuffle service is enabled. + */ + def isPushBasedShuffleEnabled(conf: SparkConf): Boolean = { + conf.get(PUSH_BASED_SHUFFLE_ENABLED) && + (conf.get(IS_TESTING).getOrElse(false) || conf.get(SHUFFLE_SERVICE_ENABLED)) + } + /** * Return whether dynamic allocation is enabled in the given conf. */ diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 55280fc578310..144489c5f7922 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -100,6 +100,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE .set(Kryo.KRYO_SERIALIZER_BUFFER_SIZE.key, "1m") .set(STORAGE_UNROLL_MEMORY_THRESHOLD, 512L) .set(Network.RPC_ASK_TIMEOUT, "5s") + .set(PUSH_BASED_SHUFFLE_ENABLED, true) } private def makeSortShuffleManager(): SortShuffleManager = { @@ -1974,6 +1975,48 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE } } + test("SPARK-32919: Shuffle push merger locations should be bounded with in" + + " spark.shuffle.push.retainedMergerLocations") { + assert(master.getShufflePushMergerLocations(10, Set.empty).isEmpty) + makeBlockManager(100, "execA", + transferService = Some(new MockBlockTransferService(10, "hostA"))) + makeBlockManager(100, "execB", + transferService = Some(new MockBlockTransferService(10, "hostB"))) + makeBlockManager(100, "execC", + transferService = Some(new MockBlockTransferService(10, "hostC"))) + makeBlockManager(100, "execD", + transferService = Some(new MockBlockTransferService(10, "hostD"))) + makeBlockManager(100, "execE", + transferService = Some(new MockBlockTransferService(10, "hostA"))) + assert(master.getShufflePushMergerLocations(10, Set.empty).size == 4) + assert(master.getShufflePushMergerLocations(10, Set.empty).map(_.host).sorted === + Seq("hostC", "hostD", "hostA", "hostB").sorted) + assert(master.getShufflePushMergerLocations(10, Set("hostB")).size == 3) + } + + test("SPARK-32919: Prefer active executor locations for shuffle push mergers") { + makeBlockManager(100, "execA", + transferService = Some(new MockBlockTransferService(10, "hostA"))) + makeBlockManager(100, "execB", + transferService = Some(new MockBlockTransferService(10, "hostB"))) + makeBlockManager(100, "execC", + transferService = Some(new MockBlockTransferService(10, "hostC"))) + makeBlockManager(100, "execD", + transferService = Some(new MockBlockTransferService(10, "hostD"))) + makeBlockManager(100, "execE", + transferService = Some(new MockBlockTransferService(10, "hostA"))) + assert(master.getShufflePushMergerLocations(5, Set.empty).size == 4) + + master.removeExecutor("execA") + master.removeExecutor("execE") + + assert(master.getShufflePushMergerLocations(3, Set.empty).size == 3) + assert(master.getShufflePushMergerLocations(3, Set.empty).map(_.host).sorted === + Seq("hostC", "hostB", "hostD").sorted) + assert(master.getShufflePushMergerLocations(4, Set.empty).map(_.host).sorted === + Seq("hostB", "hostA", "hostC", "hostD").sorted) + } + test("SPARK-33387 Support ordered shuffle block migration") { val blocks: Seq[ShuffleBlockInfo] = Seq( ShuffleBlockInfo(1, 0L), @@ -1995,7 +2038,9 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(sortedBlocks.sameElements(decomManager.shufflesToMigrate.asScala.map(_._1))) } - class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService { + class MockBlockTransferService( + val maxFailures: Int, + override val hostName: String = "MockBlockTransferServiceHost") extends BlockTransferService { var numCalls = 0 var tempFileManager: DownloadFileManager = null @@ -2013,8 +2058,6 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE override def close(): Unit = {} - override def hostName: String = { "MockBlockTransferServiceHost" } - override def port: Int = { 63332 } override def uploadBlock( diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 20624c743bc22..8fb408041ca9d 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -41,6 +41,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.{SparkConf, SparkException, SparkFunSuite, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.SparkListener import org.apache.spark.util.io.ChunkedByteBufferInputStream @@ -1432,6 +1433,17 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { }.getMessage assert(message.contains(expected)) } + + test("isPushBasedShuffleEnabled when both PUSH_BASED_SHUFFLE_ENABLED" + + " and SHUFFLE_SERVICE_ENABLED are true") { + val conf = new SparkConf() + assert(Utils.isPushBasedShuffleEnabled(conf) === false) + conf.set(PUSH_BASED_SHUFFLE_ENABLED, true) + conf.set(IS_TESTING, false) + assert(Utils.isPushBasedShuffleEnabled(conf) === false) + conf.set(SHUFFLE_SERVICE_ENABLED, true) + assert(Utils.isPushBasedShuffleEnabled(conf) === true) + } } private class SimpleExtension diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index b42bdb9816600..22002bb32004d 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler.cluster import java.util.EnumSet -import java.util.concurrent.atomic.{AtomicBoolean} +import java.util.concurrent.atomic.AtomicBoolean import javax.servlet.DispatcherType import scala.concurrent.{ExecutionContext, Future} @@ -29,14 +29,14 @@ import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId} import org.apache.spark.SparkContext import org.apache.spark.deploy.security.HadoopDelegationTokenManager -import org.apache.spark.internal.Logging -import org.apache.spark.internal.config +import org.apache.spark.internal.{config, Logging} import org.apache.spark.internal.config.UI._ import org.apache.spark.resource.ResourceProfile import org.apache.spark.rpc._ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{RpcUtils, ThreadUtils} +import org.apache.spark.storage.{BlockManagerId, BlockManagerMaster} +import org.apache.spark.util.{RpcUtils, ThreadUtils, Utils} /** * Abstract Yarn scheduler backend that contains common logic @@ -80,6 +80,18 @@ private[spark] abstract class YarnSchedulerBackend( /** Attempt ID. This is unset for client-mode schedulers */ private var attemptId: Option[ApplicationAttemptId] = None + private val blockManagerMaster: BlockManagerMaster = sc.env.blockManager.master + + private val minMergersThresholdRatio = + conf.get(config.SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO) + + private val minMergersStaticThreshold = + conf.get(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD) + + private val maxNumExecutors = conf.get(config.DYN_ALLOCATION_MAX_EXECUTORS) + + private val numExecutors = conf.get(config.EXECUTOR_INSTANCES).getOrElse(0) + /** * Bind to YARN. This *must* be done before calling [[start()]]. * @@ -161,6 +173,36 @@ private[spark] abstract class YarnSchedulerBackend( totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio } + override def getShufflePushMergerLocations( + numPartitions: Int, + resourceProfileId: Int): Seq[BlockManagerId] = { + // TODO (SPARK-33481) This is a naive way of calculating numMergersDesired for a stage, + // TODO we can use better heuristics to calculate numMergersDesired for a stage. + val maxExecutors = if (Utils.isDynamicAllocationEnabled(sc.getConf)) { + maxNumExecutors + } else { + numExecutors + } + val tasksPerExecutor = sc.resourceProfileManager + .resourceProfileFromId(resourceProfileId).maxTasksPerExecutor(sc.conf) + val numMergersDesired = math.min( + math.max(1, math.ceil(numPartitions / tasksPerExecutor).toInt), maxExecutors) + val minMergersNeeded = math.max(minMergersStaticThreshold, + math.floor(numMergersDesired * minMergersThresholdRatio).toInt) + + // Request for numMergersDesired shuffle mergers to BlockManagerMasterEndpoint + // and if it's less than minMergersNeeded, we disable push based shuffle. + val mergerLocations = blockManagerMaster + .getShufflePushMergerLocations(numMergersDesired, scheduler.excludedNodes()) + if (mergerLocations.size < numMergersDesired && mergerLocations.size < minMergersNeeded) { + Seq.empty[BlockManagerId] + } else { + logDebug(s"The number of shuffle mergers desired ${numMergersDesired}" + + s" and available locations are ${mergerLocations.length}") + mergerLocations + } + } + /** * Add filters to the SparkUI. */