@@ -25,6 +25,9 @@ import scala.annotation.meta.param
2525import scala .collection .mutable .{ArrayBuffer , HashMap , HashSet , Map }
2626import scala .util .control .NonFatal
2727
28+ import org .mockito .Mockito .spy
29+ import org .mockito .Mockito .times
30+ import org .mockito .Mockito .verify
2831import org .scalatest .concurrent .{Signaler , ThreadSignaler , TimeLimits }
2932import org .scalatest .exceptions .TestFailedException
3033import org .scalatest .time .SpanSugar ._
@@ -235,6 +238,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
235238
236239 var sparkListener : EventInfoRecordingListener = null
237240
241+ var blockManagerMaster : BlockManagerMaster = null
238242 var mapOutputTracker : MapOutputTrackerMaster = null
239243 var broadcastManager : BroadcastManager = null
240244 var securityMgr : SecurityManager = null
@@ -248,17 +252,18 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
248252 */
249253 val cacheLocations = new HashMap [(Int , Int ), Seq [BlockManagerId ]]
250254 // stub out BlockManagerMaster.getLocations to use our cacheLocations
251- val blockManagerMaster = new BlockManagerMaster (null , null , conf, true ) {
252- override def getLocations (blockIds : Array [BlockId ]): IndexedSeq [Seq [BlockManagerId ]] = {
253- blockIds.map {
254- _.asRDDId.map(id => (id.rddId -> id.splitIndex)).flatMap(key => cacheLocations.get(key)).
255- getOrElse(Seq ())
256- }.toIndexedSeq
257- }
258- override def removeExecutor (execId : String ): Unit = {
259- // don't need to propagate to the driver, which we don't have
260- }
255+ class MyBlockManagerMaster (conf : SparkConf ) extends BlockManagerMaster (null , null , conf, true ) {
256+ override def getLocations (blockIds : Array [BlockId ]): IndexedSeq [Seq [BlockManagerId ]] = {
257+ blockIds.map {
258+ _.asRDDId.map { id => (id.rddId -> id.splitIndex)
259+ }.flatMap { key => cacheLocations.get(key)
260+ }.getOrElse(Seq ())
261+ }.toIndexedSeq
261262 }
263+ override def removeExecutor (execId : String ): Unit = {
264+ // don't need to propagate to the driver, which we don't have
265+ }
266+ }
262267
263268 /** The list of results that DAGScheduler has collected. */
264269 val results = new HashMap [Int , Any ]()
@@ -276,6 +281,16 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
276281 override def jobFailed (exception : Exception ): Unit = { failure = exception }
277282 }
278283
284+ class MyMapOutputTrackerMaster (
285+ conf : SparkConf ,
286+ broadcastManager : BroadcastManager )
287+ extends MapOutputTrackerMaster (conf, broadcastManager, true ) {
288+
289+ override def sendTracker (message : Any ): Unit = {
290+ // no-op, just so we can stop this to avoid leaking threads
291+ }
292+ }
293+
279294 override def beforeEach (): Unit = {
280295 super .beforeEach()
281296 init(new SparkConf ())
@@ -293,11 +308,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
293308 results.clear()
294309 securityMgr = new SecurityManager (conf)
295310 broadcastManager = new BroadcastManager (true , conf, securityMgr)
296- mapOutputTracker = new MapOutputTrackerMaster (conf, broadcastManager, true ) {
297- override def sendTracker (message : Any ): Unit = {
298- // no-op, just so we can stop this to avoid leaking threads
299- }
300- }
311+ mapOutputTracker = spy(new MyMapOutputTrackerMaster (conf, broadcastManager))
312+ blockManagerMaster = spy(new MyBlockManagerMaster (conf))
301313 scheduler = new DAGScheduler (
302314 sc,
303315 taskScheduler,
@@ -548,6 +560,56 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
548560 assert(mapStatus2(2 ).location.host === " hostB" )
549561 }
550562
563+ test(" SPARK-32003: All shuffle files for executor should be cleaned up on fetch failure" ) {
564+ // reset the test context with the right shuffle service config
565+ afterEach()
566+ val conf = new SparkConf ()
567+ conf.set(config.SHUFFLE_SERVICE_ENABLED .key, " true" )
568+ init(conf)
569+
570+ val shuffleMapRdd = new MyRDD (sc, 3 , Nil )
571+ val shuffleDep = new ShuffleDependency (shuffleMapRdd, new HashPartitioner (3 ))
572+ val shuffleId = shuffleDep.shuffleId
573+ val reduceRdd = new MyRDD (sc, 3 , List (shuffleDep), tracker = mapOutputTracker)
574+
575+ submit(reduceRdd, Array (0 , 1 , 2 ))
576+ // Map stage completes successfully,
577+ // two tasks are run on an executor on hostA and one on an executor on hostB
578+ completeShuffleMapStageSuccessfully(0 , 0 , 3 , Seq (" hostA" , " hostA" , " hostB" ))
579+ // Now the executor on hostA is lost
580+ runEvent(ExecutorLost (" hostA-exec" , ExecutorExited (- 100 , false , " Container marked as failed" )))
581+ // Executor is removed but shuffle files are not unregistered
582+ verify(blockManagerMaster, times(1 )).removeExecutor(" hostA-exec" )
583+ verify(mapOutputTracker, times(0 )).removeOutputsOnExecutor(" hostA-exec" )
584+
585+ // The MapOutputTracker has all the shuffle files
586+ val mapStatuses = mapOutputTracker.shuffleStatuses(shuffleId).mapStatuses
587+ assert(mapStatuses.count(_ != null ) === 3 )
588+ assert(mapStatuses.count(s => s != null && s.location.executorId == " hostA-exec" ) === 2 )
589+ assert(mapStatuses.count(s => s != null && s.location.executorId == " hostB-exec" ) === 1 )
590+
591+ // Now a fetch failure from the lost executor occurs
592+ complete(taskSets(1 ), Seq (
593+ (FetchFailed (makeBlockManagerId(" hostA" ), shuffleId, 0L , 0 , 0 , " ignored" ), null )
594+ ))
595+ // blockManagerMaster.removeExecutor is not called again
596+ // but shuffle files are unregistered
597+ verify(blockManagerMaster, times(1 )).removeExecutor(" hostA-exec" )
598+ verify(mapOutputTracker, times(1 )).removeOutputsOnExecutor(" hostA-exec" )
599+
600+ // Shuffle files for hostA-exec should be lost
601+ assert(mapStatuses.count(_ != null ) === 1 )
602+ assert(mapStatuses.count(s => s != null && s.location.executorId == " hostA-exec" ) === 0 )
603+ assert(mapStatuses.count(s => s != null && s.location.executorId == " hostB-exec" ) === 1 )
604+
605+ // Additional fetch failure from the executor does not result in further call to
606+ // mapOutputTracker.removeOutputsOnExecutor
607+ complete(taskSets(1 ), Seq (
608+ (FetchFailed (makeBlockManagerId(" hostA" ), shuffleId, 0L , 1 , 0 , " ignored" ), null )
609+ ))
610+ verify(mapOutputTracker, times(1 )).removeOutputsOnExecutor(" hostA-exec" )
611+ }
612+
551613 test(" zero split job" ) {
552614 var numResults = 0
553615 var failureReason : Option [Exception ] = None
@@ -765,8 +827,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
765827 complete(taskSets(1 ), Seq (
766828 (Success , 42 ),
767829 (FetchFailed (makeBlockManagerId(" hostA" ), shuffleId, 0L , 0 , 0 , " ignored" ), null )))
768- // this will get called
769- // blockManagerMaster.removeExecutor("hostA-exec")
830+ verify(blockManagerMaster, times(1 )).removeExecutor(" hostA-exec" )
770831 // ask the scheduler to try it again
771832 scheduler.resubmitFailedStages()
772833 // have the 2nd attempt pass
@@ -806,11 +867,14 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
806867 submit(reduceRdd, Array (0 ))
807868 completeShuffleMapStageSuccessfully(0 , 0 , 1 )
808869 runEvent(ExecutorLost (" hostA-exec" , event))
870+ verify(blockManagerMaster, times(1 )).removeExecutor(" hostA-exec" )
809871 if (expectFileLoss) {
872+ verify(mapOutputTracker, times(1 )).removeOutputsOnExecutor(" hostA-exec" )
810873 intercept[MetadataFetchFailedException ] {
811874 mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0 )
812875 }
813876 } else {
877+ verify(mapOutputTracker, times(0 )).removeOutputsOnExecutor(" hostA-exec" )
814878 assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0 ).map(_._1).toSet ===
815879 HashSet (makeBlockManagerId(" hostA" ), makeBlockManagerId(" hostB" )))
816880 }
0 commit comments