1818package org .apache .spark .scheduler
1919
2020import java .nio .ByteBuffer
21- import java .util
2221import java .util .{Timer , TimerTask }
2322import java .util .concurrent .{ConcurrentHashMap , TimeUnit }
2423import java .util .concurrent .atomic .AtomicLong
@@ -27,6 +26,9 @@ import scala.collection.mutable
2726import scala .collection .mutable .{ArrayBuffer , Buffer , HashMap , HashSet }
2827import scala .util .Random
2928
29+ import com .google .common .base .Ticker
30+ import com .google .common .cache .CacheBuilder
31+
3032import org .apache .spark ._
3133import org .apache .spark .TaskState .TaskState
3234import org .apache .spark .executor .ExecutorMetrics
@@ -137,9 +139,21 @@ private[spark] class TaskSchedulerImpl(
137139 // IDs of the tasks running on each executor
138140 private val executorIdToRunningTaskIds = new HashMap [String , HashSet [Long ]]
139141
142+ // We add executors here when we first get decommission notification for them. Executors can
143+ // continue to run even after being asked to decommission, but they will eventually exit.
140144 val executorsPendingDecommission = new HashMap [String , ExecutorDecommissionInfo ]
141- // map of second to list of executors to clear form the above map
142- val decommissioningExecutorsToGc = new util.TreeMap [Long , mutable.ArrayBuffer [String ]]()
145+
146+ // When they exit and we know of that via heartbeat failure, we will add them to this cache.
147+ // This cache is consulted to know if a fetch failure is because a source executor was
148+ // decommissioned.
149+ lazy val decommissionedExecutorsRemoved = CacheBuilder .newBuilder()
150+ .expireAfterWrite(
151+ conf.getLong(" spark.decommissioningRememberAfterRemoval.seconds" , 60L ), TimeUnit .SECONDS )
152+ .ticker(new Ticker {
153+ override def read (): Long = TimeUnit .MILLISECONDS .toNanos(clock.getTimeMillis())
154+ })
155+ .build[String , ExecutorDecommissionInfo ]()
156+ .asMap()
143157
144158 def runningTasksByExecutors : Map [String , Int ] = synchronized {
145159 executorIdToRunningTaskIds.toMap.mapValues(_.size).toMap
@@ -924,13 +938,9 @@ private[spark] class TaskSchedulerImpl(
924938
925939 override def getExecutorDecommissionInfo (executorId : String )
926940 : Option [ExecutorDecommissionInfo ] = synchronized {
927- import scala .collection .JavaConverters ._
928- // Garbage collect old decommissioning entries
929- val secondsToGcUptil = TimeUnit .MILLISECONDS .toSeconds(clock.getTimeMillis())
930- val headMap = decommissioningExecutorsToGc.headMap(secondsToGcUptil)
931- headMap.values().asScala.flatten.foreach(executorsPendingDecommission -= _)
932- headMap.clear()
933- executorsPendingDecommission.get(executorId)
941+ executorsPendingDecommission
942+ .get(executorId)
943+ .orElse(Option (decommissionedExecutorsRemoved.get(executorId)))
934944 }
935945
936946 override def executorLost (executorId : String , givenReason : ExecutorLossReason ): Unit = {
@@ -1037,14 +1047,8 @@ private[spark] class TaskSchedulerImpl(
10371047 }
10381048
10391049
1040- val decomInfo = executorsPendingDecommission.get(executorId)
1041- if (decomInfo.isDefined) {
1042- val rememberSeconds =
1043- conf.getInt(" spark.decommissioningRememberAfterRemoval.seconds" , 60 )
1044- val gcSecond = TimeUnit .MILLISECONDS .toSeconds(clock.getTimeMillis()) + rememberSeconds
1045- decommissioningExecutorsToGc.computeIfAbsent(gcSecond, _ => mutable.ArrayBuffer .empty) +=
1046- executorId
1047- }
1050+ val decomInfo = executorsPendingDecommission.remove(executorId)
1051+ decomInfo.foreach(decommissionedExecutorsRemoved.put(executorId, _))
10481052
10491053 if (reason != LossReasonPending ) {
10501054 executorIdToHost -= executorId
0 commit comments