diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index bffa1ffc5d39c..e401c395a0486 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -158,6 +158,8 @@ private[spark] class TaskSchedulerImpl( private[scheduler] var barrierCoordinator: RpcEndpoint = null + protected val defaultRackValue: Option[String] = None + private def maybeInitBarrierCoordinator(): Unit = { if (barrierCoordinator == null) { barrierCoordinator = new BarrierCoordinator(barrierSyncTimeout, sc.listenerBus, @@ -394,9 +396,10 @@ private[spark] class TaskSchedulerImpl( executorIdToRunningTaskIds(o.executorId) = HashSet[Long]() newExecAvail = true } - for (rack <- getRackForHost(o.host)) { - hostsByRack.getOrElseUpdate(rack, new HashSet[String]()) += o.host - } + } + val hosts = offers.map(_.host).toSet.toSeq + for ((host, Some(rack)) <- hosts.zip(getRacksForHosts(hosts))) { + hostsByRack.getOrElseUpdate(rack, new HashSet[String]()) += host } // Before making any offers, remove any nodes from the blacklist whose blacklist has expired. Do @@ -830,8 +833,25 @@ private[spark] class TaskSchedulerImpl( blacklistTrackerOpt.map(_.nodeBlacklist()).getOrElse(Set.empty) } - // By default, rack is unknown - def getRackForHost(value: String): Option[String] = None + /** + * Get the rack for one host. + * + * Note that [[getRacksForHosts]] should be preferred when possible as that can be much + * more efficient. + */ + def getRackForHost(host: String): Option[String] = { + getRacksForHosts(Seq(host)).head + } + + /** + * Get racks for multiple hosts. + * + * The returned Sequence will be the same length as the hosts argument and can be zipped + * together with the hosts argument. + */ + def getRacksForHosts(hosts: Seq[String]): Seq[Option[String]] = { + hosts.map(_ => defaultRackValue) + } private def waitBackendReady(): Unit = { if (backend.isReady) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 3977c0bafa57e..144422022c22f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -186,8 +186,24 @@ private[spark] class TaskSetManager( // Add all our tasks to the pending lists. We do this in reverse order // of task index so that tasks with low indices get launched first. - for (i <- (0 until numTasks).reverse) { - addPendingTask(i) + addPendingTasks() + + private def addPendingTasks(): Unit = { + val (_, duration) = Utils.timeTakenMs { + for (i <- (0 until numTasks).reverse) { + addPendingTask(i, resolveRacks = false) + } + // Resolve the rack for each host. This can be slow, so de-dupe the list of hosts, + // and assign the rack to all relevant task indices. + val (hosts, indicesForHosts) = pendingTasksForHost.toSeq.unzip + val racks = sched.getRacksForHosts(hosts) + racks.zip(indicesForHosts).foreach { + case (Some(rack), indices) => + pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer) ++= indices + case (None, _) => // no rack, nothing to do + } + } + logDebug(s"Adding pending tasks took $duration ms") } /** @@ -214,7 +230,9 @@ private[spark] class TaskSetManager( private[scheduler] var emittedTaskSizeWarning = false /** Add a task to all the pending-task lists that it should be on. */ - private[spark] def addPendingTask(index: Int) { + private[spark] def addPendingTask( + index: Int, + resolveRacks: Boolean = true): Unit = { for (loc <- tasks(index).preferredLocations) { loc match { case e: ExecutorCacheTaskLocation => @@ -234,8 +252,11 @@ private[spark] class TaskSetManager( case _ => } pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer) += index - for (rack <- sched.getRackForHost(loc.host)) { - pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer) += index + + if (resolveRacks) { + sched.getRackForHost(loc.host).foreach { rack => + pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer) += index + } } } @@ -331,7 +352,7 @@ private[spark] class TaskSetManager( val executors = prefs.flatMap(_ match { case e: ExecutorCacheTaskLocation => Some(e.executorId) case _ => None - }); + }) if (executors.contains(execId)) { speculatableTasks -= index return Some((index, TaskLocality.PROCESS_LOCAL)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index ad03194fe4c31..79160d05b3e60 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -22,8 +22,8 @@ import java.util.{Properties, Random} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.mockito.ArgumentMatchers.{any, anyInt, anyString} -import org.mockito.Mockito.{mock, never, spy, times, verify, when} +import org.mockito.ArgumentMatchers.{any, anyBoolean, anyInt, anyString} +import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.apache.spark._ @@ -68,17 +68,27 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) // Get the rack for a given host object FakeRackUtil { private val hostToRack = new mutable.HashMap[String, String]() + var numBatchInvocation = 0 + var numSingleHostInvocation = 0 def cleanUp() { hostToRack.clear() + numBatchInvocation = 0 + numSingleHostInvocation = 0 } def assignHostToRack(host: String, rack: String) { hostToRack(host) = rack } - def getRackForHost(host: String): Option[String] = { - hostToRack.get(host) + def getRacksForHosts(hosts: Seq[String]): Seq[Option[String]] = { + assert(hosts.toSet.size == hosts.size) // no dups in hosts + if (hosts.nonEmpty && hosts.length != 1) { + numBatchInvocation += 1 + } else if (hosts.length == 1) { + numSingleHostInvocation += 1 + } + hosts.map(hostToRack.get(_)) } } @@ -98,6 +108,9 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex val speculativeTasks = new ArrayBuffer[Int] val executors = new mutable.HashMap[String, String] + + // this must be initialized before addExecutor + override val defaultRackValue: Option[String] = Some("default") for ((execId, host) <- liveExecutors) { addExecutor(execId, host) } @@ -143,8 +156,9 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex } } - - override def getRackForHost(value: String): Option[String] = FakeRackUtil.getRackForHost(value) + override def getRacksForHosts(hosts: Seq[String]): Seq[Option[String]] = { + FakeRackUtil.getRacksForHosts(hosts) + } } /** @@ -1311,7 +1325,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val taskDesc = taskSetManagerSpy.resourceOffer(exec, host, TaskLocality.ANY) // Assert the task has been black listed on the executor it was last executed on. - when(taskSetManagerSpy.addPendingTask(anyInt())).thenAnswer( + when(taskSetManagerSpy.addPendingTask(anyInt(), anyBoolean())).thenAnswer( (invocationOnMock: InvocationOnMock) => { val task: Int = invocationOnMock.getArgument(0) assert(taskSetManager.taskSetBlacklistHelperOpt.get. @@ -1323,7 +1337,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val e = new ExceptionFailure("a", "b", Array(), "c", None) taskSetManagerSpy.handleFailedTask(taskDesc.get.taskId, TaskState.FAILED, e) - verify(taskSetManagerSpy, times(1)).addPendingTask(anyInt()) + verify(taskSetManagerSpy, times(1)).addPendingTask(0, false) } test("SPARK-21563 context's added jars shouldn't change mid-TaskSet") { @@ -1595,4 +1609,50 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg verify(sched.dagScheduler).taskEnded(manager.tasks(3), Success, result.value(), result.accumUpdates, info3) } + + test("SPARK-13704 Rack Resolution is done with a batch of de-duped hosts") { + val conf = new SparkConf() + .set(config.LOCALITY_WAIT, 0L) + .set(config.LOCALITY_WAIT_RACK, 1L) + sc = new SparkContext("local", "test", conf) + // Create a cluster with 20 racks, with hosts spread out among them + val execAndHost = (0 to 199).map { i => + FakeRackUtil.assignHostToRack("host" + i, "rack" + (i % 20)) + ("exec" + i, "host" + i) + } + sched = new FakeTaskScheduler(sc, execAndHost: _*) + // make a taskset with preferred locations on the first 100 hosts in our cluster + val locations = new ArrayBuffer[Seq[TaskLocation]]() + for (i <- 0 to 99) { + locations += Seq(TaskLocation("host" + i)) + } + val taskSet = FakeTask.createTaskSet(100, locations: _*) + val clock = new ManualClock + // make sure we only do one rack resolution call, for the entire batch of hosts, as this + // can be expensive. The FakeTaskScheduler calls rack resolution more than the real one + // -- that is outside of the scope of this test, we just want to check the task set manager. + FakeRackUtil.numBatchInvocation = 0 + FakeRackUtil.numSingleHostInvocation = 0 + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) + assert(FakeRackUtil.numBatchInvocation === 1) + assert(FakeRackUtil.numSingleHostInvocation === 0) + // with rack locality, reject an offer on a host with an unknown rack + assert(manager.resourceOffer("otherExec", "otherHost", TaskLocality.RACK_LOCAL).isEmpty) + (0 until 20).foreach { rackIdx => + (0 until 5).foreach { offerIdx => + // if we offer hosts which are not in preferred locations, + // we'll reject them at NODE_LOCAL level, + // but accept them at RACK_LOCAL level if they're on OK racks + val hostIdx = 100 + rackIdx + assert(manager.resourceOffer("exec" + hostIdx, "host" + hostIdx, TaskLocality.NODE_LOCAL) + .isEmpty) + assert(manager.resourceOffer("exec" + hostIdx, "host" + hostIdx, TaskLocality.RACK_LOCAL) + .isDefined) + } + } + // check no more expensive calls to the rack resolution. manager.resourceOffer() will call + // the single-host resolution, but the real rack resolution would have cached all hosts + // by that point. + assert(FakeRackUtil.numBatchInvocation === 1) + } } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala index 0a7a16f468fbd..2288bb55d8b47 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala @@ -138,9 +138,8 @@ private[yarn] class LocalityPreferredContainerPlacementStrategy( // Only filter out the ratio which is larger than 0, which means the current host can // still be allocated with new container request. val hosts = preferredLocalityRatio.filter(_._2 > 0).keys.toArray - val racks = hosts.map { h => - resolver.resolve(yarnConf, h) - }.toSet + val racks = resolver.resolve(hosts).map(_.getNetworkLocation) + .filter(_ != null).toSet containerLocalityPreferences += ContainerLocalityPreferences(hosts, racks.toArray) // Minus 1 each time when the host is used. When the current ratio is 0, diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/SparkRackResolver.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/SparkRackResolver.scala index c711d088f2116..cab32724e13a6 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/SparkRackResolver.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/SparkRackResolver.scala @@ -17,24 +17,100 @@ package org.apache.spark.deploy.yarn +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import com.google.common.base.Strings import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.CommonConfigurationKeysPublic +import org.apache.hadoop.net._ +import org.apache.hadoop.util.ReflectionUtils import org.apache.hadoop.yarn.util.RackResolver import org.apache.log4j.{Level, Logger} +import org.apache.spark.internal.Logging + /** - * Wrapper around YARN's [[RackResolver]]. This allows Spark tests to easily override the - * default behavior, since YARN's class self-initializes the first time it's called, and - * future calls all use the initial configuration. + * Re-implement YARN's [[RackResolver]] for hadoop releases without YARN-9332. + * This also allows Spark tests to easily override the default behavior, since YARN's class + * self-initializes the first time it's called, and future calls all use the initial configuration. */ -private[yarn] class SparkRackResolver { +private[spark] class SparkRackResolver(conf: Configuration) extends Logging { // RackResolver logs an INFO message whenever it resolves a rack, which is way too often. if (Logger.getLogger(classOf[RackResolver]).getLevel == null) { Logger.getLogger(classOf[RackResolver]).setLevel(Level.WARN) } - def resolve(conf: Configuration, hostName: String): String = { - RackResolver.resolve(conf, hostName).getNetworkLocation() + private val dnsToSwitchMapping: DNSToSwitchMapping = { + val dnsToSwitchMappingClass = + conf.getClass(CommonConfigurationKeysPublic.NET_TOPOLOGY_NODE_SWITCH_MAPPING_IMPL_KEY, + classOf[ScriptBasedMapping], classOf[DNSToSwitchMapping]) + ReflectionUtils.newInstance(dnsToSwitchMappingClass, conf) + .asInstanceOf[DNSToSwitchMapping] match { + case c: CachedDNSToSwitchMapping => c + case o => new CachedDNSToSwitchMapping(o) + } + } + + def resolve(hostName: String): String = { + coreResolve(Seq(hostName)).head.getNetworkLocation + } + + /** + * Added in SPARK-13704. + * This should be changed to `RackResolver.resolve(conf, hostNames)` + * in hadoop releases with YARN-9332. + */ + def resolve(hostNames: Seq[String]): Seq[Node] = { + coreResolve(hostNames) + } + + private def coreResolve(hostNames: Seq[String]): Seq[Node] = { + val nodes = new ArrayBuffer[Node] + // dnsToSwitchMapping is thread-safe + val rNameList = dnsToSwitchMapping.resolve(hostNames.toList.asJava).asScala + if (rNameList == null || rNameList.isEmpty) { + hostNames.foreach(nodes += new NodeBase(_, NetworkTopology.DEFAULT_RACK)) + logInfo(s"Got an error when resolving hostNames. " + + s"Falling back to ${NetworkTopology.DEFAULT_RACK} for all") + } else { + for ((hostName, rName) <- hostNames.zip(rNameList)) { + if (Strings.isNullOrEmpty(rName)) { + nodes += new NodeBase(hostName, NetworkTopology.DEFAULT_RACK) + logDebug(s"Could not resolve $hostName. " + + s"Falling back to ${NetworkTopology.DEFAULT_RACK}") + } else { + nodes += new NodeBase(hostName, rName) + } + } + } + nodes.toList + } +} + +/** + * Utility to resolve the rack for hosts in an efficient manner. + * It will cache the rack for individual hosts to avoid + * repeatedly performing the same expensive lookup. + */ +object SparkRackResolver extends Logging { + @volatile private var instance: SparkRackResolver = _ + + /** + * It will return the static resolver instance. If there is already an instance, the passed + * conf is entirely ignored. If there is not a shared instance, it will create one with the + * given conf. + */ + def get(conf: Configuration): SparkRackResolver = { + if (instance == null) { + synchronized { + if (instance == null) { + instance = new SparkRackResolver(conf) + } + } + } + instance } } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 5d939cfd41f9b..1dc9d49f17a14 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -432,7 +432,7 @@ private[yarn] class YarnAllocator( override def run(): Unit = { try { for (allocatedContainer <- remainingAfterHostMatches) { - val rack = resolver.resolve(conf, allocatedContainer.getNodeId.getHost) + val rack = resolver.resolve(allocatedContainer.getNodeId.getHost) matchContainerToRequest(allocatedContainer, rack, containersToUse, remainingAfterRackMatches) } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index cf16edf16c034..7c67493c33160 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -83,7 +83,7 @@ private[spark] class YarnRMClient extends Logging { localResources: Map[String, LocalResource]): YarnAllocator = { require(registered, "Must register AM before creating allocator.") new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, appAttemptId, securityMgr, - localResources, new SparkRackResolver()) + localResources, SparkRackResolver.get(conf)) } /** diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala index 029382133ddf2..d466ed77a929e 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala @@ -17,23 +17,23 @@ package org.apache.spark.scheduler.cluster -import org.apache.hadoop.yarn.util.RackResolver -import org.apache.log4j.{Level, Logger} +import org.apache.hadoop.net.NetworkTopology import org.apache.spark._ +import org.apache.spark.deploy.yarn.SparkRackResolver import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.util.Utils private[spark] class YarnScheduler(sc: SparkContext) extends TaskSchedulerImpl(sc) { - // RackResolver logs an INFO message whenever it resolves a rack, which is way too often. - if (Logger.getLogger(classOf[RackResolver]).getLevel == null) { - Logger.getLogger(classOf[RackResolver]).setLevel(Level.WARN) - } + override val defaultRackValue: Option[String] = Some(NetworkTopology.DEFAULT_RACK) + + private[spark] val resolver = SparkRackResolver.get(sc.hadoopConfiguration) - // By default, rack is unknown - override def getRackForHost(hostPort: String): Option[String] = { - val host = Utils.parseHostPort(hostPort)._1 - Option(RackResolver.resolve(sc.hadoopConfiguration, host).getNetworkLocation) + override def getRacksForHosts(hostPorts: Seq[String]): Seq[Option[String]] = { + val hosts = hostPorts.map(Utils.parseHostPort(_)._1) + resolver.resolve(hosts).map { node => + Option(node.getNetworkLocation) + } } } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 42b59663af0b3..59291af72eacb 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -31,6 +31,7 @@ import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfterEach, Matchers} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.config._ @@ -38,9 +39,9 @@ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.SplitInfo import org.apache.spark.util.ManualClock -class MockResolver extends SparkRackResolver { +class MockResolver extends SparkRackResolver(SparkHadoopUtil.get.conf) { - override def resolve(conf: Configuration, hostName: String): String = { + override def resolve(hostName: String): String = { if (hostName == "host3") "/rack2" else "/rack1" }