From 7a3da3dfbc4c3df58c82c3ab554d7734f808cbe2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Tue, 19 Feb 2019 13:22:40 -0800 Subject: [PATCH] [SPARK-26891][YARN] Fixing flaky test in YarnSchedulerBackendSuite MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The test "RequestExecutors reflects node blacklist and is serializable" is flaky because of multi threaded access of the mock task scheduler. For details check [Mockito FAQ (occasional exceptions like: WrongTypeOfReturnValue)](https://github.com/mockito/mockito/wiki/FAQ#is-mockito-thread-safe). So instead of mocking the task scheduler in the test TaskSchedulerImpl is simply subclassed. This multithreaded access of the `nodeBlacklist()` method is coming from: 1) the unit test thread via calling of the method `prepareRequestExecutors()` 2) the `DriverEndpoint.onStart` which runs a periodic task that ends up calling this method Existing unittest. Closes #23801 from attilapiros/SPARK-26891. Authored-by: “attilapiros” Signed-off-by: Marcelo Vanzin (cherry picked from commit e4e4e2b842bffba6805623f2258b27b162b451ba) --- .../cluster/YarnSchedulerBackendSuite.scala | 35 +++++++++++++++---- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala index 7fac57ff68abc..bd2cf97426637 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.scheduler.cluster +import java.util.concurrent.atomic.AtomicReference + import scala.language.reflectiveCalls import org.mockito.Mockito.when @@ -27,15 +29,35 @@ import org.apache.spark.serializer.JavaSerializer class YarnSchedulerBackendSuite extends SparkFunSuite with MockitoSugar with LocalSparkContext { + private var yarnSchedulerBackend: YarnSchedulerBackend = _ + + override def afterEach(): Unit = { + try { + if (yarnSchedulerBackend != null) { + yarnSchedulerBackend.stop() + } + } finally { + super.afterEach() + } + } + test("RequestExecutors reflects node blacklist and is serializable") { sc = new SparkContext("local", "YarnSchedulerBackendSuite") - val sched = mock[TaskSchedulerImpl] - when(sched.sc).thenReturn(sc) - val yarnSchedulerBackend = new YarnSchedulerBackend(sched, sc) { + // Subclassing the TaskSchedulerImpl here instead of using Mockito. For details see SPARK-26891. + val sched = new TaskSchedulerImpl(sc) { + val blacklistedNodes = new AtomicReference[Set[String]]() + + def setNodeBlacklist(nodeBlacklist: Set[String]): Unit = blacklistedNodes.set(nodeBlacklist) + + override def nodeBlacklist(): Set[String] = blacklistedNodes.get() + } + + val yarnSchedulerBackendExtended = new YarnSchedulerBackend(sched, sc) { def setHostToLocalTaskCount(hostToLocalTaskCount: Map[String, Int]): Unit = { this.hostToLocalTaskCount = hostToLocalTaskCount } } + yarnSchedulerBackend = yarnSchedulerBackendExtended val ser = new JavaSerializer(sc.conf).newInstance() for { blacklist <- IndexedSeq(Set[String](), Set("a", "b", "c")) @@ -45,16 +67,15 @@ class YarnSchedulerBackendSuite extends SparkFunSuite with MockitoSugar with Loc Map("a" -> 1, "b" -> 2) ) } { - yarnSchedulerBackend.setHostToLocalTaskCount(hostToLocalCount) - when(sched.nodeBlacklist()).thenReturn(blacklist) - val req = yarnSchedulerBackend.prepareRequestExecutors(numRequested) + yarnSchedulerBackendExtended.setHostToLocalTaskCount(hostToLocalCount) + sched.setNodeBlacklist(blacklist) + val req = yarnSchedulerBackendExtended.prepareRequestExecutors(numRequested) assert(req.requestedTotal === numRequested) assert(req.nodeBlacklist === blacklist) assert(req.hostToLocalTaskCount.keySet.intersect(blacklist).isEmpty) // Serialize to make sure serialization doesn't throw an error ser.serialize(req) } - sc.stop() } }