@@ -24,8 +24,7 @@ import scala.concurrent.duration._
2424
2525import org .scalatest .concurrent .Eventually
2626
27- import org .apache .spark .{LocalSparkContext , SparkConf , SparkContext , SparkFunSuite , Success ,
28- TestUtils }
27+ import org .apache .spark ._
2928import org .apache .spark .internal .config
3029import org .apache .spark .scheduler ._
3130import org .apache .spark .scheduler .cluster .StandaloneSchedulerBackend
@@ -35,41 +34,51 @@ class BlockManagerDecommissionSuite extends SparkFunSuite with LocalSparkContext
3534 with ResetSystemProperties with Eventually {
3635
3736 val numExecs = 3
37+ val numParts = 3
3838
3939 test(s " verify that an already running task which is going to cache data succeeds " +
4040 s " on a decommissioned executor " ) {
41- runDecomTest(true , false )
41+ runDecomTest(true , false , true )
4242 }
4343
4444 test(s " verify that shuffle blocks are migrated. " ) {
45- runDecomTest(false , true )
45+ runDecomTest(false , true , false )
4646 }
4747
4848 test(s " verify that both migrations can work at the same time. " ) {
49- runDecomTest(true , true )
49+ runDecomTest(true , true , false )
5050 }
5151
52- private def runDecomTest (persist : Boolean , shuffle : Boolean ) = {
52+ private def runDecomTest (persist : Boolean , shuffle : Boolean , migrateDuring : Boolean ) = {
5353 val master = s " local-cluster[ ${numExecs}, 1, 1024] "
5454 val conf = new SparkConf ().setAppName(" test" ).setMaster(master)
5555 .set(config.Worker .WORKER_DECOMMISSION_ENABLED , true )
5656 .set(config.STORAGE_DECOMMISSION_ENABLED , true )
5757 .set(config.STORAGE_RDD_DECOMMISSION_ENABLED , persist)
5858 .set(config.STORAGE_SHUFFLE_DECOMMISSION_ENABLED , shuffle)
59+ // Just replicate blocks as fast as we can during testing, there isn't another
60+ // workload we need to worry about.
5961 .set(config.STORAGE_DECOMMISSION_REPLICATION_REATTEMPT_INTERVAL , 1L )
6062
63+ // Force fetching to local disk
64+ if (shuffle) {
65+ conf.set(" spark.network.maxRemoteBlockSizeFetchToMem" , " 1" )
66+ }
67+
6168 sc = new SparkContext (master, " test" , conf)
6269
6370 // Create input RDD with 10 partitions
64- val input = sc.parallelize(1 to 10 , 10 )
71+ val input = sc.parallelize(1 to numParts, numParts )
6572 val accum = sc.longAccumulator(" mapperRunAccumulator" )
6673 // Do a count to wait for the executors to be registered.
6774 input.count()
6875
6976 // Create a new RDD where we have sleep in each partition, we are also increasing
7077 // the value of accumulator in each partition
7178 val sleepyRdd = input.mapPartitions { x =>
72- Thread .sleep(250 )
79+ if (migrateDuring) {
80+ Thread .sleep(500 )
81+ }
7382 accum.add(1 )
7483 x.map(y => (y, y))
7584 }
@@ -79,19 +88,26 @@ class BlockManagerDecommissionSuite extends SparkFunSuite with LocalSparkContext
7988 }
8089
8190 // Listen for the job & block updates
82- val sem = new Semaphore (0 )
91+ val taskStartSem = new Semaphore (0 )
92+ val broadcastSem = new Semaphore (0 )
8393 val taskEndEvents = ArrayBuffer .empty[SparkListenerTaskEnd ]
8494 val blocksUpdated = ArrayBuffer .empty[SparkListenerBlockUpdated ]
8595 sc.addSparkListener(new SparkListener {
96+
8697 override def onTaskStart (taskStart : SparkListenerTaskStart ): Unit = {
87- sem .release()
98+ taskStartSem .release()
8899 }
89100
90101 override def onTaskEnd (taskEnd : SparkListenerTaskEnd ): Unit = {
91102 taskEndEvents.append(taskEnd)
92103 }
93104
94105 override def onBlockUpdated (blockUpdated : SparkListenerBlockUpdated ): Unit = {
106+ // Once broadcast start landing on the executors we're good to proceed.
107+ // We don't only use task start as it can occur before the work is on the executor.
108+ if (blockUpdated.blockUpdatedInfo.blockId.isBroadcast) {
109+ broadcastSem.release()
110+ }
95111 blocksUpdated.append(blockUpdated)
96112 }
97113 })
@@ -102,19 +118,32 @@ class BlockManagerDecommissionSuite extends SparkFunSuite with LocalSparkContext
102118 testRdd.persist()
103119 }
104120
105- // Wait for all of the executors to start
121+ // Wait for the first executor to start
106122 TestUtils .waitUntilExecutorsUp(sc = sc,
107- numExecutors = numExecs ,
123+ numExecutors = 1 ,
108124 timeout = 10000 ) // 10s
109125
110126 // Start the computation of RDD - this step will also cache the RDD
111127 val asyncCount = testRdd.countAsync()
112128
113- // Wait for the job to have started
114- sem.acquire(1 )
129+ // Wait for all of the executors to start
130+ TestUtils .waitUntilExecutorsUp(sc = sc,
131+ numExecutors = numExecs,
132+ timeout = 10000 ) // 10s
115133
116- // Give Spark a tiny bit to start the tasks after the listener says hello
117- Thread .sleep(50 )
134+ // Wait for the job to have started.
135+ taskStartSem.acquire(1 )
136+ // Wait for each executor + driver to have it's broadcast info delivered.
137+ broadcastSem.acquire((numExecs + 1 ))
138+
139+ // Make sure the job is either mid run or otherwise has data to migrate.
140+ if (migrateDuring) {
141+ // Give Spark a tiny bit to start executing after the broadcast blocks land.
142+ // For me this works at 100, set to 300 for system variance.
143+ Thread .sleep(300 )
144+ } else {
145+ ThreadUtils .awaitResult(asyncCount, 15 .seconds)
146+ }
118147
119148 // Decommission one of the executor
120149 val sched = sc.schedulerBackend.asInstanceOf [StandaloneSchedulerBackend ]
@@ -127,49 +156,58 @@ class BlockManagerDecommissionSuite extends SparkFunSuite with LocalSparkContext
127156
128157 // Wait for job to finish
129158 val asyncCountResult = ThreadUtils .awaitResult(asyncCount, 15 .seconds)
130- assert(asyncCountResult === 10 )
131- // All 10 tasks finished, so accum should have been increased 10 times
132- assert(accum.value === 10 )
159+ assert(asyncCountResult === numParts )
160+ // All tasks finished, so accum should have been increased numParts times
161+ assert(accum.value === numParts )
133162
134163 // All tasks should be successful, nothing should have failed
135164 sc.listenerBus.waitUntilEmpty()
136165 if (shuffle) {
137- // 10 mappers & 10 reducers which succeeded
138- assert(taskEndEvents.count(_.reason == Success ) === 20 ,
139- s " Expected 20 tasks got ${taskEndEvents.size} ( ${taskEndEvents}) " )
166+ // mappers & reducers which succeeded
167+ assert(taskEndEvents.count(_.reason == Success ) === 2 * numParts ,
168+ s " Expected ${ 2 * numParts} tasks got ${taskEndEvents.size} ( ${taskEndEvents}) " )
140169 } else {
141- // 10 mappers which executed successfully
142- assert(taskEndEvents.count(_.reason == Success ) === 10 ,
143- s " Expected 10 tasks got ${taskEndEvents.size} ( ${taskEndEvents}) " )
170+ // only mappers which executed successfully
171+ assert(taskEndEvents.count(_.reason == Success ) === numParts ,
172+ s " Expected ${numParts} tasks got ${taskEndEvents.size} ( ${taskEndEvents}) " )
144173 }
145174
146175 // Wait for our respective blocks to have migrated
147176 eventually(timeout(15 .seconds), interval(10 .milliseconds)) {
148177 if (persist) {
149178 // One of our blocks should have moved.
150- val blockLocs = blocksUpdated.map{ update =>
179+ val rddUpdates = blocksUpdated.filter{update =>
180+ val blockId = update.blockUpdatedInfo.blockId
181+ blockId.isRDD}
182+ val blockLocs = rddUpdates.map{ update =>
151183 (update.blockUpdatedInfo.blockId.name,
152184 update.blockUpdatedInfo.blockManagerId)}
153185 val blocksToManagers = blockLocs.groupBy(_._1).mapValues(_.toSet.size)
154186 assert(! blocksToManagers.filter(_._2 > 1 ).isEmpty,
155- s " We should have a block that has been on multiple BMs in ${blocksUpdated}" )
187+ s " We should have a block that has been on multiple BMs in rdds: \n ${rddUpdates} from: \n " +
188+ s " ${blocksUpdated}\n but instead we got: \n ${blocksToManagers}" )
156189 }
157190 // If we're migrating shuffles we look for any shuffle block updates
158191 // as there is no block update on the initial shuffle block write.
159192 if (shuffle) {
160- val numLocs = blocksUpdated.filter{ update =>
193+ val numDataLocs = blocksUpdated.filter{ update =>
194+ val blockId = update.blockUpdatedInfo.blockId
195+ blockId.isInstanceOf [ShuffleDataBlockId ]
196+ }.toSet.size
197+ val numIndexLocs = blocksUpdated.filter{ update =>
161198 val blockId = update.blockUpdatedInfo.blockId
162- blockId.isShuffle || blockId.isInternalShuffle
199+ blockId.isInstanceOf [ ShuffleIndexBlockId ]
163200 }.toSet.size
164- assert(numLocs > 0 , s " No shuffle block updates in ${blocksUpdated}" )
201+ assert(numDataLocs >= 1 , s " Expect shuffle data block updates in ${blocksUpdated}" )
202+ assert(numIndexLocs >= 1 , s " Expect shuffle index block updates in ${blocksUpdated}" )
165203 }
166204 }
167205
168206 // Since the RDD is cached or shuffled so further usage of same RDD should use the
169207 // cached data. Original RDD partitions should not be recomputed i.e. accum
170208 // should have same value like before
171- assert(testRdd.count() === 10 )
172- assert(accum.value === 10 )
209+ assert(testRdd.count() === numParts )
210+ assert(accum.value === numParts )
173211
174212 val storageStatus = sc.env.blockManager.master.getStorageStatus
175213 val execIdToBlocksMapping = storageStatus.map(
@@ -178,8 +216,8 @@ class BlockManagerDecommissionSuite extends SparkFunSuite with LocalSparkContext
178216 assert(execIdToBlocksMapping(execToDecommission).keys.filter(_.isRDD).toSeq === Seq (),
179217 " Cache blocks should be migrated" )
180218 if (persist) {
181- // There should still be all 10 RDD blocks cached
182- assert(execIdToBlocksMapping.values.flatMap(_.keys).count(_.isRDD) === 10 )
219+ // There should still be all the RDD blocks cached
220+ assert(execIdToBlocksMapping.values.flatMap(_.keys).count(_.isRDD) === numParts )
183221 }
184222 }
185223}
0 commit comments