1717
1818package org .apache .spark .storage
1919
20- import java .util .concurrent .{ConcurrentLinkedQueue , Semaphore }
20+ import java .util .concurrent .{ConcurrentHashMap , ConcurrentLinkedQueue , Semaphore }
2121
2222import scala .collection .JavaConverters ._
2323import scala .collection .mutable .ArrayBuffer
@@ -29,7 +29,7 @@ import org.apache.spark._
2929import org .apache .spark .internal .config
3030import org .apache .spark .scheduler ._
3131import org .apache .spark .scheduler .cluster .StandaloneSchedulerBackend
32- import org .apache .spark .util .{ResetSystemProperties , ThreadUtils }
32+ import org .apache .spark .util .{ResetSystemProperties , SystemClock , ThreadUtils }
3333
3434class BlockManagerDecommissionIntegrationSuite extends SparkFunSuite with LocalSparkContext
3535 with ResetSystemProperties with Eventually {
@@ -73,6 +73,10 @@ class BlockManagerDecommissionIntegrationSuite extends SparkFunSuite with LocalS
7373 // workload we need to worry about.
7474 .set(config.STORAGE_DECOMMISSION_REPLICATION_REATTEMPT_INTERVAL , 1L )
7575
76+ if (whenToDecom == TaskStarted ) {
77+ // We are using accumulators below, make sure those are reported frequently.
78+ conf.set(config.EXECUTOR_HEARTBEAT_INTERVAL .key, " 10ms" )
79+ }
7680 sc = new SparkContext (master, " test" , conf)
7781
7882 // Wait for the executors to start
@@ -81,7 +85,7 @@ class BlockManagerDecommissionIntegrationSuite extends SparkFunSuite with LocalS
8185 timeout = 60000 ) // 60s
8286
8387 val input = sc.parallelize(1 to numParts, numParts)
84- val accum = sc.collectionAccumulator[ String ] (" mapperRunAccumulator" )
88+ val accum = sc.longAccumulator (" mapperRunAccumulator" )
8589
8690 val sleepIntervalMs = whenToDecom match {
8791 // Increase the window of time b/w task started and ended so that we can decom within that.
@@ -101,7 +105,7 @@ class BlockManagerDecommissionIntegrationSuite extends SparkFunSuite with LocalS
101105 // Create a new RDD where we have sleep in each partition, we are also increasing
102106 // the value of accumulator in each partition
103107 val baseRdd = input.mapPartitions { x =>
104- accum.add(SparkEnv .get.executorId )
108+ accum.add(1 )
105109 if (sleepIntervalMs > 0 ) {
106110 Thread .sleep(sleepIntervalMs)
107111 }
@@ -115,10 +119,11 @@ class BlockManagerDecommissionIntegrationSuite extends SparkFunSuite with LocalS
115119 // Listen for the job & block updates
116120 val executorRemovedSem = new Semaphore (0 )
117121 val taskEndEvents = new ConcurrentLinkedQueue [SparkListenerTaskEnd ]()
122+ val executorsActuallyStarted = new ConcurrentHashMap [String , Boolean ]()
118123 val blocksUpdated = ArrayBuffer .empty[SparkListenerBlockUpdated ]
119124
120125 def getCandidateExecutorToDecom : Option [String ] = if (whenToDecom == TaskStarted ) {
121- accum.value .asScala.headOption
126+ executorsActuallyStarted.keySet() .asScala.headOption
122127 } else {
123128 taskEndEvents.asScala.filter(_.taskInfo.successful).map(_.taskInfo.executorId).headOption
124129 }
@@ -135,6 +140,22 @@ class BlockManagerDecommissionIntegrationSuite extends SparkFunSuite with LocalS
135140 override def onBlockUpdated (blockUpdated : SparkListenerBlockUpdated ): Unit = {
136141 blocksUpdated.append(blockUpdated)
137142 }
143+
144+ override def onExecutorMetricsUpdate (
145+ executorMetricsUpdate : SparkListenerExecutorMetricsUpdate ): Unit = {
146+ val executorId = executorMetricsUpdate.execId
147+ if (executorId != SparkContext .DRIVER_IDENTIFIER ) {
148+ val validUpdate = executorMetricsUpdate
149+ .accumUpdates
150+ .flatMap(_._4)
151+ .exists { accumInfo =>
152+ accumInfo.name == accum.name && accumInfo.update.exists(_.asInstanceOf [Long ] >= 1 )
153+ }
154+ if (validUpdate) {
155+ executorsActuallyStarted.put(executorId, java.lang.Boolean .TRUE )
156+ }
157+ }
158+ }
138159 })
139160
140161 // Cache the RDD lazily
@@ -151,7 +172,7 @@ class BlockManagerDecommissionIntegrationSuite extends SparkFunSuite with LocalS
151172 // This way we know that this executor had real data to migrate when it is subsequently
152173 // decommissioned below.
153174 val intervalMs = if (whenToDecom == TaskStarted ) {
154- 1 .milliseconds
175+ 3 .milliseconds
155176 } else {
156177 10 .milliseconds
157178 }
@@ -170,24 +191,39 @@ class BlockManagerDecommissionIntegrationSuite extends SparkFunSuite with LocalS
170191 sched.decommissionExecutor(
171192 execToDecommission,
172193 ExecutorDecommissionInfo (" " , isHostDecommissioned = false ))
194+ val decomTime = new SystemClock ().getTimeMillis()
173195
174196 // Wait for job to finish.
175197 val asyncCountResult = ThreadUtils .awaitResult(asyncCount, 15 .seconds)
176198 assert(asyncCountResult === numParts)
177199 // All tasks finished, so accum should have been increased numParts times.
178- assert(accum.value.size() === numParts)
200+ assert(accum.value === numParts)
179201
180202 sc.listenerBus.waitUntilEmpty()
203+ val taskEndEventsCopy = taskEndEvents.asScala
181204 if (shuffle) {
182205 // mappers & reducers which succeeded
183- assert(taskEndEvents.asScala .count(_.reason == Success ) === 2 * numParts,
206+ assert(taskEndEventsCopy .count(_.reason == Success ) === 2 * numParts,
184207 s " Expected ${2 * numParts} tasks got ${taskEndEvents.size} ( ${taskEndEvents}) " )
185208 } else {
186209 // only mappers which executed successfully
187- assert(taskEndEvents.asScala .count(_.reason == Success ) === numParts,
210+ assert(taskEndEventsCopy .count(_.reason == Success ) === numParts,
188211 s " Expected ${numParts} tasks got ${taskEndEvents.size} ( ${taskEndEvents}) " )
189212 }
190213
214+ val minTaskEndTime = taskEndEventsCopy.map(_.taskInfo.finishTime).min
215+ val maxTaskEndTime = taskEndEventsCopy.map(_.taskInfo.finishTime).max
216+
217+ // Verify that the decom time matched our expectations
218+ val decomAssertMsg = s " $whenToDecom: decomTime: $decomTime, minTaskEnd: $minTaskEndTime, " +
219+ s " maxTaskEnd: $maxTaskEndTime"
220+ whenToDecom match {
221+ case TaskStarted => assert(minTaskEndTime > decomTime, decomAssertMsg)
222+ case TaskEnded => assert(minTaskEndTime <= decomTime &&
223+ decomTime < maxTaskEndTime, decomAssertMsg)
224+ case JobEnded => assert(maxTaskEndTime <= decomTime, decomAssertMsg)
225+ }
226+
191227 // Wait for our respective blocks to have migrated
192228 eventually(timeout(30 .seconds), interval(10 .milliseconds)) {
193229 if (persist) {
@@ -223,7 +259,7 @@ class BlockManagerDecommissionIntegrationSuite extends SparkFunSuite with LocalS
223259 // cached data. Original RDD partitions should not be recomputed i.e. accum
224260 // should have same value like before
225261 assert(testRdd.count() === numParts)
226- assert(accum.value.size() === numParts)
262+ assert(accum.value === numParts)
227263
228264 val storageStatus = sc.env.blockManager.master.getStorageStatus
229265 val execIdToBlocksMapping = storageStatus.map(
@@ -246,6 +282,6 @@ class BlockManagerDecommissionIntegrationSuite extends SparkFunSuite with LocalS
246282 // cached data. Original RDD partitions should not be recomputed i.e. accum
247283 // should have same value like before
248284 assert(testRdd.count() === numParts)
249- assert(accum.value.size() === numParts)
285+ assert(accum.value === numParts)
250286 }
251287}
0 commit comments