Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.connect.service

import java.util.UUID
import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, ScheduledExecutorService, TimeUnit}
import javax.annotation.concurrent.GuardedBy
import java.util.concurrent.atomic.{AtomicLong, AtomicReference}

import scala.collection.mutable
import scala.concurrent.duration.FiniteDuration
Expand Down Expand Up @@ -66,21 +66,19 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
/** Concurrent hash table containing all the current executions. */
private val executions: ConcurrentMap[ExecuteKey, ExecuteHolder] =
new ConcurrentHashMap[ExecuteKey, ExecuteHolder]()
private val executionsLock = new Object

/** Graveyard of tombstones of executions that were abandoned and removed. */
private val abandonedTombstones = CacheBuilder
.newBuilder()
.maximumSize(SparkEnv.get.conf.get(CONNECT_EXECUTE_MANAGER_ABANDONED_TOMBSTONES_SIZE))
.build[ExecuteKey, ExecuteInfo]()

/** None if there are no executions. Otherwise, the time when the last execution was removed. */
@GuardedBy("executionsLock")
private var lastExecutionTimeMs: Option[Long] = Some(System.currentTimeMillis())
/** The time when the last execution was removed. */
private var lastExecutionTimeMs: AtomicLong = new AtomicLong(System.currentTimeMillis())

/** Executor for the periodic maintenance */
@GuardedBy("executionsLock")
private var scheduledExecutor: Option[ScheduledExecutorService] = None
private var scheduledExecutor: AtomicReference[ScheduledExecutorService] =
new AtomicReference[ScheduledExecutorService]()

/**
* Create a new ExecuteHolder and register it with this global manager and with its session.
Expand Down Expand Up @@ -118,11 +116,6 @@ private[connect] class SparkConnectExecutionManager() extends Logging {

sessionHolder.addExecuteHolder(executeHolder)

executionsLock.synchronized {
if (!executions.isEmpty()) {
lastExecutionTimeMs = None
}
}
logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, executeHolder.key)} is created.")

schedulePeriodicChecks() // Starts the maintenance thread if it hasn't started.
Expand Down Expand Up @@ -151,11 +144,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
executions.remove(key)
executeHolder.sessionHolder.removeExecuteHolder(executeHolder.operationId)

executionsLock.synchronized {
if (executions.isEmpty) {
lastExecutionTimeMs = Some(System.currentTimeMillis())
}
}
updateLastExecutionTime()

logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, key)} is removed.")

Expand Down Expand Up @@ -197,7 +186,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
*/
def listActiveExecutions: Either[Long, Seq[ExecuteInfo]] = {
if (executions.isEmpty) {
Left(lastExecutionTimeMs.get)
Left(lastExecutionTimeMs.getAcquire())
} else {
Right(executions.values().asScala.map(_.getExecuteInfo).toBuffer.toSeq)
}
Expand All @@ -212,39 +201,40 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
}

private[connect] def shutdown(): Unit = {
executionsLock.synchronized {
scheduledExecutor.foreach { executor =>
ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES))
}
scheduledExecutor = None
val executor = scheduledExecutor.getAndSet(null)
if (executor != null) {
ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES))
}

// note: this does not cleanly shut down the executions, but the server is shutting down.
executions.clear()
abandonedTombstones.invalidateAll()

executionsLock.synchronized {
if (lastExecutionTimeMs.isEmpty) {
lastExecutionTimeMs = Some(System.currentTimeMillis())
}
}
updateLastExecutionTime()
}

/**
* Updates the last execution time after the last execution has been removed.
*/
private def updateLastExecutionTime(): Unit = {
lastExecutionTimeMs.getAndUpdate(prev => prev.max(System.currentTimeMillis()))
}

/**
* Schedules periodic maintenance checks if it is not already scheduled. The checks are looking
* for executions that have not been closed, but are left with no RPC attached to them, and
* removes them after a timeout.
*/
private def schedulePeriodicChecks(): Unit = executionsLock.synchronized {
scheduledExecutor match {
case Some(_) => // Already running.
case None =>
private def schedulePeriodicChecks(): Unit = {
var executor = scheduledExecutor.getAcquire()
if (executor == null) {
executor = Executors.newSingleThreadScheduledExecutor()
if (scheduledExecutor.compareAndExchangeRelease(null, executor) == null) {
val interval = SparkEnv.get.conf.get(CONNECT_EXECUTE_MANAGER_MAINTENANCE_INTERVAL)
logInfo(
log"Starting thread for cleanup of abandoned executions every " +
log"${MDC(LogKeys.INTERVAL, interval)} ms")
scheduledExecutor = Some(Executors.newSingleThreadScheduledExecutor())
scheduledExecutor.get.scheduleAtFixedRate(
executor.scheduleAtFixedRate(
() => {
try {
val timeout = SparkEnv.get.conf.get(CONNECT_EXECUTE_MANAGER_DETACHED_TIMEOUT)
Expand All @@ -256,6 +246,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
interval,
interval,
TimeUnit.MILLISECONDS)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.connect.service

import java.util.UUID
import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, ScheduledExecutorService, TimeUnit}
import javax.annotation.concurrent.GuardedBy
import java.util.concurrent.atomic.AtomicReference

import scala.collection.mutable
import scala.concurrent.duration.FiniteDuration
Expand All @@ -40,8 +40,6 @@ import org.apache.spark.util.ThreadUtils
*/
class SparkConnectSessionManager extends Logging {

private val sessionsLock = new Object

private val sessionStore: ConcurrentMap[SessionKey, SessionHolder] =
new ConcurrentHashMap[SessionKey, SessionHolder]()

Expand All @@ -52,8 +50,8 @@ class SparkConnectSessionManager extends Logging {
.build[SessionKey, SessionHolderInfo]()

/** Executor for the periodic maintenance */
@GuardedBy("sessionsLock")
private var scheduledExecutor: Option[ScheduledExecutorService] = None
private var scheduledExecutor: AtomicReference[ScheduledExecutorService] =
new AtomicReference[ScheduledExecutorService]()

private def validateSessionId(
key: SessionKey,
Expand All @@ -75,8 +73,6 @@ class SparkConnectSessionManager extends Logging {
val holder = getSession(
key,
Some(() => {
// Executed under sessionsState lock in getSession, to guard against concurrent removal
// and insertion into closedSessionsCache.
validateSessionCreate(key)
val holder = SessionHolder(key.userId, key.sessionId, newIsolatedSession())
holder.initializeSession()
Expand Down Expand Up @@ -168,17 +164,14 @@ class SparkConnectSessionManager extends Logging {

def closeSession(key: SessionKey): Unit = {
val sessionHolder = removeSessionHolder(key)
// Rest of the cleanup outside sessionLock - the session cannot be accessed anymore by
// getOrCreateIsolatedSession.
// Rest of the cleanup: the session cannot be accessed anymore by getOrCreateIsolatedSession.
sessionHolder.foreach(shutdownSessionHolder(_))
}

private[connect] def shutdown(): Unit = {
sessionsLock.synchronized {
scheduledExecutor.foreach { executor =>
ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES))
}
scheduledExecutor = None
val executor = scheduledExecutor.getAndSet(null)
if (executor != null) {
ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES))
}

// note: this does not cleanly shut down the sessions, but the server is shutting down.
Expand All @@ -199,16 +192,16 @@ class SparkConnectSessionManager extends Logging {
*
* The checks are looking to remove sessions that expired.
*/
private def schedulePeriodicChecks(): Unit = sessionsLock.synchronized {
scheduledExecutor match {
case Some(_) => // Already running.
case None =>
private def schedulePeriodicChecks(): Unit = {
var executor = scheduledExecutor.getAcquire()
if (executor == null) {
executor = Executors.newSingleThreadScheduledExecutor()
if (scheduledExecutor.compareAndExchangeRelease(null, executor) == null) {
val interval = SparkEnv.get.conf.get(CONNECT_SESSION_MANAGER_MAINTENANCE_INTERVAL)
logInfo(
log"Starting thread for cleanup of expired sessions every " +
log"${MDC(INTERVAL, interval)} ms")
scheduledExecutor = Some(Executors.newSingleThreadScheduledExecutor())
scheduledExecutor.get.scheduleAtFixedRate(
executor.scheduleAtFixedRate(
() => {
try {
val defaultInactiveTimeoutMs =
Expand All @@ -221,6 +214,7 @@ class SparkConnectSessionManager extends Logging {
interval,
interval,
TimeUnit.MILLISECONDS)
}
}
}

Expand Down Expand Up @@ -255,24 +249,18 @@ class SparkConnectSessionManager extends Logging {

// .. and remove them.
toRemove.foreach { sessionHolder =>
// This doesn't use closeSession to be able to do the extra last chance check under lock.
val removedSession = {
// Last chance - check expiration time and remove under lock if expired.
val info = sessionHolder.getSessionHolderInfo
if (shouldExpire(info, System.currentTimeMillis())) {
logInfo(
log"Found session ${MDC(SESSION_HOLD_INFO, info)} that expired " +
log"and will be closed.")
removeSessionHolder(info.key)
} else {
None
val info = sessionHolder.getSessionHolderInfo
if (shouldExpire(info, System.currentTimeMillis())) {
logInfo(
log"Found session ${MDC(SESSION_HOLD_INFO, info)} that expired " +
log"and will be closed.")
removeSessionHolder(info.key)
try {
shutdownSessionHolder(sessionHolder)
} catch {
case NonFatal(ex) => logWarning("Unexpected exception closing session", ex)
}
}
// do shutdown and cleanup outside of lock.
try removedSession.foreach(shutdownSessionHolder(_))
catch {
case NonFatal(ex) => logWarning("Unexpected exception closing session", ex)
}
}
logInfo("Finished periodic run of SparkConnectSessionManager maintenance.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.connect.service
import java.util.concurrent.Executors
import java.util.concurrent.ScheduledExecutorService
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicReference
import javax.annotation.concurrent.GuardedBy

import scala.collection.mutable
Expand Down Expand Up @@ -185,10 +186,10 @@ private[connect] class SparkConnectStreamingQueryCache(

// Visible for testing.
private[service] def shutdown(): Unit = queryCacheLock.synchronized {
scheduledExecutor.foreach { executor =>
val executor = scheduledExecutor.getAndSet(null)
if (executor != null) {
ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES))
}
scheduledExecutor = None
}

@GuardedBy("queryCacheLock")
Expand All @@ -199,19 +200,19 @@ private[connect] class SparkConnectStreamingQueryCache(
private val taggedQueries = new mutable.HashMap[String, mutable.ArrayBuffer[QueryCacheKey]]
private val taggedQueriesLock = new Object

@GuardedBy("queryCacheLock")
private var scheduledExecutor: Option[ScheduledExecutorService] = None
private var scheduledExecutor: AtomicReference[ScheduledExecutorService] =
new AtomicReference[ScheduledExecutorService]()

/** Schedules periodic checks if it is not already scheduled */
private def schedulePeriodicChecks(): Unit = queryCacheLock.synchronized {
scheduledExecutor match {
case Some(_) => // Already running.
case None =>
private def schedulePeriodicChecks(): Unit = {
var executor = scheduledExecutor.getAcquire()
if (executor == null) {
executor = Executors.newSingleThreadScheduledExecutor()
if (scheduledExecutor.compareAndExchangeRelease(null, executor) == null) {
logInfo(
log"Starting thread for polling streaming sessions " +
log"every ${MDC(DURATION, sessionPollingPeriod.toMillis)}")
scheduledExecutor = Some(Executors.newSingleThreadScheduledExecutor())
scheduledExecutor.get.scheduleAtFixedRate(
executor.scheduleAtFixedRate(
() => {
try periodicMaintenance()
catch {
Expand All @@ -221,6 +222,7 @@ private[connect] class SparkConnectStreamingQueryCache(
sessionPollingPeriod.toMillis,
sessionPollingPeriod.toMillis,
TimeUnit.MILLISECONDS)
}
}
}

Expand Down