Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import java.nio.ByteBuffer
import java.util.{Locale, Properties}
import java.util.concurrent._
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.locks.ReentrantLock
import javax.annotation.concurrent.GuardedBy
import javax.ws.rs.core.UriBuilder

Expand Down Expand Up @@ -85,6 +86,11 @@ private[spark] class Executor(

private[executor] val conf = env.conf

// SPARK-40235: updateDependencies() uses a ReentrantLock instead of the `synchronized` keyword
// so that tasks can exit quickly if they are interrupted while waiting on another task to
// finish downloading dependencies.
private val updateDependenciesLock = new ReentrantLock()

// No ip or host:port - just hostname
Utils.checkHost(executorHostname)
// must not have port specified.
Expand Down Expand Up @@ -969,13 +975,19 @@ private[spark] class Executor(
/**
* Download any missing dependencies if we receive a new set of files and JARs from the
* SparkContext. Also adds any new JARs we fetched to the class loader.
* Visible for testing.
*/
private def updateDependencies(
private[executor] def updateDependencies(
newFiles: Map[String, Long],
newJars: Map[String, Long],
newArchives: Map[String, Long]): Unit = {
newArchives: Map[String, Long],
testStartLatch: Option[CountDownLatch] = None,
testEndLatch: Option[CountDownLatch] = None): Unit = {
lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
synchronized {
updateDependenciesLock.lockInterruptibly()
try {
// For testing, so we can simulate a slow file download:
testStartLatch.foreach(_.countDown())
// Fetch missing dependencies
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo(s"Fetching $name with timestamp $timestamp")
Expand Down Expand Up @@ -1018,6 +1030,10 @@ private[spark] class Executor(
}
}
}
// For testing, so we can simulate a slow file download:
testEndLatch.foreach(_.await())
} finally {
updateDependenciesLock.unlock()
}
}

Expand Down
53 changes: 53 additions & 0 deletions core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,59 @@ class ExecutorSuite extends SparkFunSuite
}
}

test("SPARK-40235: updateDependencies is interruptible when waiting on lock") {
val conf = new SparkConf
val serializer = new JavaSerializer(conf)
val env = createMockEnv(conf, serializer)
withExecutor("id", "localhost", env) { executor =>
val startLatch = new CountDownLatch(1)
val endLatch = new CountDownLatch(1)

// Start a thread to simulate a task that begins executing updateDependencies()
// and takes a long time to finish because file download is slow:
val slowLibraryDownloadThread = new Thread(() => {
executor.updateDependencies(
Map.empty,
Map.empty,
Map.empty,
Some(startLatch),
Some(endLatch))
})
slowLibraryDownloadThread.start()

// Wait for that thread to acquire the lock:
startLatch.await()

// Start a second thread to simulate a task that blocks on the other task's
// dependency update:
val blockedLibraryDownloadThread = new Thread(() => {
executor.updateDependencies(
Map.empty,
Map.empty,
Map.empty)
})
blockedLibraryDownloadThread.start()
eventually(timeout(10.seconds), interval(100.millis)) {
val threadState = blockedLibraryDownloadThread.getState
assert(Set(Thread.State.BLOCKED, Thread.State.WAITING).contains(threadState))
}

// Interrupt the blocked thread:
blockedLibraryDownloadThread.interrupt()

// The thread should exit:
eventually(timeout(10.seconds), interval(100.millis)) {
assert(blockedLibraryDownloadThread.getState == Thread.State.TERMINATED)
}

// Allow the first thread to finish and exit:
endLatch.countDown()
eventually(timeout(10.seconds), interval(100.millis)) {
assert(slowLibraryDownloadThread.getState == Thread.State.TERMINATED)
}
}
}

private def createMockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = {
val mockEnv = mock[SparkEnv]
val mockRpcEnv = mock[RpcEnv]
Expand Down