@@ -84,6 +84,19 @@ class DecommissionWorkerSuite
8484 }
8585 }
8686
87+ // Unlike TestUtils.withListener, it also waits for the job to be done
88+ def withListener (sc : SparkContext , listener : RootStageAwareListener )
89+ (body : SparkListener => Unit ): Unit = {
90+ sc.addSparkListener(listener)
91+ try {
92+ body(listener)
93+ sc.listenerBus.waitUntilEmpty()
94+ listener.waitForJobDone()
95+ } finally {
96+ sc.listenerBus.removeListener(listener)
97+ }
98+ }
99+
87100 test(" decommission workers should not result in job failure" ) {
88101 val maxTaskFailures = 2
89102 val numTimesToKillWorkers = maxTaskFailures + 1
@@ -109,7 +122,7 @@ class DecommissionWorkerSuite
109122 }
110123 }
111124 }
112- TestUtils . withListener(sc, listener) { _ =>
125+ withListener(sc, listener) { _ =>
113126 val jobResult = sc.parallelize(1 to 1 , 1 ).map { _ =>
114127 Thread .sleep(5 * 1000L ); 1
115128 }.count()
@@ -164,7 +177,7 @@ class DecommissionWorkerSuite
164177 }
165178 }
166179 }
167- TestUtils . withListener(sc, listener) { _ =>
180+ withListener(sc, listener) { _ =>
168181 val jobResult = sc.parallelize(1 to 2 , 2 ).mapPartitionsWithIndex((pid, _) => {
169182 val sleepTimeSeconds = if (pid == 0 ) 1 else 10
170183 Thread .sleep(sleepTimeSeconds * 1000L )
@@ -212,22 +225,27 @@ class DecommissionWorkerSuite
212225 override def handleRootTaskEnd (taskEnd : SparkListenerTaskEnd ): Unit = {
213226 val taskInfo = taskEnd.taskInfo
214227 if (taskInfo.executorId == executorToDecom && taskInfo.attemptNumber == 0 &&
215- taskEnd.stageAttemptId == 0 ) {
228+ taskEnd.stageAttemptId == 0 && taskEnd.stageId == 0 ) {
216229 decommissionWorkerOnMaster(workerToDecom,
217230 " decommission worker after task on it is done" )
218231 }
219232 }
220233 }
221- TestUtils . withListener(sc, listener) { _ =>
234+ withListener(sc, listener) { _ =>
222235 val jobResult = sc.parallelize(1 to 2 , 2 ).mapPartitionsWithIndex((_, _) => {
223236 val executorId = SparkEnv .get.executorId
224- val sleepTimeSeconds = if (executorId == executorToDecom) 10 else 1
225- Thread .sleep(sleepTimeSeconds * 1000L )
237+ val context = TaskContext .get()
238+ if (context.attemptNumber() == 0 && context.stageAttemptNumber() == 0 ) {
239+ val sleepTimeSeconds = if (executorId == executorToDecom) 10 else 1
240+ Thread .sleep(sleepTimeSeconds * 1000L )
241+ }
226242 List (1 ).iterator
227243 }, preservesPartitioning = true )
228244 .repartition(1 ).mapPartitions(iter => {
229245 val context = TaskContext .get()
230246 if (context.attemptNumber == 0 && context.stageAttemptNumber() == 0 ) {
247+ // Wait a bit for the decommissioning to be triggered in the listener
248+ Thread .sleep(5000 )
231249 // MapIndex is explicitly -1 to force the entire host to be decommissioned
232250 // However, this will cause both the tasks in the preceding stage since the host here is
233251 // "localhost" (shortcoming of this single-machine unit test in that all the workers
@@ -265,23 +283,31 @@ class DecommissionWorkerSuite
265283 override def onJobEnd (jobEnd : SparkListenerJobEnd ): Unit = {
266284 jobEnd.jobResult match {
267285 case JobSucceeded => jobDone.set(true )
286+ case JobFailed (exception) => logError(s " Job failed " , exception)
268287 }
269288 }
270289
271290 protected def handleRootTaskEnd (end : SparkListenerTaskEnd ) = {}
272291
273292 protected def handleRootTaskStart (start : SparkListenerTaskStart ) = {}
274293
294+ private def getSignature (taskInfo : TaskInfo , stageId : Int , stageAttemptId : Int ):
295+ String = {
296+ s " ${stageId}: ${stageAttemptId}: " +
297+ s " ${taskInfo.index}: ${taskInfo.attemptNumber}- ${taskInfo.status}"
298+ }
299+
275300 override def onTaskStart (taskStart : SparkListenerTaskStart ): Unit = {
301+ val signature = getSignature(taskStart.taskInfo, taskStart.stageId, taskStart.stageAttemptId)
302+ logInfo(s " Task started: $signature" )
276303 if (isRootStageId(taskStart.stageId)) {
277304 rootTasksStarted.add(taskStart.taskInfo)
278305 handleRootTaskStart(taskStart)
279306 }
280307 }
281308
282309 override def onTaskEnd (taskEnd : SparkListenerTaskEnd ): Unit = {
283- val taskSignature = s " ${taskEnd.stageId}: ${taskEnd.stageAttemptId}: " +
284- s " ${taskEnd.taskInfo.index}: ${taskEnd.taskInfo.attemptNumber}"
310+ val taskSignature = getSignature(taskEnd.taskInfo, taskEnd.stageId, taskEnd.stageAttemptId)
285311 logInfo(s " Task End $taskSignature" )
286312 tasksFinished.add(taskSignature)
287313 if (isRootStageId(taskEnd.stageId)) {
@@ -291,8 +317,13 @@ class DecommissionWorkerSuite
291317 }
292318
293319 def getTasksFinished (): Seq [String ] = {
294- assert(jobDone.get(), " Job isn't successfully done yet" )
295- tasksFinished.asScala.toSeq
320+ tasksFinished.asScala.toList
321+ }
322+
323+ def waitForJobDone (): Unit = {
324+ eventually(timeout(10 .seconds), interval(100 .milliseconds)) {
325+ assert(jobDone.get(), " Job isn't successfully done yet" )
326+ }
296327 }
297328 }
298329
0 commit comments