From 7b1c17822ec91039357f255ab38bc7981749ce27 Mon Sep 17 00:00:00 2001 From: Changgyoo Park Date: Tue, 17 Sep 2024 10:11:45 +0200 Subject: [PATCH] Optimize --- .../SparkConnectExecutionManager.scala | 59 ++++++++---------- .../service/SparkConnectSessionManager.scala | 60 ++++++++----------- .../SparkConnectStreamingQueryCache.scala | 22 +++---- 3 files changed, 61 insertions(+), 80 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala index 61b41f932199..d66964b8d34b 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.connect.service import java.util.UUID import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, ScheduledExecutorService, TimeUnit} -import javax.annotation.concurrent.GuardedBy +import java.util.concurrent.atomic.{AtomicLong, AtomicReference} import scala.collection.mutable import scala.concurrent.duration.FiniteDuration @@ -66,7 +66,6 @@ private[connect] class SparkConnectExecutionManager() extends Logging { /** Concurrent hash table containing all the current executions. */ private val executions: ConcurrentMap[ExecuteKey, ExecuteHolder] = new ConcurrentHashMap[ExecuteKey, ExecuteHolder]() - private val executionsLock = new Object /** Graveyard of tombstones of executions that were abandoned and removed. */ private val abandonedTombstones = CacheBuilder @@ -74,13 +73,12 @@ private[connect] class SparkConnectExecutionManager() extends Logging { .maximumSize(SparkEnv.get.conf.get(CONNECT_EXECUTE_MANAGER_ABANDONED_TOMBSTONES_SIZE)) .build[ExecuteKey, ExecuteInfo]() - /** None if there are no executions. Otherwise, the time when the last execution was removed. */ - @GuardedBy("executionsLock") - private var lastExecutionTimeMs: Option[Long] = Some(System.currentTimeMillis()) + /** The time when the last execution was removed. */ + private var lastExecutionTimeMs: AtomicLong = new AtomicLong(System.currentTimeMillis()) /** Executor for the periodic maintenance */ - @GuardedBy("executionsLock") - private var scheduledExecutor: Option[ScheduledExecutorService] = None + private var scheduledExecutor: AtomicReference[ScheduledExecutorService] = + new AtomicReference[ScheduledExecutorService]() /** * Create a new ExecuteHolder and register it with this global manager and with its session. @@ -118,11 +116,6 @@ private[connect] class SparkConnectExecutionManager() extends Logging { sessionHolder.addExecuteHolder(executeHolder) - executionsLock.synchronized { - if (!executions.isEmpty()) { - lastExecutionTimeMs = None - } - } logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, executeHolder.key)} is created.") schedulePeriodicChecks() // Starts the maintenance thread if it hasn't started. @@ -151,11 +144,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging { executions.remove(key) executeHolder.sessionHolder.removeExecuteHolder(executeHolder.operationId) - executionsLock.synchronized { - if (executions.isEmpty) { - lastExecutionTimeMs = Some(System.currentTimeMillis()) - } - } + updateLastExecutionTime() logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, key)} is removed.") @@ -197,7 +186,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging { */ def listActiveExecutions: Either[Long, Seq[ExecuteInfo]] = { if (executions.isEmpty) { - Left(lastExecutionTimeMs.get) + Left(lastExecutionTimeMs.getAcquire()) } else { Right(executions.values().asScala.map(_.getExecuteInfo).toBuffer.toSeq) } @@ -212,22 +201,23 @@ private[connect] class SparkConnectExecutionManager() extends Logging { } private[connect] def shutdown(): Unit = { - executionsLock.synchronized { - scheduledExecutor.foreach { executor => - ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES)) - } - scheduledExecutor = None + val executor = scheduledExecutor.getAndSet(null) + if (executor != null) { + ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES)) } // note: this does not cleanly shut down the executions, but the server is shutting down. executions.clear() abandonedTombstones.invalidateAll() - executionsLock.synchronized { - if (lastExecutionTimeMs.isEmpty) { - lastExecutionTimeMs = Some(System.currentTimeMillis()) - } - } + updateLastExecutionTime() + } + + /** + * Updates the last execution time after the last execution has been removed. + */ + private def updateLastExecutionTime(): Unit = { + lastExecutionTimeMs.getAndUpdate(prev => prev.max(System.currentTimeMillis())) } /** @@ -235,16 +225,16 @@ private[connect] class SparkConnectExecutionManager() extends Logging { * for executions that have not been closed, but are left with no RPC attached to them, and * removes them after a timeout. */ - private def schedulePeriodicChecks(): Unit = executionsLock.synchronized { - scheduledExecutor match { - case Some(_) => // Already running. - case None => + private def schedulePeriodicChecks(): Unit = { + var executor = scheduledExecutor.getAcquire() + if (executor == null) { + executor = Executors.newSingleThreadScheduledExecutor() + if (scheduledExecutor.compareAndExchangeRelease(null, executor) == null) { val interval = SparkEnv.get.conf.get(CONNECT_EXECUTE_MANAGER_MAINTENANCE_INTERVAL) logInfo( log"Starting thread for cleanup of abandoned executions every " + log"${MDC(LogKeys.INTERVAL, interval)} ms") - scheduledExecutor = Some(Executors.newSingleThreadScheduledExecutor()) - scheduledExecutor.get.scheduleAtFixedRate( + executor.scheduleAtFixedRate( () => { try { val timeout = SparkEnv.get.conf.get(CONNECT_EXECUTE_MANAGER_DETACHED_TIMEOUT) @@ -256,6 +246,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging { interval, interval, TimeUnit.MILLISECONDS) + } } } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala index fec01813de6e..4ca3a80bfb98 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.connect.service import java.util.UUID import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, ScheduledExecutorService, TimeUnit} -import javax.annotation.concurrent.GuardedBy +import java.util.concurrent.atomic.AtomicReference import scala.collection.mutable import scala.concurrent.duration.FiniteDuration @@ -40,8 +40,6 @@ import org.apache.spark.util.ThreadUtils */ class SparkConnectSessionManager extends Logging { - private val sessionsLock = new Object - private val sessionStore: ConcurrentMap[SessionKey, SessionHolder] = new ConcurrentHashMap[SessionKey, SessionHolder]() @@ -52,8 +50,8 @@ class SparkConnectSessionManager extends Logging { .build[SessionKey, SessionHolderInfo]() /** Executor for the periodic maintenance */ - @GuardedBy("sessionsLock") - private var scheduledExecutor: Option[ScheduledExecutorService] = None + private var scheduledExecutor: AtomicReference[ScheduledExecutorService] = + new AtomicReference[ScheduledExecutorService]() private def validateSessionId( key: SessionKey, @@ -75,8 +73,6 @@ class SparkConnectSessionManager extends Logging { val holder = getSession( key, Some(() => { - // Executed under sessionsState lock in getSession, to guard against concurrent removal - // and insertion into closedSessionsCache. validateSessionCreate(key) val holder = SessionHolder(key.userId, key.sessionId, newIsolatedSession()) holder.initializeSession() @@ -168,17 +164,14 @@ class SparkConnectSessionManager extends Logging { def closeSession(key: SessionKey): Unit = { val sessionHolder = removeSessionHolder(key) - // Rest of the cleanup outside sessionLock - the session cannot be accessed anymore by - // getOrCreateIsolatedSession. + // Rest of the cleanup: the session cannot be accessed anymore by getOrCreateIsolatedSession. sessionHolder.foreach(shutdownSessionHolder(_)) } private[connect] def shutdown(): Unit = { - sessionsLock.synchronized { - scheduledExecutor.foreach { executor => - ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES)) - } - scheduledExecutor = None + val executor = scheduledExecutor.getAndSet(null) + if (executor != null) { + ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES)) } // note: this does not cleanly shut down the sessions, but the server is shutting down. @@ -199,16 +192,16 @@ class SparkConnectSessionManager extends Logging { * * The checks are looking to remove sessions that expired. */ - private def schedulePeriodicChecks(): Unit = sessionsLock.synchronized { - scheduledExecutor match { - case Some(_) => // Already running. - case None => + private def schedulePeriodicChecks(): Unit = { + var executor = scheduledExecutor.getAcquire() + if (executor == null) { + executor = Executors.newSingleThreadScheduledExecutor() + if (scheduledExecutor.compareAndExchangeRelease(null, executor) == null) { val interval = SparkEnv.get.conf.get(CONNECT_SESSION_MANAGER_MAINTENANCE_INTERVAL) logInfo( log"Starting thread for cleanup of expired sessions every " + log"${MDC(INTERVAL, interval)} ms") - scheduledExecutor = Some(Executors.newSingleThreadScheduledExecutor()) - scheduledExecutor.get.scheduleAtFixedRate( + executor.scheduleAtFixedRate( () => { try { val defaultInactiveTimeoutMs = @@ -221,6 +214,7 @@ class SparkConnectSessionManager extends Logging { interval, interval, TimeUnit.MILLISECONDS) + } } } @@ -255,24 +249,18 @@ class SparkConnectSessionManager extends Logging { // .. and remove them. toRemove.foreach { sessionHolder => - // This doesn't use closeSession to be able to do the extra last chance check under lock. - val removedSession = { - // Last chance - check expiration time and remove under lock if expired. - val info = sessionHolder.getSessionHolderInfo - if (shouldExpire(info, System.currentTimeMillis())) { - logInfo( - log"Found session ${MDC(SESSION_HOLD_INFO, info)} that expired " + - log"and will be closed.") - removeSessionHolder(info.key) - } else { - None + val info = sessionHolder.getSessionHolderInfo + if (shouldExpire(info, System.currentTimeMillis())) { + logInfo( + log"Found session ${MDC(SESSION_HOLD_INFO, info)} that expired " + + log"and will be closed.") + removeSessionHolder(info.key) + try { + shutdownSessionHolder(sessionHolder) + } catch { + case NonFatal(ex) => logWarning("Unexpected exception closing session", ex) } } - // do shutdown and cleanup outside of lock. - try removedSession.foreach(shutdownSessionHolder(_)) - catch { - case NonFatal(ex) => logWarning("Unexpected exception closing session", ex) - } } logInfo("Finished periodic run of SparkConnectSessionManager maintenance.") } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala index 03719ddd8741..8241672d5107 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.connect.service import java.util.concurrent.Executors import java.util.concurrent.ScheduledExecutorService import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicReference import javax.annotation.concurrent.GuardedBy import scala.collection.mutable @@ -185,10 +186,10 @@ private[connect] class SparkConnectStreamingQueryCache( // Visible for testing. private[service] def shutdown(): Unit = queryCacheLock.synchronized { - scheduledExecutor.foreach { executor => + val executor = scheduledExecutor.getAndSet(null) + if (executor != null) { ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES)) } - scheduledExecutor = None } @GuardedBy("queryCacheLock") @@ -199,19 +200,19 @@ private[connect] class SparkConnectStreamingQueryCache( private val taggedQueries = new mutable.HashMap[String, mutable.ArrayBuffer[QueryCacheKey]] private val taggedQueriesLock = new Object - @GuardedBy("queryCacheLock") - private var scheduledExecutor: Option[ScheduledExecutorService] = None + private var scheduledExecutor: AtomicReference[ScheduledExecutorService] = + new AtomicReference[ScheduledExecutorService]() /** Schedules periodic checks if it is not already scheduled */ - private def schedulePeriodicChecks(): Unit = queryCacheLock.synchronized { - scheduledExecutor match { - case Some(_) => // Already running. - case None => + private def schedulePeriodicChecks(): Unit = { + var executor = scheduledExecutor.getAcquire() + if (executor == null) { + executor = Executors.newSingleThreadScheduledExecutor() + if (scheduledExecutor.compareAndExchangeRelease(null, executor) == null) { logInfo( log"Starting thread for polling streaming sessions " + log"every ${MDC(DURATION, sessionPollingPeriod.toMillis)}") - scheduledExecutor = Some(Executors.newSingleThreadScheduledExecutor()) - scheduledExecutor.get.scheduleAtFixedRate( + executor.scheduleAtFixedRate( () => { try periodicMaintenance() catch { @@ -221,6 +222,7 @@ private[connect] class SparkConnectStreamingQueryCache( sessionPollingPeriod.toMillis, sessionPollingPeriod.toMillis, TimeUnit.MILLISECONDS) + } } }