Skip to content

Commit 01c999e

Browse files
zsxwingtdas
authored andcommitted
[SPARK-20461][CORE][SS] Use UninterruptibleThread for Executor and fix the potential hang in CachedKafkaConsumer
## What changes were proposed in this pull request? This PR changes Executor's threads to `UninterruptibleThread` so that we can use `runUninterruptibly` in `CachedKafkaConsumer`. However, this is just best effort to avoid hanging forever. If the user uses`CachedKafkaConsumer` in another thread (e.g., create a new thread or Future), the potential hang may still happen. ## How was this patch tested? The new added test. Author: Shixiong Zhu <[email protected]> Closes #17761 from zsxwing/int.
1 parent 606432a commit 01c999e

File tree

4 files changed

+50
-5
lines changed

4 files changed

+50
-5
lines changed

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,15 @@ import java.lang.management.ManagementFactory
2323
import java.net.{URI, URL}
2424
import java.nio.ByteBuffer
2525
import java.util.Properties
26-
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
26+
import java.util.concurrent._
2727
import javax.annotation.concurrent.GuardedBy
2828

2929
import scala.collection.JavaConverters._
3030
import scala.collection.mutable.{ArrayBuffer, HashMap, Map}
3131
import scala.util.control.NonFatal
3232

33+
import com.google.common.util.concurrent.ThreadFactoryBuilder
34+
3335
import org.apache.spark._
3436
import org.apache.spark.deploy.SparkHadoopUtil
3537
import org.apache.spark.internal.Logging
@@ -84,7 +86,20 @@ private[spark] class Executor(
8486
}
8587

8688
// Start worker thread pool
87-
private val threadPool = ThreadUtils.newDaemonCachedThreadPool("Executor task launch worker")
89+
private val threadPool = {
90+
val threadFactory = new ThreadFactoryBuilder()
91+
.setDaemon(true)
92+
.setNameFormat("Executor task launch worker-%d")
93+
.setThreadFactory(new ThreadFactory {
94+
override def newThread(r: Runnable): Thread =
95+
// Use UninterruptibleThread to run tasks so that we can allow running codes without being
96+
// interrupted by `Thread.interrupt()`. Some issues, such as KAFKA-1894, HADOOP-10622,
97+
// will hang forever if some methods are interrupted.
98+
new UninterruptibleThread(r, "unused") // thread name will be set by ThreadFactoryBuilder
99+
})
100+
.build()
101+
Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor]
102+
}
88103
private val executorSource = new ExecutorSource(threadPool, executorId)
89104
// Pool used for threads that supervise task killing / cancellation
90105
private val taskReaperPool = ThreadUtils.newDaemonCachedThreadPool("Task reaper")

core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,13 @@ import javax.annotation.concurrent.GuardedBy
2727
*
2828
* Note: "runUninterruptibly" should be called only in `this` thread.
2929
*/
30-
private[spark] class UninterruptibleThread(name: String) extends Thread(name) {
30+
private[spark] class UninterruptibleThread(
31+
target: Runnable,
32+
name: String) extends Thread(target, name) {
33+
34+
def this(name: String) {
35+
this(null, name)
36+
}
3137

3238
/** A monitor to protect "uninterruptible" and "interrupted" */
3339
private val uninterruptibleLock = new Object

core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ import org.apache.spark.scheduler.{FakeTask, ResultTask, TaskDescription}
4444
import org.apache.spark.serializer.JavaSerializer
4545
import org.apache.spark.shuffle.FetchFailedException
4646
import org.apache.spark.storage.BlockManagerId
47+
import org.apache.spark.util.UninterruptibleThread
4748

4849
class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually {
4950

@@ -158,6 +159,18 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
158159
assert(failReason.isInstanceOf[FetchFailed])
159160
}
160161

162+
test("Executor's worker threads should be UninterruptibleThread") {
163+
val conf = new SparkConf()
164+
.setMaster("local")
165+
.setAppName("executor thread test")
166+
.set("spark.ui.enabled", "false")
167+
sc = new SparkContext(conf)
168+
val executorThread = sc.parallelize(Seq(1), 1).map { _ =>
169+
Thread.currentThread.getClass.getName
170+
}.collect().head
171+
assert(executorThread === classOf[UninterruptibleThread].getName)
172+
}
173+
161174
test("SPARK-19276: OOMs correctly handled with a FetchFailure") {
162175
// when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it
163176
// may be a false positive. And we should call the uncaught exception handler.

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import org.apache.kafka.common.TopicPartition
2828
import org.apache.spark.{SparkEnv, SparkException, TaskContext}
2929
import org.apache.spark.internal.Logging
3030
import org.apache.spark.sql.kafka010.KafkaSource._
31+
import org.apache.spark.util.UninterruptibleThread
3132

3233

3334
/**
@@ -62,11 +63,20 @@ private[kafka010] case class CachedKafkaConsumer private(
6263

6364
case class AvailableOffsetRange(earliest: Long, latest: Long)
6465

66+
private def runUninterruptiblyIfPossible[T](body: => T): T = Thread.currentThread match {
67+
case ut: UninterruptibleThread =>
68+
ut.runUninterruptibly(body)
69+
case _ =>
70+
logWarning("CachedKafkaConsumer is not running in UninterruptibleThread. " +
71+
"It may hang when CachedKafkaConsumer's methods are interrupted because of KAFKA-1894")
72+
body
73+
}
74+
6575
/**
6676
* Return the available offset range of the current partition. It's a pair of the earliest offset
6777
* and the latest offset.
6878
*/
69-
def getAvailableOffsetRange(): AvailableOffsetRange = {
79+
def getAvailableOffsetRange(): AvailableOffsetRange = runUninterruptiblyIfPossible {
7080
consumer.seekToBeginning(Set(topicPartition).asJava)
7181
val earliestOffset = consumer.position(topicPartition)
7282
consumer.seekToEnd(Set(topicPartition).asJava)
@@ -92,7 +102,8 @@ private[kafka010] case class CachedKafkaConsumer private(
92102
offset: Long,
93103
untilOffset: Long,
94104
pollTimeoutMs: Long,
95-
failOnDataLoss: Boolean): ConsumerRecord[Array[Byte], Array[Byte]] = {
105+
failOnDataLoss: Boolean):
106+
ConsumerRecord[Array[Byte], Array[Byte]] = runUninterruptiblyIfPossible {
96107
require(offset < untilOffset,
97108
s"offset must always be less than untilOffset [offset: $offset, untilOffset: $untilOffset]")
98109
logDebug(s"Get $groupId $topicPartition nextOffset $nextOffsetInFetchedData requested $offset")

0 commit comments

Comments
 (0)