diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 5a42299a0bf8..17014e4954f9 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -18,9 +18,9 @@ package org.apache.spark import java.lang.ref.{ReferenceQueue, WeakReference} -import java.util.concurrent.{ScheduledExecutorService, TimeUnit} +import java.util.concurrent.{ConcurrentLinkedQueue, ScheduledExecutorService, TimeUnit} -import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import scala.collection.JavaConverters._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData} @@ -57,13 +57,11 @@ private class CleanupTaskWeakReference( */ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { - private val referenceBuffer = new ArrayBuffer[CleanupTaskWeakReference] - with SynchronizedBuffer[CleanupTaskWeakReference] + private val referenceBuffer = new ConcurrentLinkedQueue[CleanupTaskWeakReference]() private val referenceQueue = new ReferenceQueue[AnyRef] - private val listeners = new ArrayBuffer[CleanerListener] - with SynchronizedBuffer[CleanerListener] + private val listeners = new ConcurrentLinkedQueue[CleanerListener]() private val cleaningThread = new Thread() { override def run() { keepCleaning() }} @@ -111,7 +109,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { /** Attach a listener object to get information of when objects are cleaned. */ def attachListener(listener: CleanerListener): Unit = { - listeners += listener + listeners.add(listener) } /** Start the cleaner. */ @@ -166,7 +164,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { /** Register an object for cleanup. */ private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask): Unit = { - referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue) + referenceBuffer.add(new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue)) } /** Keep cleaning RDD, shuffle, and broadcast state. */ @@ -179,7 +177,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { synchronized { reference.map(_.task).foreach { task => logDebug("Got cleaning task " + task) - referenceBuffer -= reference.get + referenceBuffer.remove(reference.get) task match { case CleanRDD(rddId) => doCleanupRDD(rddId, blocking = blockOnCleanupTasks) @@ -206,7 +204,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { try { logDebug("Cleaning RDD " + rddId) sc.unpersistRDD(rddId, blocking) - listeners.foreach(_.rddCleaned(rddId)) + listeners.asScala.foreach(_.rddCleaned(rddId)) logInfo("Cleaned RDD " + rddId) } catch { case e: Exception => logError("Error cleaning RDD " + rddId, e) @@ -219,7 +217,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { logDebug("Cleaning shuffle " + shuffleId) mapOutputTrackerMaster.unregisterShuffle(shuffleId) blockManagerMaster.removeShuffle(shuffleId, blocking) - listeners.foreach(_.shuffleCleaned(shuffleId)) + listeners.asScala.foreach(_.shuffleCleaned(shuffleId)) logInfo("Cleaned shuffle " + shuffleId) } catch { case e: Exception => logError("Error cleaning shuffle " + shuffleId, e) @@ -231,7 +229,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { try { logDebug(s"Cleaning broadcast $broadcastId") broadcastManager.unbroadcast(broadcastId, true, blocking) - listeners.foreach(_.broadcastCleaned(broadcastId)) + listeners.asScala.foreach(_.broadcastCleaned(broadcastId)) logDebug(s"Cleaned broadcast $broadcastId") } catch { case e: Exception => logError("Error cleaning broadcast " + broadcastId, e) @@ -243,7 +241,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { try { logDebug("Cleaning accumulator " + accId) Accumulators.remove(accId) - listeners.foreach(_.accumCleaned(accId)) + listeners.asScala.foreach(_.accumCleaned(accId)) logInfo("Cleaned accumulator " + accId) } catch { case e: Exception => logError("Error cleaning accumulator " + accId, e) @@ -258,7 +256,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { try { logDebug("Cleaning rdd checkpoint data " + rddId) ReliableRDDCheckpointData.cleanCheckpoint(sc, rddId) - listeners.foreach(_.checkpointCleaned(rddId)) + listeners.asScala.foreach(_.checkpointCleaned(rddId)) logInfo("Cleaned rdd checkpoint data " + rddId) } catch { diff --git a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala index eb794b6739d5..658779360b7a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.deploy.client -import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import java.util.concurrent.ConcurrentLinkedQueue + +import scala.collection.JavaConverters._ import scala.concurrent.duration._ import org.scalatest.BeforeAndAfterAll @@ -165,14 +167,14 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd /** Application Listener to collect events */ private class AppClientCollector extends AppClientListener with Logging { - val connectedIdList = new ArrayBuffer[String] with SynchronizedBuffer[String] + val connectedIdList = new ConcurrentLinkedQueue[String]() @volatile var disconnectedCount: Int = 0 - val deadReasonList = new ArrayBuffer[String] with SynchronizedBuffer[String] - val execAddedList = new ArrayBuffer[String] with SynchronizedBuffer[String] - val execRemovedList = new ArrayBuffer[String] with SynchronizedBuffer[String] + val deadReasonList = new ConcurrentLinkedQueue[String]() + val execAddedList = new ConcurrentLinkedQueue[String]() + val execRemovedList = new ConcurrentLinkedQueue[String]() def connected(id: String): Unit = { - connectedIdList += id + connectedIdList.add(id) } def disconnected(): Unit = { @@ -182,7 +184,7 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd } def dead(reason: String): Unit = { - deadReasonList += reason + deadReasonList.add(reason) } def executorAdded( @@ -191,11 +193,11 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd hostPort: String, cores: Int, memory: Int): Unit = { - execAddedList += id + execAddedList.add(id) } def executorRemoved(id: String, message: String, exitStatus: Option[Int]): Unit = { - execRemovedList += id + execRemovedList.add(id) } } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 6f4eda8b47dd..22048003882d 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -20,9 +20,10 @@ package org.apache.spark.rpc import java.io.{File, NotSerializableException} import java.nio.charset.StandardCharsets.UTF_8 import java.util.UUID -import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit} +import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch, TimeoutException, TimeUnit} import scala.collection.mutable +import scala.collection.JavaConverters._ import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps @@ -490,30 +491,30 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { /** * Setup an [[RpcEndpoint]] to collect all network events. - * @return the [[RpcEndpointRef]] and an `Seq` that contains network events. + * @return the [[RpcEndpointRef]] and an `ConcurrentLinkedQueue` that contains network events. */ private def setupNetworkEndpoint( _env: RpcEnv, - name: String): (RpcEndpointRef, Seq[(Any, Any)]) = { - val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)] + name: String): (RpcEndpointRef, ConcurrentLinkedQueue[(Any, Any)]) = { + val events = new ConcurrentLinkedQueue[(Any, Any)] val ref = _env.setupEndpoint("network-events-non-client", new ThreadSafeRpcEndpoint { override val rpcEnv = _env override def receive: PartialFunction[Any, Unit] = { case "hello" => - case m => events += "receive" -> m + case m => events.add("receive" -> m) } override def onConnected(remoteAddress: RpcAddress): Unit = { - events += "onConnected" -> remoteAddress + events.add("onConnected" -> remoteAddress) } override def onDisconnected(remoteAddress: RpcAddress): Unit = { - events += "onDisconnected" -> remoteAddress + events.add("onDisconnected" -> remoteAddress) } override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { - events += "onNetworkError" -> remoteAddress + events.add("onNetworkError" -> remoteAddress) } }) @@ -560,7 +561,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { eventually(timeout(5 seconds), interval(5 millis)) { // We don't know the exact client address but at least we can verify the message type - assert(events.map(_._1).contains("onConnected")) + assert(events.asScala.map(_._1).exists(_ == "onConnected")) } clientEnv.shutdown() @@ -568,8 +569,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { eventually(timeout(5 seconds), interval(5 millis)) { // We don't know the exact client address but at least we can verify the message type - assert(events.map(_._1).contains("onConnected")) - assert(events.map(_._1).contains("onDisconnected")) + assert(events.asScala.map(_._1).exists(_ == "onConnected")) + assert(events.asScala.map(_._1).exists(_ == "onDisconnected")) } } finally { clientEnv.shutdown() diff --git a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala index b207d497f33c..6f7dddd4f760 100644 --- a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.util -import java.util.concurrent.CountDownLatch +import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch} -import scala.collection.mutable +import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.language.postfixOps @@ -31,11 +31,11 @@ import org.apache.spark.SparkFunSuite class EventLoopSuite extends SparkFunSuite with Timeouts { test("EventLoop") { - val buffer = new mutable.ArrayBuffer[Int] with mutable.SynchronizedBuffer[Int] + val buffer = new ConcurrentLinkedQueue[Int] val eventLoop = new EventLoop[Int]("test") { override def onReceive(event: Int): Unit = { - buffer += event + buffer.add(event) } override def onError(e: Throwable): Unit = {} @@ -43,7 +43,7 @@ class EventLoopSuite extends SparkFunSuite with Timeouts { eventLoop.start() (1 to 100).foreach(eventLoop.post) eventually(timeout(5 seconds), interval(5 millis)) { - assert((1 to 100) === buffer.toSeq) + assert((1 to 100) === buffer.asScala.toSeq) } eventLoop.stop() }