Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,18 @@ private[spark] object ThreadUtils {
* Create a cached thread pool whose max number of threads is `maxThreadNumber`. Thread names
* are formatted as prefix-ID, where ID is a unique, sequentially assigned integer.
*/
def newDaemonCachedThreadPool(prefix: String, maxThreadNumber: Int): ThreadPoolExecutor = {
def newDaemonCachedThreadPool(
prefix: String, maxThreadNumber: Int, keepAliveSeconds: Int = 60): ThreadPoolExecutor = {
val threadFactory = namedThreadFactory(prefix)
new ThreadPoolExecutor(
0, maxThreadNumber, 60L, TimeUnit.SECONDS, new SynchronousQueue[Runnable], threadFactory)
val threadPool = new ThreadPoolExecutor(
maxThreadNumber, // corePoolSize: the max number of threads to create before queuing the tasks
maxThreadNumber, // maximumPoolSize: because we use LinkedBlockingDeque, this one is not used
keepAliveSeconds,
TimeUnit.SECONDS,
new LinkedBlockingQueue[Runnable],
threadFactory)
threadPool.allowCoreThreadTimeOut(true)
threadPool
}

/**
Expand Down
45 changes: 45 additions & 0 deletions core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import scala.concurrent.duration._
import scala.concurrent.{Await, Future}
import scala.util.Random

import org.scalatest.concurrent.Eventually._

import org.apache.spark.SparkFunSuite

class ThreadUtilsSuite extends SparkFunSuite {
Expand Down Expand Up @@ -59,6 +61,49 @@ class ThreadUtilsSuite extends SparkFunSuite {
}
}

test("newDaemonCachedThreadPool") {
val maxThreadNumber = 10
val startThreadsLatch = new CountDownLatch(maxThreadNumber)
val latch = new CountDownLatch(1)
val cachedThreadPool = ThreadUtils.newDaemonCachedThreadPool(
"ThreadUtilsSuite-newDaemonCachedThreadPool",
maxThreadNumber,
keepAliveSeconds = 2)
try {
for (_ <- 1 to maxThreadNumber) {
cachedThreadPool.execute(new Runnable {
override def run(): Unit = {
startThreadsLatch.countDown()
latch.await(10, TimeUnit.SECONDS)
}
})
}
startThreadsLatch.await(10, TimeUnit.SECONDS)
assert(cachedThreadPool.getActiveCount === maxThreadNumber)
assert(cachedThreadPool.getQueue.size === 0)

// Submit a new task and it should be put into the queue since the thread number reaches the
// limitation
cachedThreadPool.execute(new Runnable {
override def run(): Unit = {
latch.await(10, TimeUnit.SECONDS)
}
})

assert(cachedThreadPool.getActiveCount === maxThreadNumber)
assert(cachedThreadPool.getQueue.size === 1)

latch.countDown()
eventually(timeout(10.seconds)) {
// All threads should be stopped after keepAliveSeconds
assert(cachedThreadPool.getActiveCount === 0)
assert(cachedThreadPool.getPoolSize === 0)
}
} finally {
cachedThreadPool.shutdownNow()
}
}

test("sameThread") {
val callerThreadName = Thread.currentThread().getName()
val f = Future {
Expand Down