Skip to content

Commit 7b1c178

Browse files
author
Changgyoo Park
committed
Optimize
1 parent dd8d127 commit 7b1c178

File tree

3 files changed

+61
-80
lines changed

3 files changed

+61
-80
lines changed

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala

Lines changed: 25 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.connect.service
1919

2020
import java.util.UUID
2121
import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, ScheduledExecutorService, TimeUnit}
22-
import javax.annotation.concurrent.GuardedBy
22+
import java.util.concurrent.atomic.{AtomicLong, AtomicReference}
2323

2424
import scala.collection.mutable
2525
import scala.concurrent.duration.FiniteDuration
@@ -66,21 +66,19 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
6666
/** Concurrent hash table containing all the current executions. */
6767
private val executions: ConcurrentMap[ExecuteKey, ExecuteHolder] =
6868
new ConcurrentHashMap[ExecuteKey, ExecuteHolder]()
69-
private val executionsLock = new Object
7069

7170
/** Graveyard of tombstones of executions that were abandoned and removed. */
7271
private val abandonedTombstones = CacheBuilder
7372
.newBuilder()
7473
.maximumSize(SparkEnv.get.conf.get(CONNECT_EXECUTE_MANAGER_ABANDONED_TOMBSTONES_SIZE))
7574
.build[ExecuteKey, ExecuteInfo]()
7675

77-
/** None if there are no executions. Otherwise, the time when the last execution was removed. */
78-
@GuardedBy("executionsLock")
79-
private var lastExecutionTimeMs: Option[Long] = Some(System.currentTimeMillis())
76+
/** The time when the last execution was removed. */
77+
private var lastExecutionTimeMs: AtomicLong = new AtomicLong(System.currentTimeMillis())
8078

8179
/** Executor for the periodic maintenance */
82-
@GuardedBy("executionsLock")
83-
private var scheduledExecutor: Option[ScheduledExecutorService] = None
80+
private var scheduledExecutor: AtomicReference[ScheduledExecutorService] =
81+
new AtomicReference[ScheduledExecutorService]()
8482

8583
/**
8684
* Create a new ExecuteHolder and register it with this global manager and with its session.
@@ -118,11 +116,6 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
118116

119117
sessionHolder.addExecuteHolder(executeHolder)
120118

121-
executionsLock.synchronized {
122-
if (!executions.isEmpty()) {
123-
lastExecutionTimeMs = None
124-
}
125-
}
126119
logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, executeHolder.key)} is created.")
127120

128121
schedulePeriodicChecks() // Starts the maintenance thread if it hasn't started.
@@ -151,11 +144,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
151144
executions.remove(key)
152145
executeHolder.sessionHolder.removeExecuteHolder(executeHolder.operationId)
153146

154-
executionsLock.synchronized {
155-
if (executions.isEmpty) {
156-
lastExecutionTimeMs = Some(System.currentTimeMillis())
157-
}
158-
}
147+
updateLastExecutionTime()
159148

160149
logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, key)} is removed.")
161150

@@ -197,7 +186,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
197186
*/
198187
def listActiveExecutions: Either[Long, Seq[ExecuteInfo]] = {
199188
if (executions.isEmpty) {
200-
Left(lastExecutionTimeMs.get)
189+
Left(lastExecutionTimeMs.getAcquire())
201190
} else {
202191
Right(executions.values().asScala.map(_.getExecuteInfo).toBuffer.toSeq)
203192
}
@@ -212,39 +201,40 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
212201
}
213202

214203
private[connect] def shutdown(): Unit = {
215-
executionsLock.synchronized {
216-
scheduledExecutor.foreach { executor =>
217-
ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES))
218-
}
219-
scheduledExecutor = None
204+
val executor = scheduledExecutor.getAndSet(null)
205+
if (executor != null) {
206+
ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES))
220207
}
221208

222209
// note: this does not cleanly shut down the executions, but the server is shutting down.
223210
executions.clear()
224211
abandonedTombstones.invalidateAll()
225212

226-
executionsLock.synchronized {
227-
if (lastExecutionTimeMs.isEmpty) {
228-
lastExecutionTimeMs = Some(System.currentTimeMillis())
229-
}
230-
}
213+
updateLastExecutionTime()
214+
}
215+
216+
/**
217+
* Updates the last execution time after the last execution has been removed.
218+
*/
219+
private def updateLastExecutionTime(): Unit = {
220+
lastExecutionTimeMs.getAndUpdate(prev => prev.max(System.currentTimeMillis()))
231221
}
232222

233223
/**
234224
* Schedules periodic maintenance checks if it is not already scheduled. The checks are looking
235225
* for executions that have not been closed, but are left with no RPC attached to them, and
236226
* removes them after a timeout.
237227
*/
238-
private def schedulePeriodicChecks(): Unit = executionsLock.synchronized {
239-
scheduledExecutor match {
240-
case Some(_) => // Already running.
241-
case None =>
228+
private def schedulePeriodicChecks(): Unit = {
229+
var executor = scheduledExecutor.getAcquire()
230+
if (executor == null) {
231+
executor = Executors.newSingleThreadScheduledExecutor()
232+
if (scheduledExecutor.compareAndExchangeRelease(null, executor) == null) {
242233
val interval = SparkEnv.get.conf.get(CONNECT_EXECUTE_MANAGER_MAINTENANCE_INTERVAL)
243234
logInfo(
244235
log"Starting thread for cleanup of abandoned executions every " +
245236
log"${MDC(LogKeys.INTERVAL, interval)} ms")
246-
scheduledExecutor = Some(Executors.newSingleThreadScheduledExecutor())
247-
scheduledExecutor.get.scheduleAtFixedRate(
237+
executor.scheduleAtFixedRate(
248238
() => {
249239
try {
250240
val timeout = SparkEnv.get.conf.get(CONNECT_EXECUTE_MANAGER_DETACHED_TIMEOUT)
@@ -256,6 +246,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
256246
interval,
257247
interval,
258248
TimeUnit.MILLISECONDS)
249+
}
259250
}
260251
}
261252

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala

Lines changed: 24 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.connect.service
1919

2020
import java.util.UUID
2121
import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, ScheduledExecutorService, TimeUnit}
22-
import javax.annotation.concurrent.GuardedBy
22+
import java.util.concurrent.atomic.AtomicReference
2323

2424
import scala.collection.mutable
2525
import scala.concurrent.duration.FiniteDuration
@@ -40,8 +40,6 @@ import org.apache.spark.util.ThreadUtils
4040
*/
4141
class SparkConnectSessionManager extends Logging {
4242

43-
private val sessionsLock = new Object
44-
4543
private val sessionStore: ConcurrentMap[SessionKey, SessionHolder] =
4644
new ConcurrentHashMap[SessionKey, SessionHolder]()
4745

@@ -52,8 +50,8 @@ class SparkConnectSessionManager extends Logging {
5250
.build[SessionKey, SessionHolderInfo]()
5351

5452
/** Executor for the periodic maintenance */
55-
@GuardedBy("sessionsLock")
56-
private var scheduledExecutor: Option[ScheduledExecutorService] = None
53+
private var scheduledExecutor: AtomicReference[ScheduledExecutorService] =
54+
new AtomicReference[ScheduledExecutorService]()
5755

5856
private def validateSessionId(
5957
key: SessionKey,
@@ -75,8 +73,6 @@ class SparkConnectSessionManager extends Logging {
7573
val holder = getSession(
7674
key,
7775
Some(() => {
78-
// Executed under sessionsState lock in getSession, to guard against concurrent removal
79-
// and insertion into closedSessionsCache.
8076
validateSessionCreate(key)
8177
val holder = SessionHolder(key.userId, key.sessionId, newIsolatedSession())
8278
holder.initializeSession()
@@ -168,17 +164,14 @@ class SparkConnectSessionManager extends Logging {
168164

169165
def closeSession(key: SessionKey): Unit = {
170166
val sessionHolder = removeSessionHolder(key)
171-
// Rest of the cleanup outside sessionLock - the session cannot be accessed anymore by
172-
// getOrCreateIsolatedSession.
167+
// Rest of the cleanup: the session cannot be accessed anymore by getOrCreateIsolatedSession.
173168
sessionHolder.foreach(shutdownSessionHolder(_))
174169
}
175170

176171
private[connect] def shutdown(): Unit = {
177-
sessionsLock.synchronized {
178-
scheduledExecutor.foreach { executor =>
179-
ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES))
180-
}
181-
scheduledExecutor = None
172+
val executor = scheduledExecutor.getAndSet(null)
173+
if (executor != null) {
174+
ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES))
182175
}
183176

184177
// note: this does not cleanly shut down the sessions, but the server is shutting down.
@@ -199,16 +192,16 @@ class SparkConnectSessionManager extends Logging {
199192
*
200193
* The checks are looking to remove sessions that expired.
201194
*/
202-
private def schedulePeriodicChecks(): Unit = sessionsLock.synchronized {
203-
scheduledExecutor match {
204-
case Some(_) => // Already running.
205-
case None =>
195+
private def schedulePeriodicChecks(): Unit = {
196+
var executor = scheduledExecutor.getAcquire()
197+
if (executor == null) {
198+
executor = Executors.newSingleThreadScheduledExecutor()
199+
if (scheduledExecutor.compareAndExchangeRelease(null, executor) == null) {
206200
val interval = SparkEnv.get.conf.get(CONNECT_SESSION_MANAGER_MAINTENANCE_INTERVAL)
207201
logInfo(
208202
log"Starting thread for cleanup of expired sessions every " +
209203
log"${MDC(INTERVAL, interval)} ms")
210-
scheduledExecutor = Some(Executors.newSingleThreadScheduledExecutor())
211-
scheduledExecutor.get.scheduleAtFixedRate(
204+
executor.scheduleAtFixedRate(
212205
() => {
213206
try {
214207
val defaultInactiveTimeoutMs =
@@ -221,6 +214,7 @@ class SparkConnectSessionManager extends Logging {
221214
interval,
222215
interval,
223216
TimeUnit.MILLISECONDS)
217+
}
224218
}
225219
}
226220

@@ -255,24 +249,18 @@ class SparkConnectSessionManager extends Logging {
255249

256250
// .. and remove them.
257251
toRemove.foreach { sessionHolder =>
258-
// This doesn't use closeSession to be able to do the extra last chance check under lock.
259-
val removedSession = {
260-
// Last chance - check expiration time and remove under lock if expired.
261-
val info = sessionHolder.getSessionHolderInfo
262-
if (shouldExpire(info, System.currentTimeMillis())) {
263-
logInfo(
264-
log"Found session ${MDC(SESSION_HOLD_INFO, info)} that expired " +
265-
log"and will be closed.")
266-
removeSessionHolder(info.key)
267-
} else {
268-
None
252+
val info = sessionHolder.getSessionHolderInfo
253+
if (shouldExpire(info, System.currentTimeMillis())) {
254+
logInfo(
255+
log"Found session ${MDC(SESSION_HOLD_INFO, info)} that expired " +
256+
log"and will be closed.")
257+
removeSessionHolder(info.key)
258+
try {
259+
shutdownSessionHolder(sessionHolder)
260+
} catch {
261+
case NonFatal(ex) => logWarning("Unexpected exception closing session", ex)
269262
}
270263
}
271-
// do shutdown and cleanup outside of lock.
272-
try removedSession.foreach(shutdownSessionHolder(_))
273-
catch {
274-
case NonFatal(ex) => logWarning("Unexpected exception closing session", ex)
275-
}
276264
}
277265
logInfo("Finished periodic run of SparkConnectSessionManager maintenance.")
278266
}

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.connect.service
2020
import java.util.concurrent.Executors
2121
import java.util.concurrent.ScheduledExecutorService
2222
import java.util.concurrent.TimeUnit
23+
import java.util.concurrent.atomic.AtomicReference
2324
import javax.annotation.concurrent.GuardedBy
2425

2526
import scala.collection.mutable
@@ -185,10 +186,10 @@ private[connect] class SparkConnectStreamingQueryCache(
185186

186187
// Visible for testing.
187188
private[service] def shutdown(): Unit = queryCacheLock.synchronized {
188-
scheduledExecutor.foreach { executor =>
189+
val executor = scheduledExecutor.getAndSet(null)
190+
if (executor != null) {
189191
ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES))
190192
}
191-
scheduledExecutor = None
192193
}
193194

194195
@GuardedBy("queryCacheLock")
@@ -199,19 +200,19 @@ private[connect] class SparkConnectStreamingQueryCache(
199200
private val taggedQueries = new mutable.HashMap[String, mutable.ArrayBuffer[QueryCacheKey]]
200201
private val taggedQueriesLock = new Object
201202

202-
@GuardedBy("queryCacheLock")
203-
private var scheduledExecutor: Option[ScheduledExecutorService] = None
203+
private var scheduledExecutor: AtomicReference[ScheduledExecutorService] =
204+
new AtomicReference[ScheduledExecutorService]()
204205

205206
/** Schedules periodic checks if it is not already scheduled */
206-
private def schedulePeriodicChecks(): Unit = queryCacheLock.synchronized {
207-
scheduledExecutor match {
208-
case Some(_) => // Already running.
209-
case None =>
207+
private def schedulePeriodicChecks(): Unit = {
208+
var executor = scheduledExecutor.getAcquire()
209+
if (executor == null) {
210+
executor = Executors.newSingleThreadScheduledExecutor()
211+
if (scheduledExecutor.compareAndExchangeRelease(null, executor) == null) {
210212
logInfo(
211213
log"Starting thread for polling streaming sessions " +
212214
log"every ${MDC(DURATION, sessionPollingPeriod.toMillis)}")
213-
scheduledExecutor = Some(Executors.newSingleThreadScheduledExecutor())
214-
scheduledExecutor.get.scheduleAtFixedRate(
215+
executor.scheduleAtFixedRate(
215216
() => {
216217
try periodicMaintenance()
217218
catch {
@@ -221,6 +222,7 @@ private[connect] class SparkConnectStreamingQueryCache(
221222
sessionPollingPeriod.toMillis,
222223
sessionPollingPeriod.toMillis,
223224
TimeUnit.MILLISECONDS)
225+
}
224226
}
225227
}
226228

0 commit comments

Comments
 (0)