From 1e752f1a5c5f3887df2ca20d63a9d30f1d32f9d1 Mon Sep 17 00:00:00 2001 From: Roman Pastukhov Date: Wed, 5 Feb 2014 20:11:56 +0400 Subject: [PATCH 01/37] Added unpersist method to Broadcast. --- .../scala/org/apache/spark/SparkContext.scala | 7 ++- .../apache/spark/broadcast/Broadcast.scala | 13 ++++- .../spark/broadcast/BroadcastFactory.scala | 2 +- .../spark/broadcast/HttpBroadcast.scala | 45 ++++++++++++----- .../spark/broadcast/TorrentBroadcast.scala | 43 +++++++++++++---- .../apache/spark/storage/BlockManager.scala | 12 +++++ .../apache/spark/storage/MemoryStore.scala | 31 +++++++----- .../org/apache/spark/BroadcastSuite.scala | 48 +++++++++++++++++++ 8 files changed, 163 insertions(+), 38 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 566472e597958..f42589c3900d0 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -613,8 +613,13 @@ class SparkContext( * Broadcast a read-only variable to the cluster, returning a * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. * The variable will be sent to each cluster only once. + * + * If `registerBlocks` is true, workers will notify driver about blocks they create + * and these blocks will be dropped when `unpersist` method of the broadcast variable is called. */ - def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal) + def broadcast[T](value: T, registerBlocks: Boolean = false) = { + env.broadcastManager.newBroadcast[T](value, isLocal, registerBlocks) + } /** * Add a file to be downloaded with this Spark job on every node. diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index d113d4040594d..076d98f8de991 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -53,6 +53,15 @@ import org.apache.spark._ abstract class Broadcast[T](val id: Long) extends Serializable { def value: T + /** + * Removes all blocks of this broadcast from memory (and disk if removeSource is true). + * + * @param removeSource Whether to remove data from disk as well. + * Will cause errors if broadcast is accessed on workers afterwards + * (e.g. in case of RDD re-computation due to executor failure). + */ + def unpersist(removeSource: Boolean = false) + // We cannot have an abstract readObject here due to some weird issues with // readObject having to be 'private' in sub-classes. @@ -91,8 +100,8 @@ class BroadcastManager(val _isDriver: Boolean, conf: SparkConf) extends Logging private val nextBroadcastId = new AtomicLong(0) - def newBroadcast[T](value_ : T, isLocal: Boolean) = - broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) + def newBroadcast[T](value_ : T, isLocal: Boolean, registerBlocks: Boolean) = + broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement(), registerBlocks) def isDriver = _isDriver } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index 940e5ab805100..e38283f244ea1 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -27,6 +27,6 @@ import org.apache.spark.SparkConf */ trait BroadcastFactory { def initialize(isDriver: Boolean, conf: SparkConf): Unit - def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] + def newBroadcast[T](value: T, isLocal: Boolean, id: Long, registerBlocks: Boolean): Broadcast[T] def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 39ee0dbb92841..53fcc2748b4e0 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -29,11 +29,20 @@ import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils} -private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) +private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) extends Broadcast[T](id) with Logging with Serializable { def value = value_ + def unpersist(removeSource: Boolean) { + SparkEnv.get.blockManager.master.removeBlock(blockId) + SparkEnv.get.blockManager.removeBlock(blockId) + + if (removeSource) { + HttpBroadcast.cleanupById(id) + } + } + def blockId = BroadcastBlockId(id) HttpBroadcast.synchronized { @@ -54,7 +63,7 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea logInfo("Started reading broadcast variable " + id) val start = System.nanoTime value_ = HttpBroadcast.read[T](id) - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) + SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, registerBlocks) val time = (System.nanoTime - start) / 1e9 logInfo("Reading broadcast variable " + id + " took " + time + " s") } @@ -69,8 +78,8 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea class HttpBroadcastFactory extends BroadcastFactory { def initialize(isDriver: Boolean, conf: SparkConf) { HttpBroadcast.initialize(isDriver, conf) } - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = - new HttpBroadcast[T](value_, isLocal, id) + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) = + new HttpBroadcast[T](value_, isLocal, id, registerBlocks) def stop() { HttpBroadcast.stop() } } @@ -132,8 +141,10 @@ private object HttpBroadcast extends Logging { logInfo("Broadcast server started at " + serverUri) } + def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name) + def write(id: Long, value: Any) { - val file = new File(broadcastDir, BroadcastBlockId(id).name) + val file = getFile(id) val out: OutputStream = { if (compress) { compressionCodec.compressedOutputStream(new FileOutputStream(file)) @@ -167,20 +178,30 @@ private object HttpBroadcast extends Logging { obj } + def deleteFile(fileName: String) { + try { + new File(fileName).delete() + logInfo("Deleted broadcast file '" + fileName + "'") + } catch { + case e: Exception => logWarning("Could not delete broadcast file '" + fileName + "'", e) + } + } + def cleanup(cleanupTime: Long) { val iterator = files.internalMap.entrySet().iterator() while(iterator.hasNext) { val entry = iterator.next() val (file, time) = (entry.getKey, entry.getValue) if (time < cleanupTime) { - try { - iterator.remove() - new File(file.toString).delete() - logInfo("Deleted broadcast file '" + file + "'") - } catch { - case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e) - } + iterator.remove() + deleteFile(file) } } } + + def cleanupById(id: Long) { + val file = getFile(id).getAbsolutePath + files.internalMap.remove(file) + deleteFile(file) + } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index d351dfc1f56a2..11e74675491c6 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -23,16 +23,36 @@ import scala.math import scala.util.Random import org.apache.spark._ -import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel} +import org.apache.spark.storage.{BlockId, BroadcastBlockId, BroadcastHelperBlockId, StorageLevel} import org.apache.spark.util.Utils -private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) +private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) extends Broadcast[T](id) with Logging with Serializable { def value = value_ + def unpersist(removeSource: Boolean) { + SparkEnv.get.blockManager.master.removeBlock(broadcastId) + SparkEnv.get.blockManager.removeBlock(broadcastId) + + if (removeSource) { + for (pid <- pieceIds) { + SparkEnv.get.blockManager.removeBlock(pieceBlockId(pid)) + } + SparkEnv.get.blockManager.removeBlock(metaId) + } else { + for (pid <- pieceIds) { + SparkEnv.get.blockManager.dropFromMemory(pieceBlockId(pid)) + } + SparkEnv.get.blockManager.dropFromMemory(metaId) + } + } + def broadcastId = BroadcastBlockId(id) + private def metaId = BroadcastHelperBlockId(broadcastId, "meta") + private def pieceBlockId(pid: Int) = BroadcastHelperBlockId(broadcastId, "piece" + pid) + private def pieceIds = Array.iterate(0, totalBlocks)(_ + 1).toList TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) @@ -55,7 +75,6 @@ extends Broadcast[T](id) with Logging with Serializable { hasBlocks = tInfo.totalBlocks // Store meta-info - val metaId = BroadcastHelperBlockId(broadcastId, "meta") val metaInfo = TorrentInfo(null, totalBlocks, totalBytes) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( @@ -64,7 +83,7 @@ extends Broadcast[T](id) with Logging with Serializable { // Store individual pieces for (i <- 0 until totalBlocks) { - val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i) + val pieceId = pieceBlockId(i) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true) @@ -94,7 +113,7 @@ extends Broadcast[T](id) with Logging with Serializable { // This creates a tradeoff between memory usage and latency. // Storing copy doubles the memory footprint; not storing doubles deserialization cost. SparkEnv.get.blockManager.putSingle( - broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) + broadcastId, value_, StorageLevel.MEMORY_AND_DISK, registerBlocks) // Remove arrayOfBlocks from memory once value_ is on local cache resetWorkerVariables() @@ -109,6 +128,11 @@ extends Broadcast[T](id) with Logging with Serializable { } private def resetWorkerVariables() { + if (arrayOfBlocks != null) { + for (pid <- pieceIds) { + SparkEnv.get.blockManager.removeBlock(pieceBlockId(pid)) + } + } arrayOfBlocks = null totalBytes = -1 totalBlocks = -1 @@ -117,7 +141,6 @@ extends Broadcast[T](id) with Logging with Serializable { def receiveBroadcast(variableID: Long): Boolean = { // Receive meta-info - val metaId = BroadcastHelperBlockId(broadcastId, "meta") var attemptId = 10 while (attemptId > 0 && totalBlocks == -1) { TorrentBroadcast.synchronized { @@ -140,9 +163,9 @@ extends Broadcast[T](id) with Logging with Serializable { } // Receive actual blocks - val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList) + val recvOrder = new Random().shuffle(pieceIds) for (pid <- recvOrder) { - val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid) + val pieceId = pieceBlockId(pid) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.getSingle(pieceId) match { case Some(x) => @@ -243,8 +266,8 @@ class TorrentBroadcastFactory extends BroadcastFactory { def initialize(isDriver: Boolean, conf: SparkConf) { TorrentBroadcast.initialize(isDriver, conf) } - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = - new TorrentBroadcast[T](value_, isLocal, id) + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) = + new TorrentBroadcast[T](value_, isLocal, id, registerBlocks) def stop() { TorrentBroadcast.stop() } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index ed53558566edf..f8c121615567f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -196,6 +196,11 @@ private[spark] class BlockManager( } } + /** + * For testing. Returns number of blocks BlockManager knows about that are in memory. + */ + def numberOfBlocksInMemory() = blockInfo.keys.count(memoryStore.contains(_)) + /** * Get storage level of local block. If no info exists for the block, then returns null. */ @@ -720,6 +725,13 @@ private[spark] class BlockManager( } /** + * Drop a block from memory, possibly putting it on disk if applicable. + */ + def dropFromMemory(blockId: BlockId) { + memoryStore.asInstanceOf[MemoryStore].dropFromMemory(blockId) + } + + /** * Remove all blocks belonging to the given RDD. * @return The number of blocks removed. */ diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index eb5a18521683e..4e47a06c1fed2 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -182,6 +182,24 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } + /** + * Drop a block from memory, possibly putting it on disk if applicable. + */ + def dropFromMemory(blockId: BlockId) { + val entry = entries.synchronized { entries.get(blockId) } + // This should never be null as only one thread should be dropping + // blocks and removing entries. However the check is still here for + // future safety. + if (entry != null) { + val data = if (entry.deserialized) { + Left(entry.value.asInstanceOf[ArrayBuffer[Any]]) + } else { + Right(entry.value.asInstanceOf[ByteBuffer].duplicate()) + } + blockManager.dropFromMemory(blockId, data) + } + } + /** * Tries to free up a given amount of space to store a particular block, but can fail and return * false if either the block is bigger than our memory or it would require replacing another @@ -227,18 +245,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) if (maxMemory - (currentMemory - selectedMemory) >= space) { logInfo(selectedBlocks.size + " blocks selected for dropping") for (blockId <- selectedBlocks) { - val entry = entries.synchronized { entries.get(blockId) } - // This should never be null as only one thread should be dropping - // blocks and removing entries. However the check is still here for - // future safety. - if (entry != null) { - val data = if (entry.deserialized) { - Left(entry.value.asInstanceOf[ArrayBuffer[Any]]) - } else { - Right(entry.value.asInstanceOf[ByteBuffer].duplicate()) - } - blockManager.dropFromMemory(blockId, data) - } + dropFromMemory(blockId) } return true } else { diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index e022accee6d08..a657753144b24 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -18,6 +18,11 @@ package org.apache.spark import org.scalatest.FunSuite +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.time.{Millis, Span} +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ +import org.scalatest.matchers.ShouldMatchers._ class BroadcastSuite extends FunSuite with LocalSparkContext { @@ -82,4 +87,47 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) } + def blocksExist(sc: SparkContext, numSlaves: Int) = { + val rdd = sc.parallelize(1 to numSlaves, numSlaves) + val workerBlocks = rdd.mapPartitions(_ => { + val blocks = SparkEnv.get.blockManager.numberOfBlocksInMemory() + Seq(blocks).iterator + }) + val totalKnown = workerBlocks.reduce(_ + _) + sc.env.blockManager.numberOfBlocksInMemory() + + totalKnown > 0 + } + + def testUnpersist(bcFactory: String, removeSource: Boolean) { + test("Broadcast unpersist(" + removeSource + ") with " + bcFactory) { + val numSlaves = 2 + System.setProperty("spark.broadcast.factory", bcFactory) + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test") + val list = List(1, 2, 3, 4) + + assert(!blocksExist(sc, numSlaves)) + + val listBroadcast = sc.broadcast(list, true) + val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum)) + assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + + assert(blocksExist(sc, numSlaves)) + + listBroadcast.unpersist(removeSource) + + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + blocksExist(sc, numSlaves) should be (false) + } + + if (!removeSource) { + val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum)) + assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + } + } + } + + for (removeSource <- Seq(true, false)) { + testUnpersist("org.apache.spark.broadcast.HttpBroadcastFactory", removeSource) + testUnpersist("org.apache.spark.broadcast.TorrentBroadcastFactory", removeSource) + } } From 80dd9778d2e7338bc93bc7de95ecc6776b0d9e8b Mon Sep 17 00:00:00 2001 From: Roman Pastukhov Date: Fri, 7 Feb 2014 02:53:29 +0400 Subject: [PATCH 02/37] Fix for Broadcast unpersist patch. Updated comment in MemoryStore.dropFromMemory Keep TorrentBroadcast piece blocks until unpersist is called --- .../spark/broadcast/HttpBroadcast.scala | 10 +++- .../spark/broadcast/TorrentBroadcast.scala | 57 ++++++++++++++----- .../apache/spark/storage/MemoryStore.scala | 6 +- 3 files changed, 54 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 53fcc2748b4e0..7f056b8feae27 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -35,11 +35,15 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea def value = value_ def unpersist(removeSource: Boolean) { - SparkEnv.get.blockManager.master.removeBlock(blockId) - SparkEnv.get.blockManager.removeBlock(blockId) + HttpBroadcast.synchronized { + SparkEnv.get.blockManager.master.removeBlock(blockId) + SparkEnv.get.blockManager.removeBlock(blockId) + } if (removeSource) { - HttpBroadcast.cleanupById(id) + HttpBroadcast.synchronized { + HttpBroadcast.cleanupById(id) + } } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 11e74675491c6..e6a8ae199e723 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -33,19 +33,55 @@ extends Broadcast[T](id) with Logging with Serializable { def value = value_ def unpersist(removeSource: Boolean) { - SparkEnv.get.blockManager.master.removeBlock(broadcastId) - SparkEnv.get.blockManager.removeBlock(broadcastId) + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.master.removeBlock(broadcastId) + SparkEnv.get.blockManager.removeBlock(broadcastId) + } + + if (!removeSource) { + //We can't tell BlockManager master to remove blocks from all nodes except driver, + //so we need to save them here in order to store them on disk later. + //This may be inefficient if blocks were already dropped to disk, + //but since unpersist is supposed to be called right after working with + //a broadcast this should not happen (and getting them from memory is cheap). + arrayOfBlocks = new Array[TorrentBlock](totalBlocks) + + for (pid <- 0 until totalBlocks) { + val pieceId = pieceBlockId(pid) + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.getSingle(pieceId) match { + case Some(x) => + arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock] + case None => + throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) + } + } + } + } + + for (pid <- 0 until totalBlocks) { + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.master.removeBlock(pieceBlockId(pid)) + } + } if (removeSource) { - for (pid <- pieceIds) { - SparkEnv.get.blockManager.removeBlock(pieceBlockId(pid)) + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.removeBlock(metaId) } - SparkEnv.get.blockManager.removeBlock(metaId) } else { - for (pid <- pieceIds) { - SparkEnv.get.blockManager.dropFromMemory(pieceBlockId(pid)) + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.dropFromMemory(metaId) } - SparkEnv.get.blockManager.dropFromMemory(metaId) + + for (i <- 0 until totalBlocks) { + val pieceId = pieceBlockId(i) + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.putSingle( + pieceId, arrayOfBlocks(i), StorageLevel.DISK_ONLY, true) + } + } + arrayOfBlocks = null } } @@ -128,11 +164,6 @@ extends Broadcast[T](id) with Logging with Serializable { } private def resetWorkerVariables() { - if (arrayOfBlocks != null) { - for (pid <- pieceIds) { - SparkEnv.get.blockManager.removeBlock(pieceBlockId(pid)) - } - } arrayOfBlocks = null totalBytes = -1 totalBlocks = -1 diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 4e47a06c1fed2..5dff0e95b31ba 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -187,9 +187,9 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) */ def dropFromMemory(blockId: BlockId) { val entry = entries.synchronized { entries.get(blockId) } - // This should never be null as only one thread should be dropping - // blocks and removing entries. However the check is still here for - // future safety. + // This should never be null if called from ensureFreeSpace as only one + // thread should be dropping blocks and removing entries. + // However the check is required in other cases. if (entry != null) { val data = if (entry.deserialized) { Left(entry.value.asInstanceOf[ArrayBuffer[Any]]) From e427a9eeb8d6b5def3a5ff1b766458588d8b05a9 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 13 Feb 2014 19:14:31 -0800 Subject: [PATCH 03/37] Added ContextCleaner to automatically clean RDDs and shuffles when they fall out of scope. Also replaced TimeStampedHashMap to BoundedHashMaps and TimeStampedWeakValueHashMap for the necessary hashmap behavior. --- .../org/apache/spark/ContextCleaner.scala | 126 +++++++++++ .../scala/org/apache/spark/Dependency.scala | 6 + .../org/apache/spark/MapOutputTracker.scala | 64 ++++-- .../scala/org/apache/spark/SparkContext.scala | 12 +- .../scala/org/apache/spark/SparkEnv.scala | 2 +- .../main/scala/org/apache/spark/rdd/RDD.scala | 10 + .../apache/spark/scheduler/ResultTask.scala | 12 +- .../spark/scheduler/ShuffleMapTask.scala | 17 +- .../spark/storage/BlockManagerMaster.scala | 15 ++ .../storage/BlockManagerMasterActor.scala | 12 + .../spark/storage/BlockManagerMessages.scala | 3 + .../storage/BlockManagerSlaveActor.scala | 3 + .../spark/storage/DiskBlockManager.scala | 5 + .../spark/storage/ShuffleBlockManager.scala | 33 ++- .../apache/spark/util/BoundedHashMap.scala | 45 ++++ .../apache/spark/util/MetadataCleaner.scala | 4 +- .../util/TimeStampedWeakValueHashMap.scala | 84 +++++++ .../spark/util/WrappedJavaHashMap.scala | 126 +++++++++++ .../apache/spark/ContextCleanerSuite.scala | 210 ++++++++++++++++++ .../apache/spark/MapOutputTrackerSuite.scala | 25 ++- .../spark/storage/DiskBlockManagerSuite.scala | 3 +- .../spark/util/WrappedJavaHashMapSuite.scala | 189 ++++++++++++++++ 22 files changed, 946 insertions(+), 60 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/ContextCleaner.scala create mode 100644 core/src/main/scala/org/apache/spark/util/BoundedHashMap.scala create mode 100644 core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala create mode 100644 core/src/main/scala/org/apache/spark/util/WrappedJavaHashMap.scala create mode 100644 core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala new file mode 100644 index 0000000000000..1cc4271f8cf33 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} + +import java.util.concurrent.{ArrayBlockingQueue, TimeUnit} + +import org.apache.spark.rdd.RDD + +/** Listener class used for testing when any item has been cleaned by the Cleaner class */ +private[spark] trait CleanerListener { + def rddCleaned(rddId: Int) + def shuffleCleaned(shuffleId: Int) +} + +/** + * Cleans RDDs and shuffle data. This should be instantiated only on the driver. + */ +private[spark] class ContextCleaner(env: SparkEnv) extends Logging { + + /** Classes to represent cleaning tasks */ + private sealed trait CleaningTask + private case class CleanRDD(sc: SparkContext, id: Int) extends CleaningTask + private case class CleanShuffle(id: Int) extends CleaningTask + // TODO: add CleanBroadcast + + private val QUEUE_CAPACITY = 1000 + private val queue = new ArrayBlockingQueue[CleaningTask](QUEUE_CAPACITY) + + protected val listeners = new ArrayBuffer[CleanerListener] + with SynchronizedBuffer[CleanerListener] + + private val cleaningThread = new Thread() { override def run() { keepCleaning() }} + + private var stopped = false + + /** Start the cleaner */ + def start() { + cleaningThread.setDaemon(true) + cleaningThread.start() + } + + /** Stop the cleaner */ + def stop() { + synchronized { stopped = true } + cleaningThread.interrupt() + } + + /** Clean all data and metadata related to a RDD, including shuffle files and metadata */ + def cleanRDD(rdd: RDD[_]) { + enqueue(CleanRDD(rdd.sparkContext, rdd.id)) + logDebug("Enqueued RDD " + rdd + " for cleaning up") + } + + def cleanShuffle(shuffleId: Int) { + enqueue(CleanShuffle(shuffleId)) + logDebug("Enqueued shuffle " + shuffleId + " for cleaning up") + } + + def attachListener(listener: CleanerListener) { + listeners += listener + } + /** Enqueue a cleaning task */ + private def enqueue(task: CleaningTask) { + queue.put(task) + } + + /** Keep cleaning RDDs and shuffle data */ + private def keepCleaning() { + try { + while (!isStopped) { + val taskOpt = Option(queue.poll(100, TimeUnit.MILLISECONDS)) + if (taskOpt.isDefined) { + logDebug("Got cleaning task " + taskOpt.get) + taskOpt.get match { + case CleanRDD(sc, rddId) => doCleanRDD(sc, rddId) + case CleanShuffle(shuffleId) => doCleanShuffle(shuffleId) + } + } + } + } catch { + case ie: java.lang.InterruptedException => + if (!isStopped) logWarning("Cleaning thread interrupted") + } + } + + /** Perform RDD cleaning */ + private def doCleanRDD(sc: SparkContext, rddId: Int) { + logDebug("Cleaning rdd "+ rddId) + sc.env.blockManager.master.removeRdd(rddId, false) + sc.persistentRdds.remove(rddId) + listeners.foreach(_.rddCleaned(rddId)) + logInfo("Cleaned rdd "+ rddId) + } + + /** Perform shuffle cleaning */ + private def doCleanShuffle(shuffleId: Int) { + logDebug("Cleaning shuffle "+ shuffleId) + mapOutputTrackerMaster.unregisterShuffle(shuffleId) + blockManager.master.removeShuffle(shuffleId) + listeners.foreach(_.shuffleCleaned(shuffleId)) + logInfo("Cleaned shuffle " + shuffleId) + } + + private def mapOutputTrackerMaster = env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + + private def blockManager = env.blockManager + + private def isStopped = synchronized { stopped } +} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index cc30105940d1a..dba0604ab4866 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -52,6 +52,12 @@ class ShuffleDependency[K, V]( extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) { val shuffleId: Int = rdd.context.newShuffleId() + + override def finalize() { + if (rdd != null) { + rdd.sparkContext.cleaner.cleanShuffle(shuffleId) + } + } } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 30d182b008930..bf291bf71bb61 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -17,19 +17,19 @@ package org.apache.spark +import scala.Some +import scala.collection.mutable.{HashSet, Map} +import scala.concurrent.Await + import java.io._ import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.HashSet -import scala.concurrent.Await -import scala.concurrent.duration._ - import akka.actor._ import akka.pattern.ask import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.{AkkaUtils, MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils} +import org.apache.spark.util._ private[spark] sealed trait MapOutputTrackerMessage private[spark] case class GetMapOutputStatuses(shuffleId: Int) @@ -51,23 +51,21 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster } } -private[spark] class MapOutputTracker(conf: SparkConf) extends Logging { +private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging { private val timeout = AkkaUtils.askTimeout(conf) // Set to the MapOutputTrackerActor living on the driver var trackerActor: ActorRef = _ - protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] + /** This HashMap needs to have different storage behavior for driver and worker */ + protected val mapStatuses: Map[Int, Array[MapStatus]] // Incremented every time a fetch fails so that client nodes know to clear // their cache of map output locations if this happens. protected var epoch: Long = 0 protected val epochLock = new java.lang.Object - private val metadataCleaner = - new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup, conf) - // Send a message to the trackerActor and get its result within a default timeout, or // throw a SparkException if this fails. private def askTracker(message: Any): Any = { @@ -138,8 +136,7 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging { fetchedStatuses.synchronized { return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) } - } - else { + } else { throw new FetchFailedException(null, shuffleId, -1, reduceId, new Exception("Missing all output locations for shuffle " + shuffleId)) } @@ -151,13 +148,12 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging { } protected def cleanup(cleanupTime: Long) { - mapStatuses.clearOldValues(cleanupTime) + mapStatuses.asInstanceOf[TimeStampedHashMap[_, _]].clearOldValues(cleanupTime) } def stop() { communicate(StopMapOutputTracker) mapStatuses.clear() - metadataCleaner.cancel() trackerActor = null } @@ -182,15 +178,42 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging { } } +private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) { + + /** + * Bounded HashMap for storing serialized statuses in the worker. This allows + * the HashMap stay bounded in memory-usage. Things dropped from this HashMap will be + * automatically repopulated by fetching them again from the driver. + */ + protected val MAX_MAP_STATUSES = 100 + protected val mapStatuses = new BoundedHashMap[Int, Array[MapStatus]](MAX_MAP_STATUSES, true) +} + + private[spark] class MapOutputTrackerMaster(conf: SparkConf) extends MapOutputTracker(conf) { // Cache a serialized version of the output statuses for each shuffle to send them out faster private var cacheEpoch = epoch - private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]] + + /** + * Timestamp based HashMap for storing mapStatuses in the master, so that statuses are dropped + * only by explicit deregistering or by ttl-based cleaning (if set). Other than these two + * scenarios, nothing should be dropped from this HashMap. + */ + protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]() + + /** + * Bounded HashMap for storing serialized statuses in the master. This allows + * the HashMap stay bounded in memory-usage. Things dropped from this HashMap will be + * automatically repopulated by serializing the lost statuses again . + */ + protected val MAX_SERIALIZED_STATUSES = 100 + private val cachedSerializedStatuses = + new BoundedHashMap[Int, Array[Byte]](MAX_SERIALIZED_STATUSES, true) def registerShuffle(shuffleId: Int, numMaps: Int) { - if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) { + if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } } @@ -224,6 +247,10 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) } } + def unregisterShuffle(shuffleId: Int) { + mapStatuses.remove(shuffleId) + } + def incrementEpoch() { epochLock.synchronized { epoch += 1 @@ -260,9 +287,8 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) bytes } - protected override def cleanup(cleanupTime: Long) { - super.cleanup(cleanupTime) - cachedSerializedStatuses.clearOldValues(cleanupTime) + def contains(shuffleId: Int): Boolean = { + mapStatuses.contains(shuffleId) } override def stop() { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 566472e597958..9fab2a7e0c707 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -48,8 +48,10 @@ import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, Me import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{Utils, TimeStampedHashMap, MetadataCleaner, MetadataCleanerType, - ClosureCleaner} +import org.apache.spark.util._ +import scala.Some +import org.apache.spark.storage.RDDInfo +import org.apache.spark.storage.StorageStatus /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -150,7 +152,7 @@ class SparkContext( private[spark] val addedJars = HashMap[String, Long]() // Keeps track of all persisted RDDs - private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]] + private[spark] val persistentRdds = new TimeStampedWeakValueHashMap[Int, RDD[_]] private[spark] val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, conf) @@ -202,6 +204,9 @@ class SparkContext( @volatile private[spark] var dagScheduler = new DAGScheduler(taskScheduler) dagScheduler.start() + private[spark] val cleaner = new ContextCleaner(env) + cleaner.start() + ui.start() /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ @@ -784,6 +789,7 @@ class SparkContext( dagScheduler = null if (dagSchedulerCopy != null) { metadataCleaner.cancel() + cleaner.stop() dagSchedulerCopy.stop() taskScheduler = null // TODO: Cache.stop()? diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index ed788560e79f1..23dbe18fd2576 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -181,7 +181,7 @@ object SparkEnv extends Logging { val mapOutputTracker = if (isDriver) { new MapOutputTrackerMaster(conf) } else { - new MapOutputTracker(conf) + new MapOutputTrackerWorker(conf) } mapOutputTracker.trackerActor = registerOrLookup( "MapOutputTracker", diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 8010bb68e31dd..37168d7cd5969 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1012,6 +1012,13 @@ abstract class RDD[T: ClassTag]( checkpointData.flatMap(_.getCheckpointFile) } + def cleanup() { + sc.cleaner.cleanRDD(this) + dependencies.filter(_.isInstanceOf[ShuffleDependency[_, _]]) + .map(_.asInstanceOf[ShuffleDependency[_, _]].shuffleId) + .foreach(sc.cleaner.cleanShuffle) + } + // ======================================================================= // Other internal methods and fields // ======================================================================= @@ -1091,4 +1098,7 @@ abstract class RDD[T: ClassTag]( new JavaRDD(this)(elementClassTag) } + override def finalize() { + cleanup() + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 28f3ba53b8425..671faf42a9278 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -21,20 +21,16 @@ import java.io._ import java.util.zip.{GZIPInputStream, GZIPOutputStream} import org.apache.spark._ -import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.RDDCheckpointData -import org.apache.spark.util.{MetadataCleanerType, MetadataCleaner, TimeStampedHashMap} +import org.apache.spark.rdd.{RDD, RDDCheckpointData} +import org.apache.spark.util.BoundedHashMap private[spark] object ResultTask { // A simple map between the stage id to the serialized byte array of a task. // Served as a cache for task serialization because serialization can be // expensive on the master node if it needs to launch thousands of tasks. - val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] - - // TODO: This object shouldn't have global variables - val metadataCleaner = new MetadataCleaner( - MetadataCleanerType.RESULT_TASK, serializedInfoCache.clearOldValues, new SparkConf) + val MAX_CACHE_SIZE = 100 + val serializedInfoCache = new BoundedHashMap[Int, Array[Byte]](MAX_CACHE_SIZE, true) def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = { synchronized { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index a37ead563271a..df3a7b9ee37ad 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -17,29 +17,24 @@ package org.apache.spark.scheduler +import scala.collection.mutable.HashMap + import java.io._ import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.HashMap - import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.rdd.{RDD, RDDCheckpointData} import org.apache.spark.storage._ -import org.apache.spark.util.{MetadataCleanerType, TimeStampedHashMap, MetadataCleaner} -import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.RDDCheckpointData - +import org.apache.spark.util.BoundedHashMap private[spark] object ShuffleMapTask { // A simple map between the stage id to the serialized byte array of a task. // Served as a cache for task serialization because serialization can be // expensive on the master node if it needs to launch thousands of tasks. - val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] - - // TODO: This object shouldn't have global variables - val metadataCleaner = new MetadataCleaner( - MetadataCleanerType.SHUFFLE_MAP_TASK, serializedInfoCache.clearOldValues, new SparkConf) + val MAX_CACHE_SIZE = 100 + val serializedInfoCache = new BoundedHashMap[Int, Array[Byte]](MAX_CACHE_SIZE, true) def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = { synchronized { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index c54e4f2664753..55d8349ea9d2c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -82,6 +82,14 @@ class BlockManagerMaster(var driverActor : ActorRef, conf: SparkConf) extends Lo askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) } + /** + * Check if block manager master has a block. Note that this can be used to check for only + * those blocks that are expected to be reported to block manager master. + */ + def contains(blockId: BlockId) = { + !getLocations(blockId).isEmpty + } + /** Get ids of other nodes in the cluster from the driver */ def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = { val result = askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers)) @@ -113,6 +121,13 @@ class BlockManagerMaster(var driverActor : ActorRef, conf: SparkConf) extends Lo } } + /** + * Remove all blocks belonging to the given shuffle. + */ + def removeShuffle(shuffleId: Int) { + askDriverWithReply(RemoveShuffle(shuffleId)) + } + /** * Return the memory status for each block manager, in the form of a map from * the block manager's id to two long values. The first value is the maximum diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 2c1a4e2f5d3a1..8b972672c8117 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -95,6 +95,10 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf) extends Act case RemoveRdd(rddId) => sender ! removeRdd(rddId) + case RemoveShuffle(shuffleId) => + removeShuffle(shuffleId) + sender ! true + case RemoveBlock(blockId) => removeBlockFromWorkers(blockId) sender ! true @@ -143,6 +147,14 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf) extends Act }.toSeq) } + private def removeShuffle(shuffleId: Int) { + // Nothing to do in the BlockManagerMasterActor data structures + val removeMsg = RemoveShuffle(shuffleId) + blockManagerInfo.values.map { bm => + bm.slaveActor ! removeMsg + } + } + private def removeBlockManager(blockManagerId: BlockManagerId) { val info = blockManagerInfo(blockManagerId) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 45f51da288548..98a3b68748ada 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -35,6 +35,9 @@ private[storage] object BlockManagerMessages { // Remove all blocks belonging to a specific RDD. case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave + // Remove all blocks belonging to a specific shuffle. + case class RemoveShuffle(shuffleId: Int) + ////////////////////////////////////////////////////////////////////////////////// // Messages from slaves to the master. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 3a65e55733834..eeeee07ebb722 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -36,5 +36,8 @@ class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor { case RemoveRdd(rddId) => val numBlocksRemoved = blockManager.removeRdd(rddId) sender ! numBlocksRemoved + + case RemoveShuffle(shuffleId) => + blockManager.shuffleBlockManager.removeShuffle(shuffleId) } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index f3e1c38744d78..cdee285a1cbd4 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -90,6 +90,11 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD def getFile(blockId: BlockId): File = getFile(blockId.name) + /** Check if disk block manager has a block */ + def contains(blockId: BlockId): Boolean = { + getBlockLocation(blockId).file.exists() + } + /** Produces a unique block id and File suitable for intermediate results. */ def createTempBlock(): (TempBlockId, File) = { var blockId = new TempBlockId(UUID.randomUUID()) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index bb07c8cb134cc..ed03f189fb4ac 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -169,23 +169,32 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { throw new IllegalStateException("Failed to find shuffle block: " + id) } + /** Remove all the blocks / files related to a particular shuffle */ + def removeShuffle(shuffleId: ShuffleId) { + shuffleStates.get(shuffleId) match { + case Some(state) => + if (consolidateShuffleFiles) { + for (fileGroup <- state.allFileGroups; file <- fileGroup.files) { + file.delete() + } + } else { + for (mapId <- state.completedMapTasks; reduceId <- 0 until state.numBuckets) { + val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) + blockManager.diskBlockManager.getFile(blockId).delete() + } + } + logInfo("Deleted all files for shuffle " + shuffleId) + case None => + logInfo("Could not find files for shuffle " + shuffleId + " for deleting") + } + } + private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = { "merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId) } private def cleanup(cleanupTime: Long) { - shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => { - if (consolidateShuffleFiles) { - for (fileGroup <- state.allFileGroups; file <- fileGroup.files) { - file.delete() - } - } else { - for (mapId <- state.completedMapTasks; reduceId <- 0 until state.numBuckets) { - val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) - blockManager.diskBlockManager.getFile(blockId).delete() - } - } - }) + shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffle(shuffleId)) } } diff --git a/core/src/main/scala/org/apache/spark/util/BoundedHashMap.scala b/core/src/main/scala/org/apache/spark/util/BoundedHashMap.scala new file mode 100644 index 0000000000000..0095b8a38d7b6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/BoundedHashMap.scala @@ -0,0 +1,45 @@ +package org.apache.spark.util + +import scala.collection.mutable.{ArrayBuffer, SynchronizedMap} + +import java.util.{Collections, LinkedHashMap} +import java.util.Map.{Entry => JMapEntry} +import scala.reflect.ClassTag + +/** + * A map that bounds the number of key-value pairs present in it. It can be configured to + * drop least recently inserted or used pair. It exposes a scala.collection.mutable.Map interface + * to allow it to be a drop-in replacement of Scala HashMaps. Internally, a Java LinkedHashMap is + * used to get insert-order or access-order behavior. Note that the LinkedHashMap is not + * thread-safe and hence, it is wrapped in a Collections.synchronizedMap. + * However, getting the Java HashMap's iterator and using it can still lead to + * ConcurrentModificationExceptions. Hence, the iterator() function is overridden to copy the + * all pairs into an ArrayBuffer and then return the iterator to the ArrayBuffer. Also, + * the class apply the trait SynchronizedMap which ensures that all calls to the Scala Map API + * are synchronized. This together ensures that ConcurrentModificationException is never thrown. + * @param bound max number of key-value pairs + * @param useLRU true = least recently used/accessed will be dropped when bound is reached, + * false = earliest inserted will be dropped + */ +private[spark] class BoundedHashMap[A, B](bound: Int, useLRU: Boolean) + extends WrappedJavaHashMap[A, B, A, B] with SynchronizedMap[A, B] { + + protected[util] val internalJavaMap = Collections.synchronizedMap(new LinkedHashMap[A, B]( + bound / 8, (0.75).toFloat, useLRU) { + override protected def removeEldestEntry(eldest: JMapEntry[A, B]): Boolean = { + size() > bound + } + }) + + protected[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = { + new BoundedHashMap[K1, V1](bound, useLRU) + } + + /** + * Overriding iterator to make sure that the internal Java HashMap's iterator + * is not concurrently modified. + */ + override def iterator: Iterator[(A, B)] = { + (new ArrayBuffer[(A, B)] ++= super.iterator).iterator + } +} diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index b0febe906ade3..1953e4cd2b59e 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -62,8 +62,8 @@ private[spark] class MetadataCleaner( private[spark] object MetadataCleanerType extends Enumeration { - val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK, - SHUFFLE_MAP_TASK, BLOCK_MANAGER, SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value + val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, BLOCK_MANAGER, + SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS, CLEANER = Value type MetadataCleanerType = Value diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala new file mode 100644 index 0000000000000..43848def0ffe6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala @@ -0,0 +1,84 @@ +package org.apache.spark.util + +import scala.collection.{JavaConversions, immutable} + +import java.util +import java.lang.ref.WeakReference +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.Logging + +private[util] case class TimeStampedWeakValue[T](timestamp: Long, weakValue: WeakReference[T]) { + def this(timestamp: Long, value: T) = this(timestamp, new WeakReference[T](value)) +} + + +private[spark] class TimeStampedWeakValueHashMap[A, B] + extends WrappedJavaHashMap[A, B, A, TimeStampedWeakValue[B]] with Logging { + + protected[util] val internalJavaMap: util.Map[A, TimeStampedWeakValue[B]] = { + new ConcurrentHashMap[A, TimeStampedWeakValue[B]]() + } + + protected[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = { + new TimeStampedWeakValueHashMap[K1, V1]() + } + + override def get(key: A): Option[B] = { + Option(internalJavaMap.get(key)) match { + case Some(weakValue) => + val value = weakValue.weakValue.get + if (value == null) cleanupKey(key) + Option(value) + case None => + None + } + } + + @inline override protected def externalValueToInternalValue(v: B): TimeStampedWeakValue[B] = { + new TimeStampedWeakValue(currentTime, v) + } + + @inline override protected def internalValueToExternalValue(iv: TimeStampedWeakValue[B]): B = { + iv.weakValue.get + } + + override def iterator: Iterator[(A, B)] = { + val jIterator = internalJavaMap.entrySet().iterator() + JavaConversions.asScalaIterator(jIterator).flatMap(kv => { + val key = kv.getKey + val value = kv.getValue.weakValue.get + if (value == null) { + cleanupKey(key) + Seq.empty + } else { + Seq((key, value)) + } + }) + } + + /** + * Removes old key-value pairs that have timestamp earlier than `threshTime`, + * calling the supplied function on each such entry before removing. + */ + def clearOldValues(threshTime: Long, f: (A, B) => Unit = null) { + val iterator = internalJavaMap.entrySet().iterator() + while (iterator.hasNext) { + val entry = iterator.next() + if (entry.getValue.timestamp < threshTime) { + val value = entry.getValue.weakValue.get + if (f != null && value != null) { + f(entry.getKey, value) + } + logDebug("Removing key " + entry.getKey) + iterator.remove() + } + } + } + + private def cleanupKey(key: A) { + // TODO: Consider cleaning up keys to empty weak ref values automatically in future. + } + + private def currentTime = System.currentTimeMillis() +} diff --git a/core/src/main/scala/org/apache/spark/util/WrappedJavaHashMap.scala b/core/src/main/scala/org/apache/spark/util/WrappedJavaHashMap.scala new file mode 100644 index 0000000000000..e7c66e494678b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/WrappedJavaHashMap.scala @@ -0,0 +1,126 @@ +package org.apache.spark.util + +import scala.collection.mutable.Map +import java.util.{Map => JMap} +import java.util.Map.{Entry => JMapEntry} +import scala.collection.{immutable, JavaConversions} +import scala.reflect.ClassTag + +/** + * Convenient wrapper class for exposing Java HashMaps as Scala Maps even if the + * exposed key-value type is different from the internal type. This allows Scala HashMaps to be + * hot replaceable with these Java HashMaps. + * + * While Java <-> Scala conversion methods exists, its hard to understand the performance + * implications and thread safety of the Scala wrapper. This class allows you to convert + * between types and applying the necessary overridden methods to take care of performance. + * + * Note that the threading behavior of an implementation of WrappedJavaHashMap is tied to that of + * the internal Java HashMap used in the implementation. Each implementation must use + * necessary traits (e.g, scala.collection.mutable.SynchronizedMap), etc. to achieve the + * desired thread safety. + * + * @tparam K External key type + * @tparam V External value type + * @tparam IK Internal key type + * @tparam IV Internal value type + */ +private[spark] abstract class WrappedJavaHashMap[K, V, IK, IV] extends Map[K, V] { + + /* Methods that must be defined. */ + + /** Internal Java HashMap that is being wrapped. */ + protected[util] val internalJavaMap: JMap[IK, IV] + + /** Method to get a new instance of the internal Java HashMap. */ + protected[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] + + /* + Methods that convert between internal and external types. These implementations + optimistically assume that the internal types are same as external types. These must + be overridden if the internal and external types are different. Otherwise there will be + runtime exceptions. + */ + + @inline protected def externalKeyToInternalKey(k: K): IK = { + k.asInstanceOf[IK] // works only if K is same or subclass of K + } + + @inline protected def externalValueToInternalValue(v: V): IV = { + v.asInstanceOf[IV] // works only if V is same or subclass of + } + + @inline protected def internalKeyToExternalKey(ik: IK): K = { + ik.asInstanceOf[K] + } + + @inline protected def internalValueToExternalValue(iv: IV): V = { + iv.asInstanceOf[V] + } + + @inline protected def internalPairToExternalPair(ip: JMapEntry[IK, IV]): (K, V) = { + (internalKeyToExternalKey(ip.getKey), internalValueToExternalValue(ip.getValue) ) + } + + /* Implicit functions to convert the types. */ + + @inline implicit private def convExtKeyToIntKey(k: K) = externalKeyToInternalKey(k) + + @inline implicit private def convExtValueToIntValue(v: V) = externalValueToInternalValue(v) + + @inline implicit private def convIntKeyToExtKey(ia: IK) = internalKeyToExternalKey(ia) + + @inline implicit private def convIntValueToExtValue(ib: IV) = internalValueToExternalValue(ib) + + @inline implicit private def convIntPairToExtPair(ip: JMapEntry[IK, IV]) = { + internalPairToExternalPair(ip) + } + + def get(key: K): Option[V] = { + Option(internalJavaMap.get(key)) + } + + def iterator: Iterator[(K, V)] = { + val jIterator = internalJavaMap.entrySet().iterator() + JavaConversions.asScalaIterator(jIterator).map(kv => convIntPairToExtPair(kv)) + } + + def +=(kv: (K, V)): this.type = { + internalJavaMap.put(kv._1, kv._2) + this + } + + def -=(key: K): this.type = { + internalJavaMap.remove(key) + this + } + + override def + [V1 >: V](kv: (K, V1)): Map[K, V1] = { + val newMap = newInstance[K, V1]() + newMap.internalJavaMap.asInstanceOf[JMap[IK, IV]].putAll(this.internalJavaMap) + newMap += kv + newMap + } + + override def - (key: K): Map[K, V] = { + val newMap = newInstance[K, V]() + newMap.internalJavaMap.asInstanceOf[JMap[IK, IV]].putAll(this.internalJavaMap) + newMap -= key + } + + override def foreach[U](f: ((K, V)) => U) { + while(iterator.hasNext) { + f(iterator.next()) + } + } + + override def empty: Map[K, V] = newInstance[K, V]() + + override def size: Int = internalJavaMap.size + + override def filter(p: ((K, V)) => Boolean): Map[K, V] = { + newInstance[K, V]() ++= iterator.filter(p) + } + + def toMap: immutable.Map[K, V] = iterator.toMap +} diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala new file mode 100644 index 0000000000000..2ec314aa632f3 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -0,0 +1,210 @@ +package org.apache.spark + +import scala.collection.mutable.{ArrayBuffer, HashSet, SynchronizedSet} + +import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.concurrent.Eventually +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.SparkContext._ +import org.apache.spark.storage.{RDDBlockId, ShuffleBlockId} +import org.apache.spark.rdd.RDD +import scala.util.Random +import java.lang.ref.WeakReference + +class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { + + implicit val defaultTimeout = timeout(10000 millis) + + before { + sc = new SparkContext("local[2]", "CleanerSuite") + } + + test("cleanup RDD") { + val rdd = newRDD.persist() + rdd.count() + val tester = new CleanerTester(sc, rddIds = Seq(rdd.id)) + cleaner.cleanRDD(rdd) + tester.assertCleanup + } + + test("cleanup shuffle") { + val rdd = newShuffleRDD + rdd.count() + val tester = new CleanerTester(sc, shuffleIds = Seq(0)) + cleaner.cleanShuffle(0) + tester.assertCleanup + } + + test("automatically cleanup RDD") { + var rdd = newRDD.persist() + rdd.count() + + // test that GC does not cause RDD cleanup due to a strong reference + val preGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id)) + doGC() + intercept[Exception] { + preGCTester.assertCleanup(timeout(1000 millis)) + } + + // test that GC causes RDD cleanup after dereferencing the RDD + val postGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id)) + rdd = null // make RDD out of scope + doGC() + postGCTester.assertCleanup + } + + test("automatically cleanup shuffle") { + var rdd = newShuffleRDD + rdd.count() + + // test that GC does not cause shuffle cleanup due to a strong reference + val preGCTester = new CleanerTester(sc, shuffleIds = Seq(0)) + doGC() + intercept[Exception] { + preGCTester.assertCleanup(timeout(1000 millis)) + } + + // test that GC causes shuffle cleanup after dereferencing the RDD + val postGCTester = new CleanerTester(sc, shuffleIds = Seq(0)) + rdd = null // make RDD out of scope, so that corresponding shuffle goes out of scope + doGC() + postGCTester.assertCleanup + } + + test("automatically cleanup RDD + shuffle") { + + def randomRDD: RDD[_] = { + val rdd: RDD[_] = Random.nextInt(3) match { + case 0 => newRDD + case 1 => newShuffleRDD + case 2 => newPairRDD.join(newPairRDD) + } + if (Random.nextBoolean()) rdd.persist() + rdd.count() + rdd + } + + val buffer = new ArrayBuffer[RDD[_]] + for (i <- 1 to 1000) { + buffer += randomRDD + } + + val rddIds = sc.persistentRdds.keys.toSeq + val shuffleIds = 0 until sc.newShuffleId + + val preGCTester = new CleanerTester(sc, rddIds, shuffleIds) + intercept[Exception] { + preGCTester.assertCleanup(timeout(1000 millis)) + } + + // test that GC causes shuffle cleanup after dereferencing the RDD + val postGCTester = new CleanerTester(sc, rddIds, shuffleIds) + buffer.clear() + doGC() + postGCTester.assertCleanup + } + + def newRDD = sc.makeRDD(1 to 10) + + def newPairRDD = newRDD.map(_ -> 1) + + def newShuffleRDD = newPairRDD.reduceByKey(_ + _) + + def doGC() { + val weakRef = new WeakReference(new Object()) + val startTime = System.currentTimeMillis + System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC. + System.runFinalization() // Make a best effort to call finalizer on all cleaned objects. + while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { + System.gc() + System.runFinalization() + Thread.sleep(100) + } + } + + def cleaner = sc.cleaner +} + + +/** Class to test whether RDDs, shuffles, etc. have been successfully cleaned. */ +class CleanerTester(sc: SparkContext, rddIds: Seq[Int] = Nil, shuffleIds: Seq[Int] = Nil) + extends Logging { + + val toBeCleanedRDDIds = new HashSet[Int] with SynchronizedSet[Int] ++= rddIds + val toBeCleanedShuffleIds = new HashSet[Int] with SynchronizedSet[Int] ++= shuffleIds + + val cleanerListener = new CleanerListener { + def rddCleaned(rddId: Int): Unit = { + toBeCleanedRDDIds -= rddId + logInfo("RDD "+ rddId + " cleaned") + } + + def shuffleCleaned(shuffleId: Int): Unit = { + toBeCleanedShuffleIds -= shuffleId + logInfo("Shuffle " + shuffleId + " cleaned") + } + } + + logInfo("Attempting to validate before cleanup:\n" + uncleanedResourcesToString) + preCleanupValidate() + sc.cleaner.attachListener(cleanerListener) + + def assertCleanup(implicit waitTimeout: Eventually.Timeout) { + try { + eventually(waitTimeout, interval(10 millis)) { + assert(isAllCleanedUp) + } + Thread.sleep(100) // to allow async cleanup actions to be completed + postCleanupValidate() + } finally { + logInfo("Resources left from cleaning up:\n" + uncleanedResourcesToString) + } + } + + private def preCleanupValidate() { + assert(rddIds.nonEmpty || shuffleIds.nonEmpty, "Nothing to cleanup") + + // Verify the RDDs have been persisted and blocks are present + assert(rddIds.forall(sc.persistentRdds.contains), + "One or more RDDs have not been persisted, cannot start cleaner test") + assert(rddIds.forall(rddId => blockManager.master.contains(rddBlockId(rddId))), + "One or more RDDs' blocks cannot be found in block manager, cannot start cleaner test") + + // Verify the shuffle ids are registered and blocks are present + assert(shuffleIds.forall(mapOutputTrackerMaster.contains), + "One or more shuffles have not been registered cannot start cleaner test") + assert(shuffleIds.forall(shuffleId => diskBlockManager.contains(shuffleBlockId(shuffleId))), + "One or more shuffles' blocks cannot be found in disk manager, cannot start cleaner test") + } + + private def postCleanupValidate() { + // Verify all the RDDs have been persisted + assert(rddIds.forall(!sc.persistentRdds.contains(_))) + assert(rddIds.forall(rddId => !blockManager.master.contains(rddBlockId(rddId)))) + + // Verify all the shuffle have been deregistered and cleaned up + assert(shuffleIds.forall(!mapOutputTrackerMaster.contains(_))) + assert(shuffleIds.forall(shuffleId => !diskBlockManager.contains(shuffleBlockId(shuffleId)))) + } + + private def uncleanedResourcesToString = { + s""" + |\tRDDs = ${toBeCleanedRDDIds.mkString("[", ", ", "]")} + |\tShuffles = ${toBeCleanedShuffleIds.mkString("[", ", ", "]")} + """.stripMargin + } + + private def isAllCleanedUp = toBeCleanedRDDIds.isEmpty && toBeCleanedShuffleIds.isEmpty + + private def shuffleBlockId(shuffleId: Int) = ShuffleBlockId(shuffleId, 0, 0) + + private def rddBlockId(rddId: Int) = RDDBlockId(rddId, 0) + + private def blockManager = sc.env.blockManager + + private def diskBlockManager = blockManager.diskBlockManager + + private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] +} \ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 930c2523caf8c..7675a47552ba4 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -54,11 +54,12 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { tracker.stop() } - test("master register and fetch") { + test("master register shuffle and fetch") { val actorSystem = ActorSystem("test") val tracker = new MapOutputTrackerMaster(conf) tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker))) tracker.registerShuffle(10, 2) + assert(tracker.contains(10)) val compressedSize1000 = MapOutputTracker.compressSize(1000L) val compressedSize10000 = MapOutputTracker.compressSize(10000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) @@ -73,7 +74,25 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { tracker.stop() } - test("master register and unregister and fetch") { + test("master register and unregister shuffle") { + val actorSystem = ActorSystem("test") + val tracker = new MapOutputTrackerMaster(conf) + tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker))) + tracker.registerShuffle(10, 2) + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val compressedSize10000 = MapOutputTracker.compressSize(10000L) + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0), + Array(compressedSize1000, compressedSize10000))) + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0), + Array(compressedSize10000, compressedSize1000))) + assert(tracker.contains(10)) + assert(tracker.getServerStatuses(10, 0).nonEmpty) + tracker.unregisterShuffle(10) + assert(!tracker.contains(10)) + assert(tracker.getServerStatuses(10, 0).isEmpty) + } + + test("master register shuffle and unregister mapoutput and fetch") { val actorSystem = ActorSystem("test") val tracker = new MapOutputTrackerMaster(conf) tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker))) @@ -105,7 +124,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker") val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf) - val slaveTracker = new MapOutputTracker(conf) + val slaveTracker = new MapOutputTrackerWorker(conf) val selection = slaveSystem.actorSelection( s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") val timeout = AkkaUtils.lookupTimeout(conf) diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index 829f389460f3b..d3d22bc1d6c0e 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -58,8 +58,9 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach { val newFile = diskBlockManager.getFile(blockId) writeToFile(newFile, 10) assertSegmentEquals(blockId, blockId.name, 0, 10) - + assert(diskBlockManager.contains(blockId)) newFile.delete() + assert(!diskBlockManager.contains(blockId)) } test("block appending") { diff --git a/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala new file mode 100644 index 0000000000000..9fc4681b524e9 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala @@ -0,0 +1,189 @@ +package org.apache.spark.util + +import scala.collection.mutable.{HashMap, Map} + +import java.util + +import org.scalatest.FunSuite +import scala.util.Random +import java.lang.ref.WeakReference + +class WrappedJavaHashMapSuite extends FunSuite { + + // Test the testMap function - a Scala HashMap should obviously pass + testMap(new HashMap[String, String]()) + + // Test a simple WrappedJavaHashMap + testMap(new TestMap[String, String]()) + + // Test BoundedHashMap + testMap(new BoundedHashMap[String, String](100, true)) + + testMapThreadSafety(new BoundedHashMap[String, String](100, true)) + + // Test TimeStampedHashMap + testMap(new TimeStampedHashMap[String, String]) + + testMapThreadSafety(new TimeStampedHashMap[String, String]) + + test("TimeStampedHashMap - clearing by timestamp") { + // clearing by insertion time + val map = new TimeStampedHashMap[String, String](false) + map("k1") = "v1" + assert(map("k1") === "v1") + Thread.sleep(10) + val threshTime = System.currentTimeMillis() + assert(map.internalMap.get("k1")._2 < threshTime) + map.clearOldValues(threshTime) + assert(map.get("k1") === None) + + // clearing by modification time + val map1 = new TimeStampedHashMap[String, String](true) + map1("k1") = "v1" + map1("k2") = "v2" + assert(map1("k1") === "v1") + Thread.sleep(10) + val threshTime1 = System.currentTimeMillis() + Thread.sleep(10) + assert(map1("k2") === "v2") // access k2 to update its access time to > threshTime + assert(map1.internalMap.get("k1")._2 < threshTime1) + assert(map1.internalMap.get("k2")._2 >= threshTime1) + map1.clearOldValues(threshTime1) //should only clear k1 + assert(map1.get("k1") === None) + assert(map1.get("k2").isDefined) + } + + // Test TimeStampedHashMap + testMap(new TimeStampedWeakValueHashMap[String, String]) + + testMapThreadSafety(new TimeStampedWeakValueHashMap[String, String]) + + test("TimeStampedWeakValueHashMap - clearing by timestamp") { + // clearing by insertion time + val map = new TimeStampedWeakValueHashMap[String, String]() + map("k1") = "v1" + assert(map("k1") === "v1") + Thread.sleep(10) + val threshTime = System.currentTimeMillis() + assert(map.internalJavaMap.get("k1").timestamp < threshTime) + map.clearOldValues(threshTime) + assert(map.get("k1") === None) + } + + + test("TimeStampedWeakValueHashMap - get not returning null when weak reference is cleared") { + var strongRef = new Object + val weakRef = new WeakReference(strongRef) + val map = new TimeStampedWeakValueHashMap[String, Object] + + map("k1") = strongRef + assert(map("k1") === strongRef) + + strongRef = null + val startTime = System.currentTimeMillis + System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC. + System.runFinalization() // Make a best effort to call finalizer on all cleaned objects. + while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { + System.gc() + System.runFinalization() + Thread.sleep(100) + } + assert(map.internalJavaMap.get("k1").weakValue.get == null) + assert(map.get("k1") === None) + } + + def testMap(hashMapConstructor: => Map[String, String]) { + def newMap() = hashMapConstructor + + val name = newMap().getClass.getSimpleName + + test(name + " - basic test") { + val testMap1 = newMap() + + // put and get + testMap1 += (("k1", "v1")) + assert(testMap1.get("k1").get === "v1") + testMap1("k2") = "v2" + assert(testMap1.get("k2").get === "v2") + assert(testMap1("k2") === "v2") + + // remove + testMap1.remove("k1") + assert(testMap1.get("k1").isEmpty) + testMap1.remove("k2") + intercept[Exception] { + testMap1("k2") // Map.apply() causes exception + } + + // multi put + val keys = (1 to 100).map(_.toString) + val pairs = keys.map(x => (x, x * 2)) + val testMap2 = newMap() + assert((testMap2 ++ pairs).iterator.toSet === pairs.toSet) + testMap2 ++= pairs + + // iterator + assert(testMap2.iterator.toSet === pairs.toSet) + testMap2("k1") = "v1" + + // multi remove + testMap2 --= keys + assert(testMap2.size === 1) + assert(testMap2.iterator.toSeq.head === ("k1", "v1")) + + // new instance + } + } + + def testMapThreadSafety(hashMapConstructor: => Map[String, String]) { + def newMap() = hashMapConstructor + + val name = newMap().getClass.getSimpleName + val testMap = newMap() + @volatile var error = false + + def getRandomKey(m: Map[String, String]): Option[String] = { + val keys = testMap.keysIterator.toSeq + if (keys.nonEmpty) { + Some(keys(Random.nextInt(keys.size))) + } else { + None + } + } + + val threads = (1 to 100).map(i => new Thread() { + override def run() { + try { + for (j <- 1 to 1000) { + Random.nextInt(3) match { + case 0 => + testMap(Random.nextString(10)) = Random.nextDouble.toString // put + case 1 => + getRandomKey(testMap).map(testMap.get) // get + case 2 => + getRandomKey(testMap).map(testMap.remove) // remove + } + } + } catch { + case t : Throwable => + error = true + throw t + } + } + }) + + test(name + " - threading safety test") { + threads.map(_.start) + threads.map(_.join) + assert(!error) + } + } +} + +class TestMap[A, B] extends WrappedJavaHashMap[A, B, A, B] { + protected[util] val internalJavaMap: util.Map[A, B] = new util.HashMap[A, B]() + + protected[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = { + new TestMap[K1, V1] + } +} \ No newline at end of file From 8512612036011b5cf688a1643d0d46f144a0f15e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 14 Feb 2014 00:01:04 -0800 Subject: [PATCH 04/37] Changed TimeStampedHashMap to use WrappedJavaHashMap. --- .../apache/spark/storage/BlockManager.scala | 3 +- .../apache/spark/util/BoundedHashMap.scala | 44 +++++-- .../spark/util/TimeStampedHashMap.scala | 113 ++++++------------ .../util/TimeStampedWeakValueHashMap.scala | 30 ++++- .../spark/util/WrappedJavaHashMap.scala | 32 ++++- .../spark/util/WrappedJavaHashMapSuite.scala | 36 ++++-- 6 files changed, 154 insertions(+), 104 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index ed53558566edf..a84c78fd0c027 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -225,6 +225,7 @@ private[spark] class BlockManager( * the slave needs to re-register. */ private def tryToReportBlockStatus(blockId: BlockId, info: BlockInfo, droppedMemorySize: Long = 0L): Boolean = { + logInfo("Reporting " + blockId) val (curLevel, inMemSize, onDiskSize, tellMaster) = info.synchronized { info.level match { case null => @@ -770,7 +771,7 @@ private[spark] class BlockManager( val iterator = blockInfo.internalMap.entrySet().iterator() while (iterator.hasNext) { val entry = iterator.next() - val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2) + val (id, info, time) = (entry.getKey, entry.getValue.value, entry.getValue.timestamp) if (time < cleanupTime && shouldDrop(id)) { info.synchronized { val level = info.level diff --git a/core/src/main/scala/org/apache/spark/util/BoundedHashMap.scala b/core/src/main/scala/org/apache/spark/util/BoundedHashMap.scala index 0095b8a38d7b6..c4f7df1ee0a7b 100644 --- a/core/src/main/scala/org/apache/spark/util/BoundedHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/BoundedHashMap.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.util import scala.collection.mutable.{ArrayBuffer, SynchronizedMap} @@ -7,16 +24,20 @@ import java.util.Map.{Entry => JMapEntry} import scala.reflect.ClassTag /** - * A map that bounds the number of key-value pairs present in it. It can be configured to - * drop least recently inserted or used pair. It exposes a scala.collection.mutable.Map interface - * to allow it to be a drop-in replacement of Scala HashMaps. Internally, a Java LinkedHashMap is - * used to get insert-order or access-order behavior. Note that the LinkedHashMap is not - * thread-safe and hence, it is wrapped in a Collections.synchronizedMap. - * However, getting the Java HashMap's iterator and using it can still lead to - * ConcurrentModificationExceptions. Hence, the iterator() function is overridden to copy the - * all pairs into an ArrayBuffer and then return the iterator to the ArrayBuffer. Also, - * the class apply the trait SynchronizedMap which ensures that all calls to the Scala Map API - * are synchronized. This together ensures that ConcurrentModificationException is never thrown. + * A map that upper bounds the number of key-value pairs present in it. It can be configured to + * drop the least recently user pair or the earliest inserted pair. It exposes a + * scala.collection.mutable.Map interface to allow it to be a drop-in replacement for Scala + * HashMaps. + * + * Internally, a Java LinkedHashMap is used to get insert-order or access-order behavior. + * Note that the LinkedHashMap is not thread-safe and hence, it is wrapped in a + * Collections.synchronizedMap. However, getting the Java HashMap's iterator and + * using it can still lead to ConcurrentModificationExceptions. Hence, the iterator() + * function is overridden to copy the all pairs into an ArrayBuffer and then return the + * iterator to the ArrayBuffer. Also, the class apply the trait SynchronizedMap which + * ensures that all calls to the Scala Map API are synchronized. This together ensures + * that ConcurrentModificationException is never thrown. + * * @param bound max number of key-value pairs * @param useLRU true = least recently used/accessed will be dropped when bound is reached, * false = earliest inserted will be dropped @@ -37,7 +58,8 @@ private[spark] class BoundedHashMap[A, B](bound: Int, useLRU: Boolean) /** * Overriding iterator to make sure that the internal Java HashMap's iterator - * is not concurrently modified. + * is not concurrently modified. This can be a performance issue and this should be overridden + * if it is known that this map will not be used in a multi-threaded environment. */ override def iterator: Iterator[(A, B)] = { (new ArrayBuffer[(A, B)] ++= super.iterator).iterator diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala index 8e07a0f29addf..4a4b7d6837bca 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -18,108 +18,66 @@ package org.apache.spark.util import java.util.concurrent.ConcurrentHashMap -import scala.collection.JavaConversions -import scala.collection.mutable.Map -import scala.collection.immutable -import org.apache.spark.scheduler.MapStatus + import org.apache.spark.Logging +private[util] case class TimeStampedValue[T](timestamp: Long, value: T) + /** - * This is a custom implementation of scala.collection.mutable.Map which stores the insertion - * timestamp along with each key-value pair. If specified, the timestamp of each pair can be - * updated every time it is accessed. Key-value pairs whose timestamp are older than a particular - * threshold time can then be removed using the clearOldValues method. This is intended to - * be a drop-in replacement of scala.collection.mutable.HashMap. + * A map that stores the timestamp of when a key was inserted along with the value. If specified, + * the timestamp of each pair can be updated every time it is accessed. + * Key-value pairs whose timestamps are older than a particular + * threshold time can then be removed using the clearOldValues method. It exposes a + * scala.collection.mutable.Map interface to allow it to be a drop-in replacement for Scala + * HashMaps. + * + * Internally, it uses a Java ConcurrentHashMap, so all operations on this HashMap are thread-safe. + * * @param updateTimeStampOnGet When enabled, the timestamp of a pair will be * updated when it is accessed */ -class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false) - extends Map[A, B]() with Logging { - val internalMap = new ConcurrentHashMap[A, (B, Long)]() - - def get(key: A): Option[B] = { - val value = internalMap.get(key) - if (value != null && updateTimeStampOnGet) { - internalMap.replace(key, value, (value._1, currentTime)) - } - Option(value).map(_._1) - } - - def iterator: Iterator[(A, B)] = { - val jIterator = internalMap.entrySet().iterator() - JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue._1)) - } - - override def + [B1 >: B](kv: (A, B1)): Map[A, B1] = { - val newMap = new TimeStampedHashMap[A, B1] - newMap.internalMap.putAll(this.internalMap) - newMap.internalMap.put(kv._1, (kv._2, currentTime)) - newMap - } - - override def - (key: A): Map[A, B] = { - val newMap = new TimeStampedHashMap[A, B] - newMap.internalMap.putAll(this.internalMap) - newMap.internalMap.remove(key) - newMap - } +private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false) + extends WrappedJavaHashMap[A, B, A, TimeStampedValue[B]] with Logging { - override def += (kv: (A, B)): this.type = { - internalMap.put(kv._1, (kv._2, currentTime)) - this - } + protected[util] val internalJavaMap = new ConcurrentHashMap[A, TimeStampedValue[B]]() - // Should we return previous value directly or as Option ? - def putIfAbsent(key: A, value: B): Option[B] = { - val prev = internalMap.putIfAbsent(key, (value, currentTime)) - if (prev != null) Some(prev._1) else None + protected[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = { + new TimeStampedHashMap[K1, V1]() } + def internalMap = internalJavaMap - override def -= (key: A): this.type = { - internalMap.remove(key) - this - } - - override def update(key: A, value: B) { - this += ((key, value)) + override def get(key: A): Option[B] = { + val timeStampedValue = internalMap.get(key) + if (updateTimeStampOnGet && timeStampedValue != null) { + internalJavaMap.replace(key, timeStampedValue, TimeStampedValue(currentTime, timeStampedValue.value)) + } + Option(timeStampedValue).map(_.value) } - - override def apply(key: A): B = { - val value = internalMap.get(key) - if (value == null) throw new NoSuchElementException() - value._1 + @inline override protected def externalValueToInternalValue(v: B): TimeStampedValue[B] = { + new TimeStampedValue(currentTime, v) } - override def filter(p: ((A, B)) => Boolean): Map[A, B] = { - JavaConversions.mapAsScalaConcurrentMap(internalMap).map(kv => (kv._1, kv._2._1)).filter(p) + @inline override protected def internalValueToExternalValue(iv: TimeStampedValue[B]): B = { + iv.value } - override def empty: Map[A, B] = new TimeStampedHashMap[A, B]() - - override def size: Int = internalMap.size - - override def foreach[U](f: ((A, B)) => U) { - val iterator = internalMap.entrySet().iterator() - while(iterator.hasNext) { - val entry = iterator.next() - val kv = (entry.getKey, entry.getValue._1) - f(kv) - } + /** Atomically put if a key is absent. This exposes the existing API of ConcurrentHashMap. */ + def putIfAbsent(key: A, value: B): Option[B] = { + val prev = internalJavaMap.putIfAbsent(key, TimeStampedValue(currentTime, value)) + Option(prev).map(_.value) } - def toMap: immutable.Map[A, B] = iterator.toMap - /** * Removes old key-value pairs that have timestamp earlier than `threshTime`, * calling the supplied function on each such entry before removing. */ def clearOldValues(threshTime: Long, f: (A, B) => Unit) { - val iterator = internalMap.entrySet().iterator() + val iterator = internalJavaMap.entrySet().iterator() while (iterator.hasNext) { val entry = iterator.next() - if (entry.getValue._2 < threshTime) { - f(entry.getKey, entry.getValue._1) + if (entry.getValue.timestamp < threshTime) { + f(entry.getKey, entry.getValue.value) logDebug("Removing key " + entry.getKey) iterator.remove() } @@ -134,5 +92,4 @@ class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false) } private def currentTime: Long = System.currentTimeMillis() - } diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala index 43848def0ffe6..f2ef96f2fbfa9 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.util import scala.collection.{JavaConversions, immutable} @@ -12,8 +29,19 @@ private[util] case class TimeStampedWeakValue[T](timestamp: Long, weakValue: Wea def this(timestamp: Long, value: T) = this(timestamp, new WeakReference[T](value)) } +/** + * A map that stores the timestamp of when a key was inserted along with the value, + * while ensuring that the values are weakly referenced. If the value is garbage collected and + * the weak reference is null, get() operation returns the key be non-existent. However, + * the key is actually not remmoved in the current implementation. Key-value pairs whose + * timestamps are older than a particular threshold time can then be removed using the + * clearOldValues method. It exposes a scala.collection.mutable.Map interface to allow it to be a + * drop-in replacement for Scala HashMaps. + * + * Internally, it uses a Java ConcurrentHashMap, so all operations on this HashMap are thread-safe. + */ -private[spark] class TimeStampedWeakValueHashMap[A, B] +private[spark] class TimeStampedWeakValueHashMap[A, B]() extends WrappedJavaHashMap[A, B, A, TimeStampedWeakValue[B]] with Logging { protected[util] val internalJavaMap: util.Map[A, TimeStampedWeakValue[B]] = { diff --git a/core/src/main/scala/org/apache/spark/util/WrappedJavaHashMap.scala b/core/src/main/scala/org/apache/spark/util/WrappedJavaHashMap.scala index e7c66e494678b..59e35c3abf172 100644 --- a/core/src/main/scala/org/apache/spark/util/WrappedJavaHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/WrappedJavaHashMap.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.util import scala.collection.mutable.Map @@ -8,8 +25,8 @@ import scala.reflect.ClassTag /** * Convenient wrapper class for exposing Java HashMaps as Scala Maps even if the - * exposed key-value type is different from the internal type. This allows Scala HashMaps to be - * hot replaceable with these Java HashMaps. + * exposed key-value type is different from the internal type. This allows these + * implementations of WrappedJavaHashMap to be drop-in replacements for Scala HashMaps. * * While Java <-> Scala conversion methods exists, its hard to understand the performance * implications and thread safety of the Scala wrapper. This class allows you to convert @@ -62,7 +79,7 @@ private[spark] abstract class WrappedJavaHashMap[K, V, IK, IV] extends Map[K, V] (internalKeyToExternalKey(ip.getKey), internalValueToExternalValue(ip.getValue) ) } - /* Implicit functions to convert the types. */ + /* Implicit methods to convert the types. */ @inline implicit private def convExtKeyToIntKey(k: K) = externalKeyToInternalKey(k) @@ -76,6 +93,8 @@ private[spark] abstract class WrappedJavaHashMap[K, V, IK, IV] extends Map[K, V] internalPairToExternalPair(ip) } + /* Methods that must be implemented for a scala.collection.mutable.Map */ + def get(key: K): Option[V] = { Option(internalJavaMap.get(key)) } @@ -85,6 +104,8 @@ private[spark] abstract class WrappedJavaHashMap[K, V, IK, IV] extends Map[K, V] JavaConversions.asScalaIterator(jIterator).map(kv => convIntPairToExtPair(kv)) } + /* Other methods that are implemented to ensure performance. */ + def +=(kv: (K, V)): this.type = { internalJavaMap.put(kv._1, kv._2) this @@ -109,8 +130,9 @@ private[spark] abstract class WrappedJavaHashMap[K, V, IK, IV] extends Map[K, V] } override def foreach[U](f: ((K, V)) => U) { - while(iterator.hasNext) { - f(iterator.next()) + val jIterator = internalJavaMap.entrySet().iterator() + while(jIterator.hasNext) { + f(jIterator.next()) } } diff --git a/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala index 9fc4681b524e9..7ad65c9681812 100644 --- a/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala @@ -1,12 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.util -import scala.collection.mutable.{HashMap, Map} +import scala.collection.mutable.{ArrayBuffer, HashMap, Map} +import scala.util.Random import java.util +import java.lang.ref.WeakReference import org.scalatest.FunSuite -import scala.util.Random -import java.lang.ref.WeakReference class WrappedJavaHashMapSuite extends FunSuite { @@ -33,7 +50,7 @@ class WrappedJavaHashMapSuite extends FunSuite { assert(map("k1") === "v1") Thread.sleep(10) val threshTime = System.currentTimeMillis() - assert(map.internalMap.get("k1")._2 < threshTime) + assert(map.internalMap.get("k1").timestamp < threshTime) map.clearOldValues(threshTime) assert(map.get("k1") === None) @@ -46,8 +63,8 @@ class WrappedJavaHashMapSuite extends FunSuite { val threshTime1 = System.currentTimeMillis() Thread.sleep(10) assert(map1("k2") === "v2") // access k2 to update its access time to > threshTime - assert(map1.internalMap.get("k1")._2 < threshTime1) - assert(map1.internalMap.get("k2")._2 >= threshTime1) + assert(map1.internalMap.get("k1").timestamp < threshTime1) + assert(map1.internalMap.get("k2").timestamp >= threshTime1) map1.clearOldValues(threshTime1) //should only clear k1 assert(map1.get("k1") === None) assert(map1.get("k2").isDefined) @@ -126,12 +143,15 @@ class WrappedJavaHashMapSuite extends FunSuite { assert(testMap2.iterator.toSet === pairs.toSet) testMap2("k1") = "v1" + // foreach + val buffer = new ArrayBuffer[(String, String)] + testMap2.foreach(x => buffer += x) + assert(testMap2.toSet === buffer.toSet) + // multi remove testMap2 --= keys assert(testMap2.size === 1) assert(testMap2.iterator.toSeq.head === ("k1", "v1")) - - // new instance } } From cb0a5a66ce7dbc2ded209d8bdd0cd88953f70b5f Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 11 Mar 2014 11:33:43 -0700 Subject: [PATCH 05/37] Fixed docs and styles. --- .../scala/org/apache/spark/ContextCleaner.scala | 13 +++++++------ .../scala/org/apache/spark/MapOutputTracker.scala | 14 +++++++++++++- .../org/apache/spark/util/TimeStampedHashMap.scala | 3 ++- .../spark/util/TimeStampedWeakValueHashMap.scala | 2 +- 4 files changed, 23 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 1cc4271f8cf33..3df44ae1fad64 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -30,7 +30,7 @@ private[spark] trait CleanerListener { } /** - * Cleans RDDs and shuffle data. This should be instantiated only on the driver. + * Cleans RDDs and shuffle data. */ private[spark] class ContextCleaner(env: SparkEnv) extends Logging { @@ -62,12 +62,13 @@ private[spark] class ContextCleaner(env: SparkEnv) extends Logging { cleaningThread.interrupt() } - /** Clean all data and metadata related to a RDD, including shuffle files and metadata */ + /** Clean (unpersist) RDD data. */ def cleanRDD(rdd: RDD[_]) { enqueue(CleanRDD(rdd.sparkContext, rdd.id)) logDebug("Enqueued RDD " + rdd + " for cleaning up") } + /** Clean shuffle data. */ def cleanShuffle(shuffleId: Int) { enqueue(CleanShuffle(shuffleId)) logDebug("Enqueued shuffle " + shuffleId + " for cleaning up") @@ -102,16 +103,16 @@ private[spark] class ContextCleaner(env: SparkEnv) extends Logging { /** Perform RDD cleaning */ private def doCleanRDD(sc: SparkContext, rddId: Int) { - logDebug("Cleaning rdd "+ rddId) + logDebug("Cleaning rdd " + rddId) sc.env.blockManager.master.removeRdd(rddId, false) sc.persistentRdds.remove(rddId) listeners.foreach(_.rddCleaned(rddId)) - logInfo("Cleaned rdd "+ rddId) + logInfo("Cleaned rdd " + rddId) } /** Perform shuffle cleaning */ private def doCleanShuffle(shuffleId: Int) { - logDebug("Cleaning shuffle "+ shuffleId) + logDebug("Cleaning shuffle " + shuffleId) mapOutputTrackerMaster.unregisterShuffle(shuffleId) blockManager.master.removeShuffle(shuffleId) listeners.foreach(_.shuffleCleaned(shuffleId)) @@ -123,4 +124,4 @@ private[spark] class ContextCleaner(env: SparkEnv) extends Logging { private def blockManager = env.blockManager private def isStopped = synchronized { stopped } -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index ed498696f3e2a..4d0f3dd6cdb71 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -54,6 +54,11 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster } } +/** + * Class that keeps track of the location of the location of the mapt output of + * a stage. This is abstract because different versions of MapOutputTracker + * (driver and worker) use different HashMap to store its metadata. + */ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging { private val timeout = AkkaUtils.askTimeout(conf) @@ -181,6 +186,10 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } +/** + * MapOutputTracker for the workers. This uses BoundedHashMap to keep track of + * a limited number of most recently used map output information. + */ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) { /** @@ -192,7 +201,10 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr protected val mapStatuses = new BoundedHashMap[Int, Array[MapStatus]](MAX_MAP_STATUSES, true) } - +/** + * MapOutputTracker for the driver. This uses TimeStampedHashMap to keep track of map + * output information, which allows old output information based on a TTL. + */ private[spark] class MapOutputTrackerMaster(conf: SparkConf) extends MapOutputTracker(conf) { diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala index 4a4b7d6837bca..60901c5e36130 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -50,7 +50,8 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa override def get(key: A): Option[B] = { val timeStampedValue = internalMap.get(key) if (updateTimeStampOnGet && timeStampedValue != null) { - internalJavaMap.replace(key, timeStampedValue, TimeStampedValue(currentTime, timeStampedValue.value)) + internalJavaMap.replace(key, timeStampedValue, + TimeStampedValue(currentTime, timeStampedValue.value)) } Option(timeStampedValue).map(_.value) } diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala index f2ef96f2fbfa9..ea0fde87c56d0 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala @@ -33,7 +33,7 @@ private[util] case class TimeStampedWeakValue[T](timestamp: Long, weakValue: Wea * A map that stores the timestamp of when a key was inserted along with the value, * while ensuring that the values are weakly referenced. If the value is garbage collected and * the weak reference is null, get() operation returns the key be non-existent. However, - * the key is actually not remmoved in the current implementation. Key-value pairs whose + * the key is actually not removed in the current implementation. Key-value pairs whose * timestamps are older than a particular threshold time can then be removed using the * clearOldValues method. It exposes a scala.collection.mutable.Map interface to allow it to be a * drop-in replacement for Scala HashMaps. From ae9da88b3a88a47e9d59b2be327680190c7e904b Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 11 Mar 2014 17:56:36 -0700 Subject: [PATCH 06/37] Removed unncessary TimeStampedHashMap from DAGScheduler, added try-catches in finalize() methods, and replaced ArrayBlockingQueue to LinkedBlockingQueue to avoid blocking in Java's finalizing thread. --- .../org/apache/spark/ContextCleaner.scala | 5 ++- .../scala/org/apache/spark/Dependency.scala | 19 ++++++++++-- .../main/scala/org/apache/spark/rdd/RDD.scala | 16 +++++++++- .../apache/spark/scheduler/DAGScheduler.scala | 31 ++++--------------- .../spark/streaming/dstream/DStream.scala | 4 ++- 5 files changed, 42 insertions(+), 33 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 3df44ae1fad64..461af1cd11965 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -19,7 +19,7 @@ package org.apache.spark import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} -import java.util.concurrent.{ArrayBlockingQueue, TimeUnit} +import java.util.concurrent.{LinkedBlockingQueue, TimeUnit} import org.apache.spark.rdd.RDD @@ -40,8 +40,7 @@ private[spark] class ContextCleaner(env: SparkEnv) extends Logging { private case class CleanShuffle(id: Int) extends CleaningTask // TODO: add CleanBroadcast - private val QUEUE_CAPACITY = 1000 - private val queue = new ArrayBlockingQueue[CleaningTask](QUEUE_CAPACITY) + private val queue = new LinkedBlockingQueue[CleaningTask] protected val listeners = new ArrayBuffer[CleanerListener] with SynchronizedBuffer[CleanerListener] diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index dba0604ab4866..d24d54576f77a 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -49,13 +49,26 @@ class ShuffleDependency[K, V]( @transient rdd: RDD[_ <: Product2[K, V]], val partitioner: Partitioner, val serializerClass: String = null) - extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) { + extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) with Logging { val shuffleId: Int = rdd.context.newShuffleId() override def finalize() { - if (rdd != null) { - rdd.sparkContext.cleaner.cleanShuffle(shuffleId) + try { + if (rdd != null) { + rdd.sparkContext.cleaner.cleanShuffle(shuffleId) + } + } catch { + case t: Throwable => + // Paranoia - If logError throws error as well, report to stderr. + try { + logError("Error in finalize", t) + } catch { + case _ => + System.err.println("Error in finalize (and could not write to logError): " + t) + } + } finally { + super.finalize() } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index fbb1b486e34f4..e1367131cf569 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1026,6 +1026,7 @@ abstract class RDD[T: ClassTag]( } def cleanup() { + logInfo("Cleanup called on RDD " + id) sc.cleaner.cleanRDD(this) dependencies.filter(_.isInstanceOf[ShuffleDependency[_, _]]) .map(_.asInstanceOf[ShuffleDependency[_, _]].shuffleId) @@ -1112,6 +1113,19 @@ abstract class RDD[T: ClassTag]( } override def finalize() { - cleanup() + try { + cleanup() + } catch { + case t: Throwable => + // Paranoia - If logError throws error as well, report to stderr. + try { + logError("Error in finalize", t) + } catch { + case _ => + System.err.println("Error in finalize (and could not write to logError): " + t) + } + } finally { + super.finalize() + } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index dc5b25d845dc2..38628e949a4a6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -123,17 +123,17 @@ class DAGScheduler( private val nextStageId = new AtomicInteger(0) - private[scheduler] val jobIdToStageIds = new TimeStampedHashMap[Int, HashSet[Int]] + private[scheduler] val jobIdToStageIds = new HashMap[Int, HashSet[Int]] - private[scheduler] val stageIdToJobIds = new TimeStampedHashMap[Int, HashSet[Int]] + private[scheduler] val stageIdToJobIds = new HashMap[Int, HashSet[Int]] - private[scheduler] val stageIdToStage = new TimeStampedHashMap[Int, Stage] + private[scheduler] val stageIdToStage = new HashMap[Int, Stage] - private[scheduler] val shuffleToMapStage = new TimeStampedHashMap[Int, Stage] + private[scheduler] val shuffleToMapStage = new HashMap[Int, Stage] - private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo] + private[spark] val stageToInfos = new HashMap[Stage, StageInfo] - // An async scheduler event bus. The bus should be stopped when DAGSCheduler is stopped. + // An async scheduler event bus. The bus should be stopped when DAGScheduler is stopped. private[spark] val listenerBus = new SparkListenerBus // Contains the locations that each RDD's partitions are cached on @@ -159,9 +159,6 @@ class DAGScheduler( val activeJobs = new HashSet[ActiveJob] val resultStageToJob = new HashMap[Stage, ActiveJob] - val metadataCleaner = new MetadataCleaner( - MetadataCleanerType.DAG_SCHEDULER, this.cleanup, env.conf) - /** * Starts the event processing actor. The actor has two responsibilities: * @@ -1094,26 +1091,10 @@ class DAGScheduler( Nil } - private def cleanup(cleanupTime: Long) { - Map( - "stageIdToStage" -> stageIdToStage, - "shuffleToMapStage" -> shuffleToMapStage, - "pendingTasks" -> pendingTasks, - "stageToInfos" -> stageToInfos, - "jobIdToStageIds" -> jobIdToStageIds, - "stageIdToJobIds" -> stageIdToJobIds). - foreach { case(s, t) => { - val sizeBefore = t.size - t.clearOldValues(cleanupTime) - logInfo("%s %d --> %d".format(s, sizeBefore, t.size)) - }} - } - def stop() { if (eventProcessActor != null) { eventProcessActor ! StopDAGScheduler } - metadataCleaner.cancel() taskSched.stop() listenerBus.stop() } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 6bff56a9d332a..3208359306bee 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -341,9 +341,11 @@ abstract class DStream[T: ClassTag] ( */ private[streaming] def clearMetadata(time: Time) { val oldRDDs = generatedRDDs.filter(_._1 <= (time - rememberDuration)) + logDebug("Clearing references to old RDDs: [" + + oldRDDs.map(x => s"${x._1} -> ${x._2.id}").mkString(", ") + "]") generatedRDDs --= oldRDDs.keys if (ssc.conf.getBoolean("spark.streaming.unpersist", false)) { - logDebug("Unpersisting old RDDs: " + oldRDDs.keys.mkString(", ")) + logDebug("Unpersisting old RDDs: " + oldRDDs.values.map(_.id).mkString(", ")) oldRDDs.values.foreach(_.unpersist(false)) } logDebug("Cleared " + oldRDDs.size + " RDDs that were older than " + From e61daa02e9e221625489ea1dd434cf6d3192e474 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 12 Mar 2014 19:08:42 -0700 Subject: [PATCH 07/37] Modifications based on the comments on PR 126. --- .../org/apache/spark/ContextCleaner.scala | 55 +++++++++++-------- .../org/apache/spark/MapOutputTracker.scala | 52 ++++++++---------- .../scala/org/apache/spark/SparkContext.scala | 2 +- .../main/scala/org/apache/spark/rdd/RDD.scala | 2 +- .../apache/spark/scheduler/DAGScheduler.scala | 2 +- .../spark/storage/ShuffleBlockManager.scala | 10 +++- .../apache/spark/util/MetadataCleaner.scala | 4 +- .../apache/spark/ContextCleanerSuite.scala | 2 +- .../spark/util/WrappedJavaHashMapSuite.scala | 2 +- 9 files changed, 71 insertions(+), 60 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 461af1cd11965..8f76b91753157 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -21,8 +21,6 @@ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import java.util.concurrent.{LinkedBlockingQueue, TimeUnit} -import org.apache.spark.rdd.RDD - /** Listener class used for testing when any item has been cleaned by the Cleaner class */ private[spark] trait CleanerListener { def rddCleaned(rddId: Int) @@ -32,12 +30,12 @@ private[spark] trait CleanerListener { /** * Cleans RDDs and shuffle data. */ -private[spark] class ContextCleaner(env: SparkEnv) extends Logging { +private[spark] class ContextCleaner(sc: SparkContext) extends Logging { /** Classes to represent cleaning tasks */ private sealed trait CleaningTask - private case class CleanRDD(sc: SparkContext, id: Int) extends CleaningTask - private case class CleanShuffle(id: Int) extends CleaningTask + private case class CleanRDD(rddId: Int) extends CleaningTask + private case class CleanShuffle(shuffleId: Int) extends CleaningTask // TODO: add CleanBroadcast private val queue = new LinkedBlockingQueue[CleaningTask] @@ -47,7 +45,7 @@ private[spark] class ContextCleaner(env: SparkEnv) extends Logging { private val cleaningThread = new Thread() { override def run() { keepCleaning() }} - private var stopped = false + @volatile private var stopped = false /** Start the cleaner */ def start() { @@ -57,26 +55,37 @@ private[spark] class ContextCleaner(env: SparkEnv) extends Logging { /** Stop the cleaner */ def stop() { - synchronized { stopped = true } + stopped = true cleaningThread.interrupt() } - /** Clean (unpersist) RDD data. */ - def cleanRDD(rdd: RDD[_]) { - enqueue(CleanRDD(rdd.sparkContext, rdd.id)) - logDebug("Enqueued RDD " + rdd + " for cleaning up") + /** + * Clean (unpersist) RDD data. Do not perform any time or resource intensive + * computation in this function as this is called from a finalize() function. + */ + def cleanRDD(rddId: Int) { + enqueue(CleanRDD(rddId)) + logDebug("Enqueued RDD " + rddId + " for cleaning up") } - /** Clean shuffle data. */ + /** + * Clean shuffle data. Do not perform any time or resource intensive + * computation in this function as this is called from a finalize() function. + */ def cleanShuffle(shuffleId: Int) { enqueue(CleanShuffle(shuffleId)) logDebug("Enqueued shuffle " + shuffleId + " for cleaning up") } + /** Attach a listener object to get information of when objects are cleaned. */ def attachListener(listener: CleanerListener) { listeners += listener } - /** Enqueue a cleaning task */ + + /** + * Enqueue a cleaning task. Do not perform any time or resource intensive + * computation in this function as this is called from a finalize() function. + */ private def enqueue(task: CleaningTask) { queue.put(task) } @@ -86,16 +95,16 @@ private[spark] class ContextCleaner(env: SparkEnv) extends Logging { try { while (!isStopped) { val taskOpt = Option(queue.poll(100, TimeUnit.MILLISECONDS)) - if (taskOpt.isDefined) { + taskOpt.foreach(task => { logDebug("Got cleaning task " + taskOpt.get) - taskOpt.get match { - case CleanRDD(sc, rddId) => doCleanRDD(sc, rddId) + task match { + case CleanRDD(rddId) => doCleanRDD(sc, rddId) case CleanShuffle(shuffleId) => doCleanShuffle(shuffleId) } - } + }) } } catch { - case ie: java.lang.InterruptedException => + case ie: InterruptedException => if (!isStopped) logWarning("Cleaning thread interrupted") } } @@ -103,7 +112,7 @@ private[spark] class ContextCleaner(env: SparkEnv) extends Logging { /** Perform RDD cleaning */ private def doCleanRDD(sc: SparkContext, rddId: Int) { logDebug("Cleaning rdd " + rddId) - sc.env.blockManager.master.removeRdd(rddId, false) + blockManagerMaster.removeRdd(rddId, false) sc.persistentRdds.remove(rddId) listeners.foreach(_.rddCleaned(rddId)) logInfo("Cleaned rdd " + rddId) @@ -113,14 +122,14 @@ private[spark] class ContextCleaner(env: SparkEnv) extends Logging { private def doCleanShuffle(shuffleId: Int) { logDebug("Cleaning shuffle " + shuffleId) mapOutputTrackerMaster.unregisterShuffle(shuffleId) - blockManager.master.removeShuffle(shuffleId) + blockManagerMaster.removeShuffle(shuffleId) listeners.foreach(_.shuffleCleaned(shuffleId)) logInfo("Cleaned shuffle " + shuffleId) } - private def mapOutputTrackerMaster = env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] - private def blockManager = env.blockManager + private def blockManagerMaster = sc.env.blockManager.master - private def isStopped = synchronized { stopped } + private def isStopped = stopped } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 4d0f3dd6cdb71..27f94ce0e42d0 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -17,22 +17,18 @@ package org.apache.spark -import scala.Some -import scala.collection.mutable.{HashSet, Map} -import scala.concurrent.Await - import java.io._ import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.HashSet +import scala.Some +import scala.collection.mutable.{HashSet, Map} import scala.concurrent.Await import akka.actor._ import akka.pattern.ask - import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.{AkkaUtils, TimeStampedHashMap, BoundedHashMap} +import org.apache.spark.util._ private[spark] sealed trait MapOutputTrackerMessage private[spark] case class GetMapOutputStatuses(shuffleId: Int) @@ -55,7 +51,7 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster } /** - * Class that keeps track of the location of the location of the mapt output of + * Class that keeps track of the location of the location of the map output of * a stage. This is abstract because different versions of MapOutputTracker * (driver and worker) use different HashMap to store its metadata. */ @@ -155,10 +151,6 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } - protected def cleanup(cleanupTime: Long) { - mapStatuses.asInstanceOf[TimeStampedHashMap[_, _]].clearOldValues(cleanupTime) - } - def stop() { communicate(StopMapOutputTracker) mapStatuses.clear() @@ -195,10 +187,13 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr /** * Bounded HashMap for storing serialized statuses in the worker. This allows * the HashMap stay bounded in memory-usage. Things dropped from this HashMap will be - * automatically repopulated by fetching them again from the driver. + * automatically repopulated by fetching them again from the driver. Its okay to + * keep the cache size small as it unlikely that there will be a very large number of + * stages active simultaneously in the worker. */ - protected val MAX_MAP_STATUSES = 100 - protected val mapStatuses = new BoundedHashMap[Int, Array[MapStatus]](MAX_MAP_STATUSES, true) + protected val mapStatuses = new BoundedHashMap[Int, Array[MapStatus]]( + conf.getInt("spark.mapOutputTracker.cacheSize", 100), true + ) } /** @@ -212,20 +207,18 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) private var cacheEpoch = epoch /** - * Timestamp based HashMap for storing mapStatuses in the master, so that statuses are dropped - * only by explicit deregistering or by ttl-based cleaning (if set). Other than these two + * Timestamp based HashMap for storing mapStatuses and cached serialized statuses + * in the master, so that statuses are dropped only by explicit deregistering or + * by TTL-based cleaning (if set). Other than these two * scenarios, nothing should be dropped from this HashMap. */ + protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]() + private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]() - /** - * Bounded HashMap for storing serialized statuses in the master. This allows - * the HashMap stay bounded in memory-usage. Things dropped from this HashMap will be - * automatically repopulated by serializing the lost statuses again . - */ - protected val MAX_SERIALIZED_STATUSES = 100 - private val cachedSerializedStatuses = - new BoundedHashMap[Int, Array[Byte]](MAX_SERIALIZED_STATUSES, true) + // For cleaning up TimeStampedHashMaps + private val metadataCleaner = + new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup, conf) def registerShuffle(shuffleId: Int, numMaps: Int) { if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) { @@ -264,6 +257,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) def unregisterShuffle(shuffleId: Int) { mapStatuses.remove(shuffleId) + cachedSerializedStatuses.remove(shuffleId) } def incrementEpoch() { @@ -303,11 +297,12 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) } def contains(shuffleId: Int): Boolean = { - mapStatuses.contains(shuffleId) + cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId) } override def stop() { super.stop() + metadataCleaner.cancel() cachedSerializedStatuses.clear() } @@ -315,8 +310,9 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) // This might be called on the MapOutputTrackerMaster if we're running in local mode. } - def has(shuffleId: Int): Boolean = { - cachedSerializedStatuses.get(shuffleId).isDefined || mapStatuses.contains(shuffleId) + protected def cleanup(cleanupTime: Long) { + mapStatuses.clearOldValues(cleanupTime) + cachedSerializedStatuses.clearOldValues(cleanupTime) } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 74d10196980cf..b80c58489cb52 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -206,7 +206,7 @@ class SparkContext( @volatile private[spark] var dagScheduler = new DAGScheduler(taskScheduler) dagScheduler.start() - private[spark] val cleaner = new ContextCleaner(env) + private[spark] val cleaner = new ContextCleaner(this) cleaner.start() ui.start() diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e1367131cf569..f2e20a108630a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1027,7 +1027,7 @@ abstract class RDD[T: ClassTag]( def cleanup() { logInfo("Cleanup called on RDD " + id) - sc.cleaner.cleanRDD(this) + sc.cleaner.cleanRDD(id) dependencies.filter(_.isInstanceOf[ShuffleDependency[_, _]]) .map(_.asInstanceOf[ShuffleDependency[_, _]].shuffleId) .foreach(sc.cleaner.cleanShuffle) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 38628e949a4a6..1a5cd82571a08 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -266,7 +266,7 @@ class DAGScheduler( : Stage = { val stage = newStage(rdd, numTasks, Some(shuffleDep), jobId, callSite) - if (mapOutputTracker.has(shuffleDep.shuffleId)) { + if (mapOutputTracker.contains(shuffleDep.shuffleId)) { val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) val locs = MapOutputTracker.deserializeMapStatuses(serLocs) for (i <- 0 until locs.size) { diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index ed03f189fb4ac..cf83a60ffb9e8 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -169,8 +169,14 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { throw new IllegalStateException("Failed to find shuffle block: " + id) } - /** Remove all the blocks / files related to a particular shuffle */ + /** Remove all the blocks / files and metadata related to a particular shuffle */ def removeShuffle(shuffleId: ShuffleId) { + removeShuffleBlocks(shuffleId) + shuffleStates.remove(shuffleId) + } + + /** Remove all the blocks / files related to a particular shuffle */ + private def removeShuffleBlocks(shuffleId: ShuffleId) { shuffleStates.get(shuffleId) match { case Some(state) => if (consolidateShuffleFiles) { @@ -194,7 +200,7 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { } private def cleanup(cleanupTime: Long) { - shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffle(shuffleId)) + shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffleBlocks(shuffleId)) } } diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index 2553db4ad589e..2ef853710a554 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -62,8 +62,8 @@ private[spark] class MetadataCleaner( private[spark] object MetadataCleanerType extends Enumeration { - val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, BLOCK_MANAGER, - SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS, CLEANER = Value + val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, BLOCK_MANAGER, + SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value type MetadataCleanerType = Value diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 2ec314aa632f3..cb827b9e955a9 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -25,7 +25,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo val rdd = newRDD.persist() rdd.count() val tester = new CleanerTester(sc, rddIds = Seq(rdd.id)) - cleaner.cleanRDD(rdd) + cleaner.cleanRDD(rdd.id) tester.assertCleanup } diff --git a/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala index 7ad65c9681812..f0a84064ab9fb 100644 --- a/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala @@ -206,4 +206,4 @@ class TestMap[A, B] extends WrappedJavaHashMap[A, B, A, B] { protected[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = { new TestMap[K1, V1] } -} \ No newline at end of file +} From a7260d346882bcdfe6e5014c52960017fb602300 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 17 Mar 2014 15:49:50 -0700 Subject: [PATCH 08/37] Added try-catch in context cleaner and null value cleaning in TimeStampedWeakValueHashMap. --- .../org/apache/spark/ContextCleaner.scala | 50 +++++++++++-------- .../org/apache/spark/MapOutputTracker.scala | 1 - .../util/TimeStampedWeakValueHashMap.scala | 47 ++++++++++++----- 3 files changed, 64 insertions(+), 34 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 8f76b91753157..7636c6cf64972 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -50,6 +50,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { /** Start the cleaner */ def start() { cleaningThread.setDaemon(true) + cleaningThread.setName("ContextCleaner") cleaningThread.start() } @@ -60,7 +61,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** - * Clean (unpersist) RDD data. Do not perform any time or resource intensive + * Clean RDD data. Do not perform any time or resource intensive * computation in this function as this is called from a finalize() function. */ def cleanRDD(rddId: Int) { @@ -92,39 +93,48 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { /** Keep cleaning RDDs and shuffle data */ private def keepCleaning() { - try { - while (!isStopped) { + while (!isStopped) { + try { val taskOpt = Option(queue.poll(100, TimeUnit.MILLISECONDS)) - taskOpt.foreach(task => { + taskOpt.foreach { task => logDebug("Got cleaning task " + taskOpt.get) task match { - case CleanRDD(rddId) => doCleanRDD(sc, rddId) + case CleanRDD(rddId) => doCleanRDD(rddId) case CleanShuffle(shuffleId) => doCleanShuffle(shuffleId) } - }) + } + } catch { + case ie: InterruptedException => + if (!isStopped) logWarning("Cleaning thread interrupted") + case t: Throwable => logError("Error in cleaning thread", t) } - } catch { - case ie: InterruptedException => - if (!isStopped) logWarning("Cleaning thread interrupted") } } /** Perform RDD cleaning */ - private def doCleanRDD(sc: SparkContext, rddId: Int) { - logDebug("Cleaning rdd " + rddId) - blockManagerMaster.removeRdd(rddId, false) - sc.persistentRdds.remove(rddId) - listeners.foreach(_.rddCleaned(rddId)) - logInfo("Cleaned rdd " + rddId) + private def doCleanRDD(rddId: Int) { + try { + logDebug("Cleaning RDD " + rddId) + blockManagerMaster.removeRdd(rddId, false) + sc.persistentRdds.remove(rddId) + listeners.foreach(_.rddCleaned(rddId)) + logInfo("Cleaned RDD " + rddId) + } catch { + case t: Throwable => logError("Error cleaning RDD " + rddId, t) + } } /** Perform shuffle cleaning */ private def doCleanShuffle(shuffleId: Int) { - logDebug("Cleaning shuffle " + shuffleId) - mapOutputTrackerMaster.unregisterShuffle(shuffleId) - blockManagerMaster.removeShuffle(shuffleId) - listeners.foreach(_.shuffleCleaned(shuffleId)) - logInfo("Cleaned shuffle " + shuffleId) + try { + logDebug("Cleaning shuffle " + shuffleId) + mapOutputTrackerMaster.unregisterShuffle(shuffleId) + blockManagerMaster.removeShuffle(shuffleId) + listeners.foreach(_.shuffleCleaned(shuffleId)) + logInfo("Cleaned shuffle " + shuffleId) + } catch { + case t: Throwable => logError("Error cleaning shuffle " + shuffleId, t) + } } private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 27f94ce0e42d0..f37a9d41b2237 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -20,7 +20,6 @@ package org.apache.spark import java.io._ import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.Some import scala.collection.mutable.{HashSet, Map} import scala.concurrent.Await diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala index ea0fde87c56d0..bd86d78b8010f 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala @@ -24,6 +24,7 @@ import java.lang.ref.WeakReference import java.util.concurrent.ConcurrentHashMap import org.apache.spark.Logging +import java.util.concurrent.atomic.AtomicInteger private[util] case class TimeStampedWeakValue[T](timestamp: Long, weakValue: WeakReference[T]) { def this(timestamp: Long, value: T) = this(timestamp, new WeakReference[T](value)) @@ -44,6 +45,12 @@ private[util] case class TimeStampedWeakValue[T](timestamp: Long, weakValue: Wea private[spark] class TimeStampedWeakValueHashMap[A, B]() extends WrappedJavaHashMap[A, B, A, TimeStampedWeakValue[B]] with Logging { + /** Number of inserts after which keys whose weak ref values are null will be cleaned */ + private val CLEANUP_INTERVAL = 1000 + + /** Counter for counting the number of inserts */ + private val insertCounts = new AtomicInteger(0) + protected[util] val internalJavaMap: util.Map[A, TimeStampedWeakValue[B]] = { new ConcurrentHashMap[A, TimeStampedWeakValue[B]]() } @@ -52,11 +59,21 @@ private[spark] class TimeStampedWeakValueHashMap[A, B]() new TimeStampedWeakValueHashMap[K1, V1]() } + override def +=(kv: (A, B)): this.type = { + // Cleanup null value at certain intervals + if (insertCounts.incrementAndGet() % CLEANUP_INTERVAL == 0) { + cleanNullValues() + } + super.+=(kv) + } + override def get(key: A): Option[B] = { Option(internalJavaMap.get(key)) match { case Some(weakValue) => val value = weakValue.weakValue.get - if (value == null) cleanupKey(key) + if (value == null) { + internalJavaMap.remove(key) + } Option(value) case None => None @@ -72,16 +89,10 @@ private[spark] class TimeStampedWeakValueHashMap[A, B]() } override def iterator: Iterator[(A, B)] = { - val jIterator = internalJavaMap.entrySet().iterator() - JavaConversions.asScalaIterator(jIterator).flatMap(kv => { - val key = kv.getKey - val value = kv.getValue.weakValue.get - if (value == null) { - cleanupKey(key) - Seq.empty - } else { - Seq((key, value)) - } + val iterator = internalJavaMap.entrySet().iterator() + JavaConversions.asScalaIterator(iterator).flatMap(kv => { + val (key, value) = (kv.getKey, kv.getValue.weakValue.get) + if (value != null) Seq((key, value)) else Seq.empty }) } @@ -104,8 +115,18 @@ private[spark] class TimeStampedWeakValueHashMap[A, B]() } } - private def cleanupKey(key: A) { - // TODO: Consider cleaning up keys to empty weak ref values automatically in future. + /** + * Removes keys whose weak referenced values have become null. + */ + private def cleanNullValues() { + val iterator = internalJavaMap.entrySet().iterator() + while (iterator.hasNext) { + val entry = iterator.next() + if (entry.getValue.weakValue.get == null) { + logDebug("Removing key " + entry.getKey) + iterator.remove() + } + } } private def currentTime = System.currentTimeMillis() From 892b9520d828cfa7049e6ec70345b3502b139a8e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 18 Mar 2014 15:09:24 -0700 Subject: [PATCH 09/37] Removed use of BoundedHashMap, and made BlockManagerSlaveActor cleanup shuffle metadata in MapOutputTrackerWorker. --- .../org/apache/spark/ContextCleaner.scala | 19 +++- .../scala/org/apache/spark/Dependency.scala | 4 +- .../org/apache/spark/MapOutputTracker.scala | 106 +++++++++--------- .../scala/org/apache/spark/SparkEnv.scala | 25 +++-- .../main/scala/org/apache/spark/rdd/RDD.scala | 15 +-- .../apache/spark/scheduler/DAGScheduler.scala | 8 +- .../apache/spark/scheduler/ResultTask.scala | 10 +- .../spark/scheduler/ShuffleMapTask.scala | 12 +- .../apache/spark/storage/BlockManager.scala | 14 ++- .../storage/BlockManagerSlaveActor.scala | 6 +- .../spark/storage/DiskBlockManager.scala | 2 +- .../apache/spark/storage/ThreadingTest.scala | 5 +- .../apache/spark/ContextCleanerSuite.scala | 14 ++- .../apache/spark/MapOutputTrackerSuite.scala | 6 +- .../spark/storage/BlockManagerSuite.scala | 91 ++++++++++----- .../spark/storage/DiskBlockManagerSuite.scala | 4 +- .../spark/util/WrappedJavaHashMapSuite.scala | 2 + 17 files changed, 196 insertions(+), 147 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 7636c6cf64972..5d996ed34dff5 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -20,6 +20,7 @@ package org.apache.spark import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import java.util.concurrent.{LinkedBlockingQueue, TimeUnit} +import org.apache.spark.storage.StorageLevel /** Listener class used for testing when any item has been cleaned by the Cleaner class */ private[spark] trait CleanerListener { @@ -61,19 +62,19 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** - * Clean RDD data. Do not perform any time or resource intensive + * Schedule cleanup of RDD data. Do not perform any time or resource intensive * computation in this function as this is called from a finalize() function. */ - def cleanRDD(rddId: Int) { + def scheduleRDDCleanup(rddId: Int) { enqueue(CleanRDD(rddId)) logDebug("Enqueued RDD " + rddId + " for cleaning up") } /** - * Clean shuffle data. Do not perform any time or resource intensive + * Schedule cleanup of shuffle data. Do not perform any time or resource intensive * computation in this function as this is called from a finalize() function. */ - def cleanShuffle(shuffleId: Int) { + def scheduleShuffleCleanup(shuffleId: Int) { enqueue(CleanShuffle(shuffleId)) logDebug("Enqueued shuffle " + shuffleId + " for cleaning up") } @@ -83,6 +84,13 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { listeners += listener } + /** Unpersists RDD and remove all blocks for it from memory and disk. */ + def unpersistRDD(rddId: Int, blocking: Boolean) { + logDebug("Unpersisted RDD " + rddId) + sc.env.blockManager.master.removeRdd(rddId, blocking) + sc.persistentRdds.remove(rddId) + } + /** * Enqueue a cleaning task. Do not perform any time or resource intensive * computation in this function as this is called from a finalize() function. @@ -115,8 +123,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private def doCleanRDD(rddId: Int) { try { logDebug("Cleaning RDD " + rddId) - blockManagerMaster.removeRdd(rddId, false) - sc.persistentRdds.remove(rddId) + unpersistRDD(rddId, false) listeners.foreach(_.rddCleaned(rddId)) logInfo("Cleaned RDD " + rddId) } catch { diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index d24d54576f77a..557d424d7a786 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -56,7 +56,7 @@ class ShuffleDependency[K, V]( override def finalize() { try { if (rdd != null) { - rdd.sparkContext.cleaner.cleanShuffle(shuffleId) + rdd.sparkContext.cleaner.scheduleShuffleCleanup(shuffleId) } } catch { case t: Throwable => @@ -64,7 +64,7 @@ class ShuffleDependency[K, V]( try { logError("Error in finalize", t) } catch { - case _ => + case _ : Throwable => System.err.println("Error in finalize (and could not write to logError): " + t) } } finally { diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index f37a9d41b2237..ffdf9115e1aae 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -20,7 +20,7 @@ package org.apache.spark import java.io._ import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.{HashSet, Map} +import scala.collection.mutable.{HashSet, HashMap, Map} import scala.concurrent.Await import akka.actor._ @@ -34,6 +34,7 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage +/** Actor class for MapOutputTrackerMaster */ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster) extends Actor with Logging { def receive = { @@ -50,7 +51,7 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster } /** - * Class that keeps track of the location of the location of the map output of + * Class that keeps track of the location of the map output of * a stage. This is abstract because different versions of MapOutputTracker * (driver and worker) use different HashMap to store its metadata. */ @@ -58,20 +59,27 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging private val timeout = AkkaUtils.askTimeout(conf) - // Set to the MapOutputTrackerActor living on the driver + /** Set to the MapOutputTrackerActor living on the driver */ var trackerActor: ActorRef = _ /** This HashMap needs to have different storage behavior for driver and worker */ protected val mapStatuses: Map[Int, Array[MapStatus]] - // Incremented every time a fetch fails so that client nodes know to clear - // their cache of map output locations if this happens. + /** + * Incremented every time a fetch fails so that client nodes know to clear + * their cache of map output locations if this happens. + */ protected var epoch: Long = 0 protected val epochLock = new java.lang.Object - // Send a message to the trackerActor and get its result within a default timeout, or - // throw a SparkException if this fails. - private def askTracker(message: Any): Any = { + /** Remembers which map output locations are currently being fetched on a worker */ + private val fetching = new HashSet[Int] + + /** + * Send a message to the trackerActor and get its result within a default timeout, or + * throw a SparkException if this fails. + */ + protected def askTracker(message: Any): Any = { try { val future = trackerActor.ask(message)(timeout) Await.result(future, timeout) @@ -81,17 +89,17 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } - // Send a one-way message to the trackerActor, to which we expect it to reply with true. - private def communicate(message: Any) { + /** Send a one-way message to the trackerActor, to which we expect it to reply with true. */ + protected def sendTracker(message: Any) { if (askTracker(message) != true) { throw new SparkException("Error reply received from MapOutputTracker") } } - // Remembers which map output locations are currently being fetched on a worker - private val fetching = new HashSet[Int] - - // Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle + /** + * Called from executors to get the server URIs and + * output sizes of the map outputs of a given shuffle + */ def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = { val statuses = mapStatuses.get(shuffleId).orNull if (statuses == null) { @@ -150,22 +158,18 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } - def stop() { - communicate(StopMapOutputTracker) - mapStatuses.clear() - trackerActor = null - } - - // Called to get current epoch number + /** Called to get current epoch number */ def getEpoch: Long = { epochLock.synchronized { return epoch } } - // Called on workers to update the epoch number, potentially clearing old outputs - // because of a fetch failure. (Each worker task calls this with the latest epoch - // number on the master at the time it was created.) + /** + * Called from executors to update the epoch number, potentially clearing old outputs + * because of a fetch failure. Each worker task calls this with the latest epoch + * number on the master at the time it was created. + */ def updateEpoch(newEpoch: Long) { epochLock.synchronized { if (newEpoch > epoch) { @@ -175,24 +179,17 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } } -} -/** - * MapOutputTracker for the workers. This uses BoundedHashMap to keep track of - * a limited number of most recently used map output information. - */ -private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) { + /** Unregister shuffle data */ + def unregisterShuffle(shuffleId: Int) { + mapStatuses.remove(shuffleId) + } - /** - * Bounded HashMap for storing serialized statuses in the worker. This allows - * the HashMap stay bounded in memory-usage. Things dropped from this HashMap will be - * automatically repopulated by fetching them again from the driver. Its okay to - * keep the cache size small as it unlikely that there will be a very large number of - * stages active simultaneously in the worker. - */ - protected val mapStatuses = new BoundedHashMap[Int, Array[MapStatus]]( - conf.getInt("spark.mapOutputTracker.cacheSize", 100), true - ) + def stop() { + sendTracker(StopMapOutputTracker) + mapStatuses.clear() + trackerActor = null + } } /** @@ -202,7 +199,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr private[spark] class MapOutputTrackerMaster(conf: SparkConf) extends MapOutputTracker(conf) { - // Cache a serialized version of the output statuses for each shuffle to send them out faster + /** Cache a serialized version of the output statuses for each shuffle to send them out faster */ private var cacheEpoch = epoch /** @@ -211,7 +208,6 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) * by TTL-based cleaning (if set). Other than these two * scenarios, nothing should be dropped from this HashMap. */ - protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]() private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]() @@ -232,6 +228,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) } } + /** Register multiple map output information for the given shuffle */ def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) { mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses) if (changeEpoch) { @@ -239,6 +236,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) } } + /** Unregister map output information of the given shuffle, mapper and block manager */ def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { val arrayOpt = mapStatuses.get(shuffleId) if (arrayOpt.isDefined && arrayOpt.get != null) { @@ -254,11 +252,17 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) } } - def unregisterShuffle(shuffleId: Int) { + /** Unregister shuffle data */ + override def unregisterShuffle(shuffleId: Int) { mapStatuses.remove(shuffleId) cachedSerializedStatuses.remove(shuffleId) } + /** Check if the given shuffle is being tracked */ + def containsShuffle(shuffleId: Int): Boolean = { + cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId) + } + def incrementEpoch() { epochLock.synchronized { epoch += 1 @@ -295,26 +299,26 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) bytes } - def contains(shuffleId: Int): Boolean = { - cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId) - } - override def stop() { super.stop() metadataCleaner.cancel() cachedSerializedStatuses.clear() } - override def updateEpoch(newEpoch: Long) { - // This might be called on the MapOutputTrackerMaster if we're running in local mode. - } - protected def cleanup(cleanupTime: Long) { mapStatuses.clearOldValues(cleanupTime) cachedSerializedStatuses.clearOldValues(cleanupTime) } } +/** + * MapOutputTracker for the workers, which fetches map output information from the driver's + * MapOutputTrackerMaster. + */ +private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) { + protected val mapStatuses = new HashMap[Int, Array[MapStatus]] +} + private[spark] object MapOutputTracker { private val LOG_BASE = 1.1 diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index fdfd00660377f..f636f6363b34b 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -165,18 +165,6 @@ object SparkEnv extends Logging { } } - val blockManagerMaster = new BlockManagerMaster(registerOrLookup( - "BlockManagerMaster", - new BlockManagerMasterActor(isLocal, conf)), conf) - val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, - serializer, conf, securityManager) - - val connectionManager = blockManager.connectionManager - - val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) - - val cacheManager = new CacheManager(blockManager) - // Have to assign trackerActor after initialization as MapOutputTrackerActor // requires the MapOutputTracker itself val mapOutputTracker = if (isDriver) { @@ -188,6 +176,19 @@ object SparkEnv extends Logging { "MapOutputTracker", new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster])) + val blockManagerMaster = new BlockManagerMaster(registerOrLookup( + "BlockManagerMaster", + new BlockManagerMasterActor(isLocal, conf)), conf) + + val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, + serializer, conf, securityManager, mapOutputTracker) + + val connectionManager = blockManager.connectionManager + + val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) + + val cacheManager = new CacheManager(blockManager) + val shuffleFetcher = instantiateClass[ShuffleFetcher]( "spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher") diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index f2e20a108630a..a75bca42257d4 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -165,8 +165,7 @@ abstract class RDD[T: ClassTag]( */ def unpersist(blocking: Boolean = true): RDD[T] = { logInfo("Removing RDD " + id + " from persistence list") - sc.env.blockManager.master.removeRdd(id, blocking) - sc.persistentRdds.remove(id) + sc.cleaner.unpersistRDD(id, blocking) storageLevel = StorageLevel.NONE this } @@ -1025,14 +1024,6 @@ abstract class RDD[T: ClassTag]( checkpointData.flatMap(_.getCheckpointFile) } - def cleanup() { - logInfo("Cleanup called on RDD " + id) - sc.cleaner.cleanRDD(id) - dependencies.filter(_.isInstanceOf[ShuffleDependency[_, _]]) - .map(_.asInstanceOf[ShuffleDependency[_, _]].shuffleId) - .foreach(sc.cleaner.cleanShuffle) - } - // ======================================================================= // Other internal methods and fields // ======================================================================= @@ -1114,14 +1105,14 @@ abstract class RDD[T: ClassTag]( override def finalize() { try { - cleanup() + sc.cleaner.scheduleRDDCleanup(id) } catch { case t: Throwable => // Paranoia - If logError throws error as well, report to stderr. try { logError("Error in finalize", t) } catch { - case _ => + case _ : Throwable => System.err.println("Error in finalize (and could not write to logError): " + t) } } finally { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 1a5cd82571a08..253b19880c700 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -32,7 +32,6 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerMaster, RDDBlockId} -import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} /** * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of @@ -154,7 +153,7 @@ class DAGScheduler( val running = new HashSet[Stage] // Stages we are running right now val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures // Missing tasks from each stage - val pendingTasks = new TimeStampedHashMap[Stage, HashSet[Task[_]]] + val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] val activeJobs = new HashSet[ActiveJob] val resultStageToJob = new HashMap[Stage, ActiveJob] @@ -266,7 +265,7 @@ class DAGScheduler( : Stage = { val stage = newStage(rdd, numTasks, Some(shuffleDep), jobId, callSite) - if (mapOutputTracker.contains(shuffleDep.shuffleId)) { + if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) val locs = MapOutputTracker.deserializeMapStatuses(serLocs) for (i <- 0 until locs.size) { @@ -398,6 +397,9 @@ class DAGScheduler( stageIdToStage -= stageId stageIdToJobIds -= stageId + ShuffleMapTask.removeStage(stageId) + ResultTask.removeStage(stageId) + logDebug("After removal of stage %d, remaining stages = %d" .format(stageId, stageIdToStage.size)) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 59fd630e0431a..083fb895d8696 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -20,17 +20,17 @@ package org.apache.spark.scheduler import java.io._ import java.util.zip.{GZIPInputStream, GZIPOutputStream} +import scala.collection.mutable.HashMap + import org.apache.spark._ import org.apache.spark.rdd.{RDD, RDDCheckpointData} -import org.apache.spark.util.BoundedHashMap private[spark] object ResultTask { // A simple map between the stage id to the serialized byte array of a task. // Served as a cache for task serialization because serialization can be // expensive on the master node if it needs to launch thousands of tasks. - val MAX_CACHE_SIZE = 100 - val serializedInfoCache = new BoundedHashMap[Int, Array[Byte]](MAX_CACHE_SIZE, true) + private val serializedInfoCache = new HashMap[Int, Array[Byte]] def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = { @@ -63,6 +63,10 @@ private[spark] object ResultTask { (rdd, func) } + def removeStage(stageId: Int) { + serializedInfoCache.remove(stageId) + } + def clearCache() { synchronized { serializedInfoCache.clear() diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index df3a7b9ee37ad..bb2eda79ea249 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -17,24 +17,22 @@ package org.apache.spark.scheduler -import scala.collection.mutable.HashMap - import java.io._ import java.util.zip.{GZIPInputStream, GZIPOutputStream} +import scala.collection.mutable.HashMap + import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.rdd.{RDD, RDDCheckpointData} import org.apache.spark.storage._ -import org.apache.spark.util.BoundedHashMap private[spark] object ShuffleMapTask { // A simple map between the stage id to the serialized byte array of a task. // Served as a cache for task serialization because serialization can be // expensive on the master node if it needs to launch thousands of tasks. - val MAX_CACHE_SIZE = 100 - val serializedInfoCache = new BoundedHashMap[Int, Array[Byte]](MAX_CACHE_SIZE, true) + private val serializedInfoCache = new HashMap[Int, Array[Byte]] def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = { synchronized { @@ -75,6 +73,10 @@ private[spark] object ShuffleMapTask { HashMap(set.toSeq: _*) } + def removeStage(stageId: Int) { + serializedInfoCache.remove(stageId) + } + def clearCache() { synchronized { serializedInfoCache.clear() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index f2aff78914f96..091df41412f6c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -29,7 +29,7 @@ import akka.actor.{ActorSystem, Cancellable, Props} import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream} import sun.nio.ch.DirectBuffer -import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException, SecurityManager} +import org.apache.spark._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.serializer.Serializer @@ -48,8 +48,9 @@ private[spark] class BlockManager( val defaultSerializer: Serializer, maxMemory: Long, val conf: SparkConf, - securityManager: SecurityManager) - extends Logging { + securityManager: SecurityManager, + mapOutputTracker: MapOutputTracker + ) extends Logging { val shuffleBlockManager = new ShuffleBlockManager(this) val diskBlockManager = new DiskBlockManager(shuffleBlockManager, @@ -89,7 +90,7 @@ private[spark] class BlockManager( val heartBeatFrequency = BlockManager.getHeartBeatFrequency(conf) - val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)), + val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this, mapOutputTracker)), name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) // Pending reregistration action being executed asynchronously or null if none @@ -123,9 +124,10 @@ private[spark] class BlockManager( * Construct a BlockManager with a memory limit set based on system properties. */ def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster, - serializer: Serializer, conf: SparkConf, securityManager: SecurityManager) = { + serializer: Serializer, conf: SparkConf, securityManager: SecurityManager, + mapOutputTracker: MapOutputTracker) = { this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), conf, - securityManager) + securityManager, mapOutputTracker) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 9ff7aacec141a..dfc19591781d0 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import akka.actor.Actor +import org.apache.spark.MapOutputTracker import org.apache.spark.storage.BlockManagerMessages._ /** @@ -26,7 +27,7 @@ import org.apache.spark.storage.BlockManagerMessages._ * this is used to remove blocks from the slave's BlockManager. */ private[storage] -class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor { +class BlockManagerSlaveActor(blockManager: BlockManager, mapOutputTracker: MapOutputTracker) extends Actor { override def receive = { case RemoveBlock(blockId) => @@ -38,5 +39,8 @@ class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor { case RemoveShuffle(shuffleId) => blockManager.shuffleBlockManager.removeShuffle(shuffleId) + if (mapOutputTracker != null) { + mapOutputTracker.unregisterShuffle(shuffleId) + } } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index cdee285a1cbd4..a57e6f710305a 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -91,7 +91,7 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD def getFile(blockId: BlockId): File = getFile(blockId.name) /** Check if disk block manager has a block */ - def contains(blockId: BlockId): Boolean = { + def containsBlock(blockId: BlockId): Boolean = { getBlockLocation(blockId).file.exists() } diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala index 36f2a0fd02724..233754f6eddfd 100644 --- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala @@ -22,9 +22,8 @@ import java.util.concurrent.ArrayBlockingQueue import akka.actor._ import util.Random -import org.apache.spark.SparkConf +import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SecurityManager} import org.apache.spark.serializer.KryoSerializer -import org.apache.spark.{SecurityManager, SparkConf} /** * This class tests the BlockManager and MemoryStore for thread safety and @@ -100,7 +99,7 @@ private[spark] object ThreadingTest { actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf))), conf) val blockManager = new BlockManager( "", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf, - new SecurityManager(conf)) + new SecurityManager(conf), new MapOutputTrackerMaster(conf)) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) producers.foreach(_.start) diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index cb827b9e955a9..8556888c96e06 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -25,7 +25,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo val rdd = newRDD.persist() rdd.count() val tester = new CleanerTester(sc, rddIds = Seq(rdd.id)) - cleaner.cleanRDD(rdd.id) + cleaner.scheduleRDDCleanup(rdd.id) tester.assertCleanup } @@ -33,7 +33,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo val rdd = newShuffleRDD rdd.count() val tester = new CleanerTester(sc, shuffleIds = Seq(0)) - cleaner.cleanShuffle(0) + cleaner.scheduleShuffleCleanup(0) tester.assertCleanup } @@ -106,6 +106,8 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo postGCTester.assertCleanup } + // TODO (TD): Test that cleaned up RDD and shuffle can be recomputed again correctly. + def newRDD = sc.makeRDD(1 to 10) def newPairRDD = newRDD.map(_ -> 1) @@ -173,9 +175,9 @@ class CleanerTester(sc: SparkContext, rddIds: Seq[Int] = Nil, shuffleIds: Seq[In "One or more RDDs' blocks cannot be found in block manager, cannot start cleaner test") // Verify the shuffle ids are registered and blocks are present - assert(shuffleIds.forall(mapOutputTrackerMaster.contains), + assert(shuffleIds.forall(mapOutputTrackerMaster.containsShuffle), "One or more shuffles have not been registered cannot start cleaner test") - assert(shuffleIds.forall(shuffleId => diskBlockManager.contains(shuffleBlockId(shuffleId))), + assert(shuffleIds.forall(shuffleId => diskBlockManager.containsBlock(shuffleBlockId(shuffleId))), "One or more shuffles' blocks cannot be found in disk manager, cannot start cleaner test") } @@ -185,8 +187,8 @@ class CleanerTester(sc: SparkContext, rddIds: Seq[Int] = Nil, shuffleIds: Seq[In assert(rddIds.forall(rddId => !blockManager.master.contains(rddBlockId(rddId)))) // Verify all the shuffle have been deregistered and cleaned up - assert(shuffleIds.forall(!mapOutputTrackerMaster.contains(_))) - assert(shuffleIds.forall(shuffleId => !diskBlockManager.contains(shuffleBlockId(shuffleId)))) + assert(shuffleIds.forall(!mapOutputTrackerMaster.containsShuffle(_))) + assert(shuffleIds.forall(shuffleId => !diskBlockManager.containsBlock(shuffleBlockId(shuffleId)))) } private def uncleanedResourcesToString = { diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 9091ab9265465..9358099abbe24 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -60,7 +60,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val tracker = new MapOutputTrackerMaster(conf) tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker))) tracker.registerShuffle(10, 2) - assert(tracker.contains(10)) + assert(tracker.containsShuffle(10)) val compressedSize1000 = MapOutputTracker.compressSize(1000L) val compressedSize10000 = MapOutputTracker.compressSize(10000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) @@ -86,10 +86,10 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { Array(compressedSize1000, compressedSize10000))) tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0), Array(compressedSize10000, compressedSize1000))) - assert(tracker.contains(10)) + assert(tracker.containsShuffle(10)) assert(tracker.getServerStatuses(10, 0).nonEmpty) tracker.unregisterShuffle(10) - assert(!tracker.contains(10)) + assert(!tracker.containsShuffle(10)) assert(tracker.getServerStatuses(10, 0).isEmpty) } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 1036b9f34e9dd..197b1004990ce 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -28,7 +28,7 @@ import org.scalatest.concurrent.Timeouts._ import org.scalatest.matchers.ShouldMatchers._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SecurityManager, SparkConf, SparkContext} +import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf, SparkContext} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils} @@ -41,6 +41,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT var oldArch: String = null conf.set("spark.authenticate", "false") val securityMgr = new SecurityManager(conf) + val mapOutputTracker = new MapOutputTrackerMaster(conf) // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test conf.set("spark.kryoserializer.buffer.mb", "1") @@ -128,7 +129,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("master + 1 manager interaction") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -158,9 +160,10 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("master + 2 managers interaction") { - store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer(conf), 2000, conf, - securityMgr) + securityMgr, mapOutputTracker) val peers = master.getPeers(store.blockManagerId, 1) assert(peers.size === 1, "master did not return the other manager as a peer") @@ -175,7 +178,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("removing block") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -223,7 +227,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("removing rdd") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -257,7 +262,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("reregistration on heart beat") { val heartBeat = PrivateMethod[Unit]('heartBeat) - store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) @@ -273,7 +279,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("reregistration on block update") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) @@ -292,7 +299,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("reregistration doesn't dead lock") { val heartBeat = PrivateMethod[Unit]('heartBeat) - store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = List(new Array[Byte](400)) @@ -329,7 +337,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU storage") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -348,7 +357,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU storage with serialization") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -367,7 +377,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU for partitions of same RDD") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -386,7 +397,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU for partitions of multiple RDDs") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) store.putSingle(rdd(0, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle(rdd(0, 2), new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle(rdd(1, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY) @@ -409,7 +421,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("on-disk storage") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -422,7 +435,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -437,7 +451,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with getLocalBytes") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -452,7 +467,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with serialization") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -467,7 +483,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with serialization and getLocalBytes") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -482,7 +499,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("LRU with mixed storage levels") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -507,7 +525,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU with streams") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val list1 = List(new Array[Byte](200), new Array[Byte](200)) val list2 = List(new Array[Byte](200), new Array[Byte](200)) val list3 = List(new Array[Byte](200), new Array[Byte](200)) @@ -531,7 +550,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("LRU with mixed storage levels and streams") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val list1 = List(new Array[Byte](200), new Array[Byte](200)) val list2 = List(new Array[Byte](200), new Array[Byte](200)) val list3 = List(new Array[Byte](200), new Array[Byte](200)) @@ -577,7 +597,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("overly large block") { - store = new BlockManager("", actorSystem, master, serializer, 500, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 500, conf, + securityMgr, mapOutputTracker) store.putSingle("a1", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) assert(store.getSingle("a1") === None, "a1 was in store") store.putSingle("a2", new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK) @@ -588,7 +609,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("block compression") { try { conf.set("spark.shuffle.compress", "true") - store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) <= 100, "shuffle_0_0_0 was not compressed") @@ -596,7 +618,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT store = null conf.set("spark.shuffle.compress", "false") - store = new BlockManager("exec2", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("exec2", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) >= 1000, "shuffle_0_0_0 was compressed") @@ -604,7 +627,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT store = null conf.set("spark.broadcast.compress", "true") - store = new BlockManager("exec3", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("exec3", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(BroadcastBlockId(0)) <= 100, "broadcast_0 was not compressed") @@ -612,28 +636,32 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT store = null conf.set("spark.broadcast.compress", "false") - store = new BlockManager("exec4", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("exec4", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(BroadcastBlockId(0)) >= 1000, "broadcast_0 was compressed") store.stop() store = null conf.set("spark.rdd.compress", "true") - store = new BlockManager("exec5", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("exec5", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(rdd(0, 0)) <= 100, "rdd_0_0 was not compressed") store.stop() store = null conf.set("spark.rdd.compress", "false") - store = new BlockManager("exec6", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("exec6", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(rdd(0, 0)) >= 1000, "rdd_0_0 was compressed") store.stop() store = null // Check that any other block types are also kept uncompressed - store = new BlockManager("exec7", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("exec7", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) store.putSingle("other_block", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) assert(store.memoryStore.getSize("other_block") >= 1000, "other_block was compressed") store.stop() @@ -648,7 +676,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("block store put failure") { // Use Java serializer so we can create an unserializable error. store = new BlockManager("", actorSystem, master, new JavaSerializer(conf), 1200, conf, - securityMgr) + securityMgr, mapOutputTracker) // The put should fail since a1 is not serializable. class UnserializableClass @@ -664,7 +692,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("SPARK-1194 regression: fix the same-RDD rule for cache replacement") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) store.putSingle(rdd(0, 0), new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle(rdd(1, 0), new Array[Byte](400), StorageLevel.MEMORY_ONLY) // Access rdd_1_0 to ensure it's not least recently used. diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index d594d2bc06760..0dd34223787cd 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -59,9 +59,9 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach { val newFile = diskBlockManager.getFile(blockId) writeToFile(newFile, 10) assertSegmentEquals(blockId, blockId.name, 0, 10) - assert(diskBlockManager.contains(blockId)) + assert(diskBlockManager.containsBlock(blockId)) newFile.delete() - assert(!diskBlockManager.contains(blockId)) + assert(!diskBlockManager.containsBlock(blockId)) } test("block appending") { diff --git a/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala index f0a84064ab9fb..37c1f748a6f3d 100644 --- a/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala @@ -107,6 +107,8 @@ class WrappedJavaHashMapSuite extends FunSuite { } assert(map.internalJavaMap.get("k1").weakValue.get == null) assert(map.get("k1") === None) + + // TODO (TD): Test clearing of null-value pairs } def testMap(hashMapConstructor: => Map[String, String]) { From e1fba5fee616d810afc2f33979af34721c35238e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 18 Mar 2014 17:06:51 -0700 Subject: [PATCH 10/37] Style fix --- .../org/apache/spark/storage/BlockManagerSlaveActor.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index dfc19591781d0..a6ff147c1d3e6 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -27,7 +27,10 @@ import org.apache.spark.storage.BlockManagerMessages._ * this is used to remove blocks from the slave's BlockManager. */ private[storage] -class BlockManagerSlaveActor(blockManager: BlockManager, mapOutputTracker: MapOutputTracker) extends Actor { +class BlockManagerSlaveActor( + blockManager: BlockManager, + mapOutputTracker: MapOutputTracker + ) extends Actor { override def receive = { case RemoveBlock(blockId) => From f2881fd7d4afaead50632418c4a927ecd09eac65 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 25 Mar 2014 11:41:22 -0700 Subject: [PATCH 11/37] Changed ContextCleaner to use ReferenceQueue instead of finalizer --- .../org/apache/spark/ContextCleaner.scala | 80 ++++++++------ .../scala/org/apache/spark/Dependency.scala | 19 +--- .../main/scala/org/apache/spark/rdd/RDD.scala | 18 +--- .../apache/spark/ContextCleanerSuite.scala | 100 +++++++++++++----- 4 files changed, 123 insertions(+), 94 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 5d996ed34dff5..d499af20502d0 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -17,10 +17,11 @@ package org.apache.spark +import java.lang.ref.{ReferenceQueue, WeakReference} + import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} -import java.util.concurrent.{LinkedBlockingQueue, TimeUnit} -import org.apache.spark.storage.StorageLevel +import org.apache.spark.rdd.RDD /** Listener class used for testing when any item has been cleaned by the Cleaner class */ private[spark] trait CleanerListener { @@ -34,20 +35,27 @@ private[spark] trait CleanerListener { private[spark] class ContextCleaner(sc: SparkContext) extends Logging { /** Classes to represent cleaning tasks */ - private sealed trait CleaningTask - private case class CleanRDD(rddId: Int) extends CleaningTask - private case class CleanShuffle(shuffleId: Int) extends CleaningTask + private sealed trait CleanupTask + private case class CleanRDD(rddId: Int) extends CleanupTask + private case class CleanShuffle(shuffleId: Int) extends CleanupTask // TODO: add CleanBroadcast - private val queue = new LinkedBlockingQueue[CleaningTask] + private val referenceBuffer = new ArrayBuffer[WeakReferenceWithCleanupTask] + with SynchronizedBuffer[WeakReferenceWithCleanupTask] + private val referenceQueue = new ReferenceQueue[AnyRef] - protected val listeners = new ArrayBuffer[CleanerListener] + private val listeners = new ArrayBuffer[CleanerListener] with SynchronizedBuffer[CleanerListener] private val cleaningThread = new Thread() { override def run() { keepCleaning() }} + private val REF_QUEUE_POLL_TIMEOUT = 100 + @volatile private var stopped = false + private class WeakReferenceWithCleanupTask(referent: AnyRef, val task: CleanupTask) + extends WeakReference(referent, referenceQueue) + /** Start the cleaner */ def start() { cleaningThread.setDaemon(true) @@ -62,21 +70,27 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** - * Schedule cleanup of RDD data. Do not perform any time or resource intensive - * computation in this function as this is called from a finalize() function. + * Register a RDD for cleanup when it is garbage collected. */ - def scheduleRDDCleanup(rddId: Int) { - enqueue(CleanRDD(rddId)) - logDebug("Enqueued RDD " + rddId + " for cleaning up") + def registerRDDForCleanup(rdd: RDD[_]) { + registerForCleanup(rdd, CleanRDD(rdd.id)) } /** - * Schedule cleanup of shuffle data. Do not perform any time or resource intensive - * computation in this function as this is called from a finalize() function. + * Register a shuffle dependency for cleanup when it is garbage collected. */ - def scheduleShuffleCleanup(shuffleId: Int) { - enqueue(CleanShuffle(shuffleId)) - logDebug("Enqueued shuffle " + shuffleId + " for cleaning up") + def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _]) { + registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId)) + } + + /** Cleanup RDD. */ + def cleanupRDD(rdd: RDD[_]) { + doCleanupRDD(rdd.id) + } + + /** Cleanup shuffle. */ + def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) { + doCleanupShuffle(shuffleDependency.shuffleId) } /** Attach a listener object to get information of when objects are cleaned. */ @@ -91,24 +105,23 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { sc.persistentRdds.remove(rddId) } - /** - * Enqueue a cleaning task. Do not perform any time or resource intensive - * computation in this function as this is called from a finalize() function. - */ - private def enqueue(task: CleaningTask) { - queue.put(task) + /** Register an object for cleanup. */ + private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask) { + referenceBuffer += new WeakReferenceWithCleanupTask(objectForCleanup, task) } /** Keep cleaning RDDs and shuffle data */ private def keepCleaning() { while (!isStopped) { try { - val taskOpt = Option(queue.poll(100, TimeUnit.MILLISECONDS)) - taskOpt.foreach { task => - logDebug("Got cleaning task " + taskOpt.get) + val reference = Option(referenceQueue.remove(REF_QUEUE_POLL_TIMEOUT)) + .map(_.asInstanceOf[WeakReferenceWithCleanupTask]) + reference.map(_.task).foreach { task => + logDebug("Got cleaning task " + task) + referenceBuffer -= reference.get task match { - case CleanRDD(rddId) => doCleanRDD(rddId) - case CleanShuffle(shuffleId) => doCleanShuffle(shuffleId) + case CleanRDD(rddId) => doCleanupRDD(rddId) + case CleanShuffle(shuffleId) => doCleanupShuffle(shuffleId) } } } catch { @@ -119,8 +132,8 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } } - /** Perform RDD cleaning */ - private def doCleanRDD(rddId: Int) { + /** Perform RDD cleanup. */ + private def doCleanupRDD(rddId: Int) { try { logDebug("Cleaning RDD " + rddId) unpersistRDD(rddId, false) @@ -131,8 +144,8 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } } - /** Perform shuffle cleaning */ - private def doCleanShuffle(shuffleId: Int) { + /** Perform shuffle cleanup. */ + private def doCleanupShuffle(shuffleId: Int) { try { logDebug("Cleaning shuffle " + shuffleId) mapOutputTrackerMaster.unregisterShuffle(shuffleId) @@ -144,7 +157,8 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } } - private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + private def mapOutputTrackerMaster = + sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] private def blockManagerMaster = sc.env.blockManager.master diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 557d424d7a786..132468ebdb4f8 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -53,24 +53,7 @@ class ShuffleDependency[K, V]( val shuffleId: Int = rdd.context.newShuffleId() - override def finalize() { - try { - if (rdd != null) { - rdd.sparkContext.cleaner.scheduleShuffleCleanup(shuffleId) - } - } catch { - case t: Throwable => - // Paranoia - If logError throws error as well, report to stderr. - try { - logError("Error in finalize", t) - } catch { - case _ : Throwable => - System.err.println("Error in finalize (and could not write to logError): " + t) - } - } finally { - super.finalize() - } - } + rdd.sparkContext.cleaner.registerShuffleForCleanup(this) } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index a75bca42257d4..364156a8e0779 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -147,6 +147,7 @@ abstract class RDD[T: ClassTag]( } storageLevel = newLevel // Register the RDD with the SparkContext + sc.cleaner.registerRDDForCleanup(this) sc.persistentRdds(id) = this this } @@ -1102,21 +1103,4 @@ abstract class RDD[T: ClassTag]( def toJavaRDD() : JavaRDD[T] = { new JavaRDD(this)(elementClassTag) } - - override def finalize() { - try { - sc.cleaner.scheduleRDDCleanup(id) - } catch { - case t: Throwable => - // Paranoia - If logError throws error as well, report to stderr. - try { - logError("Error in finalize", t) - } catch { - case _ : Throwable => - System.err.println("Error in finalize (and could not write to logError): " + t) - } - } finally { - super.finalize() - } - } } diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 8556888c96e06..a5f17309b4ec5 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -9,7 +9,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext._ import org.apache.spark.storage.{RDDBlockId, ShuffleBlockId} -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{ShuffleCoGroupSplitDep, RDD} import scala.util.Random import java.lang.ref.WeakReference @@ -23,18 +23,28 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo test("cleanup RDD") { val rdd = newRDD.persist() - rdd.count() + val collected = rdd.collect().toList val tester = new CleanerTester(sc, rddIds = Seq(rdd.id)) - cleaner.scheduleRDDCleanup(rdd.id) + + // Explicit cleanup + cleaner.cleanupRDD(rdd) tester.assertCleanup + + // verify that RDDs can be re-executed after cleaning up + assert(rdd.collect().toList === collected) } test("cleanup shuffle") { - val rdd = newShuffleRDD - rdd.count() - val tester = new CleanerTester(sc, shuffleIds = Seq(0)) - cleaner.scheduleShuffleCleanup(0) + val (rdd, shuffleDeps) = newRDDWithShuffleDependencies + val collected = rdd.collect().toList + val tester = new CleanerTester(sc, shuffleIds = shuffleDeps.map(_.shuffleId)) + + // Explicit cleanup + shuffleDeps.foreach(s => cleaner.cleanupShuffle(s)) tester.assertCleanup + + // Verify that shuffles can be re-executed after cleaning up + assert(rdd.collect().toList === collected) } test("automatically cleanup RDD") { @@ -43,7 +53,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo // test that GC does not cause RDD cleanup due to a strong reference val preGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id)) - doGC() + runGC() intercept[Exception] { preGCTester.assertCleanup(timeout(1000 millis)) } @@ -51,7 +61,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo // test that GC causes RDD cleanup after dereferencing the RDD val postGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id)) rdd = null // make RDD out of scope - doGC() + runGC() postGCTester.assertCleanup } @@ -61,7 +71,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo // test that GC does not cause shuffle cleanup due to a strong reference val preGCTester = new CleanerTester(sc, shuffleIds = Seq(0)) - doGC() + runGC() intercept[Exception] { preGCTester.assertCleanup(timeout(1000 millis)) } @@ -69,7 +79,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo // test that GC causes shuffle cleanup after dereferencing the RDD val postGCTester = new CleanerTester(sc, shuffleIds = Seq(0)) rdd = null // make RDD out of scope, so that corresponding shuffle goes out of scope - doGC() + runGC() postGCTester.assertCleanup } @@ -87,7 +97,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } val buffer = new ArrayBuffer[RDD[_]] - for (i <- 1 to 1000) { + for (i <- 1 to 500) { buffer += randomRDD } @@ -95,34 +105,47 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo val shuffleIds = 0 until sc.newShuffleId val preGCTester = new CleanerTester(sc, rddIds, shuffleIds) + runGC() intercept[Exception] { preGCTester.assertCleanup(timeout(1000 millis)) } - // test that GC causes shuffle cleanup after dereferencing the RDD val postGCTester = new CleanerTester(sc, rddIds, shuffleIds) buffer.clear() - doGC() + runGC() postGCTester.assertCleanup } - // TODO (TD): Test that cleaned up RDD and shuffle can be recomputed again correctly. - def newRDD = sc.makeRDD(1 to 10) def newPairRDD = newRDD.map(_ -> 1) def newShuffleRDD = newPairRDD.reduceByKey(_ + _) - def doGC() { + def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _]]) = { + def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = { + rdd.dependencies ++ rdd.dependencies.flatMap { dep => + getAllDependencies(dep.rdd) + } + } + val rdd = newShuffleRDD + + // Get all the shuffle dependencies + val shuffleDeps = getAllDependencies(rdd).filter(_.isInstanceOf[ShuffleDependency[_, _]]) + .map(_.asInstanceOf[ShuffleDependency[_, _]]) + (rdd, shuffleDeps) + } + + /** Run GC and make sure it actually has run */ + def runGC() { val weakRef = new WeakReference(new Object()) val startTime = System.currentTimeMillis System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC. - System.runFinalization() // Make a best effort to call finalizer on all cleaned objects. + // Wait until a weak reference object has been GCed while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { System.gc() System.runFinalization() - Thread.sleep(100) + Thread.sleep(200) } } @@ -149,10 +172,14 @@ class CleanerTester(sc: SparkContext, rddIds: Seq[Int] = Nil, shuffleIds: Seq[In } } + val MAX_VALIDATION_ATTEMPTS = 10 + val VALIDATION_ATTEMPT_INTERVAL = 100 + logInfo("Attempting to validate before cleanup:\n" + uncleanedResourcesToString) preCleanupValidate() sc.cleaner.attachListener(cleanerListener) + /** Assert that all the stuff has been cleaned up */ def assertCleanup(implicit waitTimeout: Eventually.Timeout) { try { eventually(waitTimeout, interval(10 millis)) { @@ -165,6 +192,7 @@ class CleanerTester(sc: SparkContext, rddIds: Seq[Int] = Nil, shuffleIds: Seq[In } } + /** Verify that RDDs, shuffles, etc. occupy resources */ private def preCleanupValidate() { assert(rddIds.nonEmpty || shuffleIds.nonEmpty, "Nothing to cleanup") @@ -181,14 +209,34 @@ class CleanerTester(sc: SparkContext, rddIds: Seq[Int] = Nil, shuffleIds: Seq[In "One or more shuffles' blocks cannot be found in disk manager, cannot start cleaner test") } + /** + * Verify that RDDs, shuffles, etc. do not occupy resources. Tests multiple times as there is + * as there is not guarantee on how long it will take clean up the resources. + */ private def postCleanupValidate() { - // Verify all the RDDs have been persisted - assert(rddIds.forall(!sc.persistentRdds.contains(_))) - assert(rddIds.forall(rddId => !blockManager.master.contains(rddBlockId(rddId)))) - - // Verify all the shuffle have been deregistered and cleaned up - assert(shuffleIds.forall(!mapOutputTrackerMaster.containsShuffle(_))) - assert(shuffleIds.forall(shuffleId => !diskBlockManager.containsBlock(shuffleBlockId(shuffleId)))) + var attempts = 0 + while (attempts < MAX_VALIDATION_ATTEMPTS) { + attempts += 1 + logInfo("Attempt: " + attempts) + try { + // Verify all the RDDs have been unpersisted + assert(rddIds.forall(!sc.persistentRdds.contains(_))) + assert(rddIds.forall(rddId => !blockManager.master.contains(rddBlockId(rddId)))) + + // Verify all the shuffle have been deregistered and cleaned up + assert(shuffleIds.forall(!mapOutputTrackerMaster.containsShuffle(_))) + assert(shuffleIds.forall(shuffleId => + !diskBlockManager.containsBlock(shuffleBlockId(shuffleId)))) + return + } catch { + case t: Throwable => + if (attempts >= MAX_VALIDATION_ATTEMPTS) { + throw t + } else { + Thread.sleep(VALIDATION_ATTEMPT_INTERVAL) + } + } + } } private def uncleanedResourcesToString = { From 620eca349808befa6c339bc5acc351c484495557 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 25 Mar 2014 13:05:47 -0700 Subject: [PATCH 12/37] Changes based on PR comments. --- .../org/apache/spark/MapOutputTracker.scala | 4 ++-- .../storage/BlockManagerMasterActor.scala | 2 +- .../apache/spark/util/BoundedHashMap.scala | 4 ++-- .../spark/util/TimeStampedHashMap.scala | 4 ++-- .../util/TimeStampedWeakValueHashMap.scala | 19 ++++++++----------- .../spark/util/WrappedJavaHashMap.scala | 10 +++++++--- .../spark/util/WrappedJavaHashMapSuite.scala | 10 +++++----- 7 files changed, 27 insertions(+), 26 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index ffdf9115e1aae..ad9ee73e6b2e0 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -70,7 +70,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * their cache of map output locations if this happens. */ protected var epoch: Long = 0 - protected val epochLock = new java.lang.Object + protected val epochLock = new AnyRef /** Remembers which map output locations are currently being fetched on a worker */ private val fetching = new HashSet[Int] @@ -305,7 +305,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) cachedSerializedStatuses.clear() } - protected def cleanup(cleanupTime: Long) { + private def cleanup(cleanupTime: Long) { mapStatuses.clearOldValues(cleanupTime) cachedSerializedStatuses.clearOldValues(cleanupTime) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 64cbedc8afcd3..cefbd28511bfd 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -150,7 +150,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf) extends Act private def removeShuffle(shuffleId: Int) { // Nothing to do in the BlockManagerMasterActor data structures val removeMsg = RemoveShuffle(shuffleId) - blockManagerInfo.values.map { bm => + blockManagerInfo.values.foreach { bm => bm.slaveActor ! removeMsg } } diff --git a/core/src/main/scala/org/apache/spark/util/BoundedHashMap.scala b/core/src/main/scala/org/apache/spark/util/BoundedHashMap.scala index c4f7df1ee0a7b..888a06b2408c9 100644 --- a/core/src/main/scala/org/apache/spark/util/BoundedHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/BoundedHashMap.scala @@ -45,14 +45,14 @@ import scala.reflect.ClassTag private[spark] class BoundedHashMap[A, B](bound: Int, useLRU: Boolean) extends WrappedJavaHashMap[A, B, A, B] with SynchronizedMap[A, B] { - protected[util] val internalJavaMap = Collections.synchronizedMap(new LinkedHashMap[A, B]( + private[util] val internalJavaMap = Collections.synchronizedMap(new LinkedHashMap[A, B]( bound / 8, (0.75).toFloat, useLRU) { override protected def removeEldestEntry(eldest: JMapEntry[A, B]): Boolean = { size() > bound } }) - protected[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = { + private[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = { new BoundedHashMap[K1, V1](bound, useLRU) } diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala index 60901c5e36130..c4d770fecdf74 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -39,9 +39,9 @@ private[util] case class TimeStampedValue[T](timestamp: Long, value: T) private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false) extends WrappedJavaHashMap[A, B, A, TimeStampedValue[B]] with Logging { - protected[util] val internalJavaMap = new ConcurrentHashMap[A, TimeStampedValue[B]]() + private[util] val internalJavaMap = new ConcurrentHashMap[A, TimeStampedValue[B]]() - protected[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = { + private[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = { new TimeStampedHashMap[K1, V1]() } diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala index bd86d78b8010f..09a6faf33ec60 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala @@ -51,11 +51,11 @@ private[spark] class TimeStampedWeakValueHashMap[A, B]() /** Counter for counting the number of inserts */ private val insertCounts = new AtomicInteger(0) - protected[util] val internalJavaMap: util.Map[A, TimeStampedWeakValue[B]] = { + private[util] val internalJavaMap: util.Map[A, TimeStampedWeakValue[B]] = { new ConcurrentHashMap[A, TimeStampedWeakValue[B]]() } - protected[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = { + private[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = { new TimeStampedWeakValueHashMap[K1, V1]() } @@ -68,15 +68,12 @@ private[spark] class TimeStampedWeakValueHashMap[A, B]() } override def get(key: A): Option[B] = { - Option(internalJavaMap.get(key)) match { - case Some(weakValue) => - val value = weakValue.weakValue.get - if (value == null) { - internalJavaMap.remove(key) - } - Option(value) - case None => - None + Option(internalJavaMap.get(key)).flatMap { weakValue => + val value = weakValue.weakValue.get + if (value == null) { + internalJavaMap.remove(key) + } + Option(value) } } diff --git a/core/src/main/scala/org/apache/spark/util/WrappedJavaHashMap.scala b/core/src/main/scala/org/apache/spark/util/WrappedJavaHashMap.scala index 59e35c3abf172..6cc3007f5d7ac 100644 --- a/core/src/main/scala/org/apache/spark/util/WrappedJavaHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/WrappedJavaHashMap.scala @@ -46,11 +46,15 @@ private[spark] abstract class WrappedJavaHashMap[K, V, IK, IV] extends Map[K, V] /* Methods that must be defined. */ - /** Internal Java HashMap that is being wrapped. */ - protected[util] val internalJavaMap: JMap[IK, IV] + /** + * Internal Java HashMap that is being wrapped. + * Scoped private[util] so that rest of Spark code cannot + * directly access the internal map. + */ + private[util] val internalJavaMap: JMap[IK, IV] /** Method to get a new instance of the internal Java HashMap. */ - protected[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] + private[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] /* Methods that convert between internal and external types. These implementations diff --git a/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala index 37c1f748a6f3d..e446c7f75dc0b 100644 --- a/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.util -import scala.collection.mutable.{ArrayBuffer, HashMap, Map} -import scala.util.Random - import java.util import java.lang.ref.WeakReference +import scala.collection.mutable.{ArrayBuffer, HashMap, Map} +import scala.util.Random + import org.scalatest.FunSuite class WrappedJavaHashMapSuite extends FunSuite { @@ -203,9 +203,9 @@ class WrappedJavaHashMapSuite extends FunSuite { } class TestMap[A, B] extends WrappedJavaHashMap[A, B, A, B] { - protected[util] val internalJavaMap: util.Map[A, B] = new util.HashMap[A, B]() + private[util] val internalJavaMap: util.Map[A, B] = new util.HashMap[A, B]() - protected[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = { + private[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = { new TestMap[K1, V1] } } From d2f8b977f2d78689512b67c82627f0b22e64daa7 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 25 Mar 2014 14:36:07 -0700 Subject: [PATCH 13/37] Removed duplicate unpersistRDD. --- .../src/main/scala/org/apache/spark/ContextCleaner.scala | 9 +-------- core/src/main/scala/org/apache/spark/SparkContext.scala | 3 +-- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 2 +- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index d499af20502d0..deabf6f5c8c5f 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -98,13 +98,6 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { listeners += listener } - /** Unpersists RDD and remove all blocks for it from memory and disk. */ - def unpersistRDD(rddId: Int, blocking: Boolean) { - logDebug("Unpersisted RDD " + rddId) - sc.env.blockManager.master.removeRdd(rddId, blocking) - sc.persistentRdds.remove(rddId) - } - /** Register an object for cleanup. */ private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask) { referenceBuffer += new WeakReferenceWithCleanupTask(objectForCleanup, task) @@ -136,7 +129,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private def doCleanupRDD(rddId: Int) { try { logDebug("Cleaning RDD " + rddId) - unpersistRDD(rddId, false) + sc.unpersistRDD(rddId, false) listeners.foreach(_.rddCleaned(rddId)) logInfo("Cleaned RDD " + rddId) } catch { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 500fb098e6649..5cd2caed10297 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -756,8 +756,7 @@ class SparkContext( /** * Unpersist an RDD from memory and/or disk storage */ - private[spark] def unpersistRDD(rdd: RDD[_], blocking: Boolean = true) { - val rddId = rdd.id + private[spark] def unpersistRDD(rddId: Int, blocking: Boolean = true) { env.blockManager.master.removeRdd(rddId, blocking) persistentRdds.remove(rddId) listenerBus.post(SparkListenerUnpersistRDD(rddId)) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index a5dc7a959fb22..2b7e3d99e68cb 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -158,7 +158,7 @@ abstract class RDD[T: ClassTag]( */ def unpersist(blocking: Boolean = true): RDD[T] = { logInfo("Removing RDD " + id + " from persistence list") - sc.unpersistRDD(this, blocking) + sc.unpersistRDD(this.id, blocking) storageLevel = StorageLevel.NONE this } From 6c9dcf608a0628a70b1ef48bde985e1e37f7bac4 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 25 Mar 2014 15:14:33 -0700 Subject: [PATCH 14/37] Added missing Apache license --- .../org/apache/spark/ContextCleanerSuite.scala | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index a5f17309b4ec5..b07f8817b7974 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark import scala.collection.mutable.{ArrayBuffer, HashSet, SynchronizedSet} From ba52e00303896e46ce9cb5122e78e12d7cae7864 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 26 Mar 2014 14:43:52 -0700 Subject: [PATCH 15/37] Refactor broadcast classes --- .../scala/org/apache/spark/SparkContext.scala | 7 +- .../apache/spark/broadcast/Broadcast.scala | 51 ----------- .../spark/broadcast/BroadcastFactory.scala | 2 +- .../spark/broadcast/BroadcastManager.scala | 63 ++++++++++++++ .../spark/broadcast/HttpBroadcast.scala | 59 +++---------- .../broadcast/HttpBroadcastFactory.scala | 34 ++++++++ .../spark/broadcast/TorrentBroadcast.scala | 86 ++----------------- .../broadcast/TorrentBroadcastFactory.scala | 36 ++++++++ .../apache/spark/storage/BlockManager.scala | 12 --- .../apache/spark/storage/MemoryStore.scala | 38 ++++---- .../org/apache/spark/BroadcastSuite.scala | 49 ----------- 11 files changed, 169 insertions(+), 268 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala create mode 100644 core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala create mode 100644 core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3e4b40a7f7b4d..5cd2caed10297 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -641,13 +641,8 @@ class SparkContext( * Broadcast a read-only variable to the cluster, returning a * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. * The variable will be sent to each cluster only once. - * - * If `registerBlocks` is true, workers will notify driver about blocks they create - * and these blocks will be dropped when `unpersist` method of the broadcast variable is called. */ - def broadcast[T](value: T, registerBlocks: Boolean = false) = { - env.broadcastManager.newBroadcast[T](value, isLocal, registerBlocks) - } + def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal) /** * Add a file to be downloaded with this Spark job on every node. diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index 516e6ba4005c8..e3e1e4f29b107 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -18,9 +18,6 @@ package org.apache.spark.broadcast import java.io.Serializable -import java.util.concurrent.atomic.AtomicLong - -import org.apache.spark._ /** * A broadcast variable. Broadcast variables allow the programmer to keep a read-only variable @@ -53,56 +50,8 @@ import org.apache.spark._ abstract class Broadcast[T](val id: Long) extends Serializable { def value: T - /** - * Removes all blocks of this broadcast from memory (and disk if removeSource is true). - * - * @param removeSource Whether to remove data from disk as well. - * Will cause errors if broadcast is accessed on workers afterwards - * (e.g. in case of RDD re-computation due to executor failure). - */ - def unpersist(removeSource: Boolean = false) - // We cannot have an abstract readObject here due to some weird issues with // readObject having to be 'private' in sub-classes. override def toString = "Broadcast(" + id + ")" } - -private[spark] -class BroadcastManager(val _isDriver: Boolean, conf: SparkConf, securityManager: SecurityManager) - extends Logging with Serializable { - - private var initialized = false - private var broadcastFactory: BroadcastFactory = null - - initialize() - - // Called by SparkContext or Executor before using Broadcast - private def initialize() { - synchronized { - if (!initialized) { - val broadcastFactoryClass = conf.get( - "spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") - - broadcastFactory = - Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] - - // Initialize appropriate BroadcastFactory and BroadcastObject - broadcastFactory.initialize(isDriver, conf, securityManager) - - initialized = true - } - } - } - - def stop() { - broadcastFactory.stop() - } - - private val nextBroadcastId = new AtomicLong(0) - - def newBroadcast[T](value_ : T, isLocal: Boolean, registerBlocks: Boolean) = - broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement(), registerBlocks) - - def isDriver = _isDriver -} diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index 7aff8d7bb670b..0a0bb6cca336c 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -28,6 +28,6 @@ import org.apache.spark.SparkConf */ trait BroadcastFactory { def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit - def newBroadcast[T](value: T, isLocal: Boolean, id: Long, registerBlocks: Boolean): Broadcast[T] + def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala new file mode 100644 index 0000000000000..746e23e81931a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import java.util.concurrent.atomic.AtomicLong + +import org.apache.spark._ + +private[spark] class BroadcastManager( + val isDriver: Boolean, + conf: SparkConf, + securityManager: SecurityManager) + extends Logging with Serializable { + + private var initialized = false + private var broadcastFactory: BroadcastFactory = null + + initialize() + + // Called by SparkContext or Executor before using Broadcast + private def initialize() { + synchronized { + if (!initialized) { + val broadcastFactoryClass = + conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") + + broadcastFactory = + Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] + + // Initialize appropriate BroadcastFactory and BroadcastObject + broadcastFactory.initialize(isDriver, conf, securityManager) + + initialized = true + } + } + } + + def stop() { + broadcastFactory.stop() + } + + private val nextBroadcastId = new AtomicLong(0) + + def newBroadcast[T](value_ : T, isLocal: Boolean) = { + broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) + } + +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 6c2413cea526a..374180e472805 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -29,24 +29,11 @@ import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils} -private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) +private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { def value = value_ - def unpersist(removeSource: Boolean) { - HttpBroadcast.synchronized { - SparkEnv.get.blockManager.master.removeBlock(blockId) - SparkEnv.get.blockManager.removeBlock(blockId) - } - - if (removeSource) { - HttpBroadcast.synchronized { - HttpBroadcast.cleanupById(id) - } - } - } - def blockId = BroadcastBlockId(id) HttpBroadcast.synchronized { @@ -67,7 +54,7 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea logInfo("Started reading broadcast variable " + id) val start = System.nanoTime value_ = HttpBroadcast.read[T](id) - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, registerBlocks) + SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) val time = (System.nanoTime - start) / 1e9 logInfo("Reading broadcast variable " + id + " took " + time + " s") } @@ -76,20 +63,6 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea } } -/** - * A [[BroadcastFactory]] implementation that uses a HTTP server as the broadcast medium. - */ -class HttpBroadcastFactory extends BroadcastFactory { - def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { - HttpBroadcast.initialize(isDriver, conf, securityMgr) - } - - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) = - new HttpBroadcast[T](value_, isLocal, id, registerBlocks) - - def stop() { HttpBroadcast.stop() } -} - private object HttpBroadcast extends Logging { private var initialized = false @@ -149,10 +122,8 @@ private object HttpBroadcast extends Logging { logInfo("Broadcast server started at " + serverUri) } - def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name) - def write(id: Long, value: Any) { - val file = getFile(id) + val file = new File(broadcastDir, BroadcastBlockId(id).name) val out: OutputStream = { if (compress) { compressionCodec.compressedOutputStream(new FileOutputStream(file)) @@ -198,30 +169,20 @@ private object HttpBroadcast extends Logging { obj } - def deleteFile(fileName: String) { - try { - new File(fileName).delete() - logInfo("Deleted broadcast file '" + fileName + "'") - } catch { - case e: Exception => logWarning("Could not delete broadcast file '" + fileName + "'", e) - } - } - def cleanup(cleanupTime: Long) { val iterator = files.internalMap.entrySet().iterator() while(iterator.hasNext) { val entry = iterator.next() val (file, time) = (entry.getKey, entry.getValue) if (time < cleanupTime) { - iterator.remove() - deleteFile(file) + try { + iterator.remove() + new File(file.toString).delete() + logInfo("Deleted broadcast file '" + file + "'") + } catch { + case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e) + } } } } - - def cleanupById(id: Long) { - val file = getFile(id).getAbsolutePath - files.internalMap.remove(file) - deleteFile(file) - } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala new file mode 100644 index 0000000000000..c4f0f149534a5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import org.apache.spark.{SecurityManager, SparkConf} + +/** + * A [[BroadcastFactory]] implementation that uses a HTTP server as the broadcast medium. + */ +class HttpBroadcastFactory extends BroadcastFactory { + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { + HttpBroadcast.initialize(isDriver, conf, securityMgr) + } + + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = + new HttpBroadcast[T](value_, isLocal, id) + + def stop() { HttpBroadcast.stop() } +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 206765679e9ed..0828035c5d217 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -26,68 +26,12 @@ import org.apache.spark._ import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel} import org.apache.spark.util.Utils -private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) +private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { def value = value_ - def unpersist(removeSource: Boolean) { - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.master.removeBlock(broadcastId) - SparkEnv.get.blockManager.removeBlock(broadcastId) - } - - if (!removeSource) { - //We can't tell BlockManager master to remove blocks from all nodes except driver, - //so we need to save them here in order to store them on disk later. - //This may be inefficient if blocks were already dropped to disk, - //but since unpersist is supposed to be called right after working with - //a broadcast this should not happen (and getting them from memory is cheap). - arrayOfBlocks = new Array[TorrentBlock](totalBlocks) - - for (pid <- 0 until totalBlocks) { - val pieceId = pieceBlockId(pid) - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.getSingle(pieceId) match { - case Some(x) => - arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock] - case None => - throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) - } - } - } - } - - for (pid <- 0 until totalBlocks) { - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.master.removeBlock(pieceBlockId(pid)) - } - } - - if (removeSource) { - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.removeBlock(metaId) - } - } else { - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.dropFromMemory(metaId) - } - - for (i <- 0 until totalBlocks) { - val pieceId = pieceBlockId(i) - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.putSingle( - pieceId, arrayOfBlocks(i), StorageLevel.DISK_ONLY, true) - } - } - arrayOfBlocks = null - } - } - def broadcastId = BroadcastBlockId(id) - private def metaId = BroadcastHelperBlockId(broadcastId, "meta") - private def pieceBlockId(pid: Int) = BroadcastHelperBlockId(broadcastId, "piece" + pid) - private def pieceIds = Array.iterate(0, totalBlocks)(_ + 1).toList TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) @@ -110,6 +54,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo hasBlocks = tInfo.totalBlocks // Store meta-info + val metaId = BroadcastHelperBlockId(broadcastId, "meta") val metaInfo = TorrentInfo(null, totalBlocks, totalBytes) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( @@ -118,7 +63,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo // Store individual pieces for (i <- 0 until totalBlocks) { - val pieceId = pieceBlockId(i) + val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true) @@ -148,7 +93,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo // This creates a tradeoff between memory usage and latency. // Storing copy doubles the memory footprint; not storing doubles deserialization cost. SparkEnv.get.blockManager.putSingle( - broadcastId, value_, StorageLevel.MEMORY_AND_DISK, registerBlocks) + broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) // Remove arrayOfBlocks from memory once value_ is on local cache resetWorkerVariables() @@ -171,6 +116,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo def receiveBroadcast(variableID: Long): Boolean = { // Receive meta-info + val metaId = BroadcastHelperBlockId(broadcastId, "meta") var attemptId = 10 while (attemptId > 0 && totalBlocks == -1) { TorrentBroadcast.synchronized { @@ -193,9 +139,9 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo } // Receive actual blocks - val recvOrder = new Random().shuffle(pieceIds) + val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList) for (pid <- recvOrder) { - val pieceId = pieceBlockId(pid) + val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.getSingle(pieceId) match { case Some(x) => @@ -215,8 +161,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo } -private object TorrentBroadcast -extends Logging { +private object TorrentBroadcast extends Logging { private var initialized = false private var conf: SparkConf = null @@ -289,18 +234,3 @@ private[spark] case class TorrentInfo( @transient var hasBlocks = 0 } - -/** - * A [[BroadcastFactory]] that creates a torrent-based implementation of broadcast. - */ -class TorrentBroadcastFactory extends BroadcastFactory { - - def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { - TorrentBroadcast.initialize(isDriver, conf) - } - - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) = - new TorrentBroadcast[T](value_, isLocal, id, registerBlocks) - - def stop() { TorrentBroadcast.stop() } -} diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala new file mode 100644 index 0000000000000..a51c438c57717 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import org.apache.spark.{SecurityManager, SparkConf} + +/** + * A [[BroadcastFactory]] that creates a torrent-based implementation of broadcast. + */ +class TorrentBroadcastFactory extends BroadcastFactory { + + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { + TorrentBroadcast.initialize(isDriver, conf) + } + + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = + new TorrentBroadcast[T](value_, isLocal, id) + + def stop() { TorrentBroadcast.stop() } + +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 84c87949adae4..ca23513c4dc64 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -209,11 +209,6 @@ private[spark] class BlockManager( } } - /** - * For testing. Returns number of blocks BlockManager knows about that are in memory. - */ - def numberOfBlocksInMemory() = blockInfo.keys.count(memoryStore.contains(_)) - /** * Get storage level of local block. If no info exists for the block, then returns null. */ @@ -817,13 +812,6 @@ private[spark] class BlockManager( } /** - * Drop a block from memory, possibly putting it on disk if applicable. - */ - def dropFromMemory(blockId: BlockId) { - memoryStore.asInstanceOf[MemoryStore].dropFromMemory(blockId) - } - - /** * Remove all blocks belonging to the given RDD. * @return The number of blocks removed. */ diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 7d614aa4726b2..488f1ea9628f5 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -210,27 +210,9 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } /** - * Drop a block from memory, possibly putting it on disk if applicable. - */ - def dropFromMemory(blockId: BlockId) { - val entry = entries.synchronized { entries.get(blockId) } - // This should never be null if called from ensureFreeSpace as only one - // thread should be dropping blocks and removing entries. - // However the check is required in other cases. - if (entry != null) { - val data = if (entry.deserialized) { - Left(entry.value.asInstanceOf[ArrayBuffer[Any]]) - } else { - Right(entry.value.asInstanceOf[ByteBuffer].duplicate()) - } - blockManager.dropFromMemory(blockId, data) - } - } - - /** - * Tries to free up a given amount of space to store a particular block, but can fail and return - * false if either the block is bigger than our memory or it would require replacing another - * block from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that + * Try to free up a given amount of space to store a particular block, but can fail if + * either the block is bigger than our memory or it would require replacing another block + * from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that * don't fit into memory that we want to avoid). * * Assume that a lock is held by the caller to ensure only one thread is dropping blocks. @@ -272,7 +254,19 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) if (maxMemory - (currentMemory - selectedMemory) >= space) { logInfo(selectedBlocks.size + " blocks selected for dropping") for (blockId <- selectedBlocks) { - dropFromMemory(blockId) + val entry = entries.synchronized { entries.get(blockId) } + // This should never be null as only one thread should be dropping + // blocks and removing entries. However the check is still here for + // future safety. + if (entry != null) { + val data = if (entry.deserialized) { + Left(entry.value.asInstanceOf[ArrayBuffer[Any]]) + } else { + Right(entry.value.asInstanceOf[ByteBuffer].duplicate()) + } + val droppedBlockStatus = blockManager.dropFromMemory(blockId, data) + droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) } + } } return ResultWithDroppedBlocks(success = true, droppedBlocks) } else { diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index dad330d6513da..e022accee6d08 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -18,15 +18,9 @@ package org.apache.spark import org.scalatest.FunSuite -import org.scalatest.concurrent.Timeouts._ -import org.scalatest.time.{Millis, Span} -import org.scalatest.concurrent.Eventually._ -import org.scalatest.time.SpanSugar._ -import org.scalatest.matchers.ShouldMatchers._ class BroadcastSuite extends FunSuite with LocalSparkContext { - override def afterEach() { super.afterEach() System.clearProperty("spark.broadcast.factory") @@ -88,47 +82,4 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) } - def blocksExist(sc: SparkContext, numSlaves: Int) = { - val rdd = sc.parallelize(1 to numSlaves, numSlaves) - val workerBlocks = rdd.mapPartitions(_ => { - val blocks = SparkEnv.get.blockManager.numberOfBlocksInMemory() - Seq(blocks).iterator - }) - val totalKnown = workerBlocks.reduce(_ + _) + sc.env.blockManager.numberOfBlocksInMemory() - - totalKnown > 0 - } - - def testUnpersist(bcFactory: String, removeSource: Boolean) { - test("Broadcast unpersist(" + removeSource + ") with " + bcFactory) { - val numSlaves = 2 - System.setProperty("spark.broadcast.factory", bcFactory) - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test") - val list = List(1, 2, 3, 4) - - assert(!blocksExist(sc, numSlaves)) - - val listBroadcast = sc.broadcast(list, true) - val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) - - assert(blocksExist(sc, numSlaves)) - - listBroadcast.unpersist(removeSource) - - eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - blocksExist(sc, numSlaves) should be (false) - } - - if (!removeSource) { - val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) - } - } - } - - for (removeSource <- Seq(true, false)) { - testUnpersist("org.apache.spark.broadcast.HttpBroadcastFactory", removeSource) - testUnpersist("org.apache.spark.broadcast.TorrentBroadcastFactory", removeSource) - } } From d0edef3dda333b5bf43a320acd214f276b8a5b3e Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 26 Mar 2014 14:57:07 -0700 Subject: [PATCH 16/37] Add framework for broadcast cleanup As of this commit, Spark does not clean up broadcast blocks. This will be done in the next commit. --- .../org/apache/spark/ContextCleaner.scala | 134 +++++++++++------- .../scala/org/apache/spark/SparkContext.scala | 6 +- .../apache/spark/broadcast/Broadcast.scala | 6 + .../spark/broadcast/BroadcastFactory.scala | 1 + .../spark/broadcast/BroadcastManager.scala | 4 + .../spark/broadcast/HttpBroadcast.scala | 81 ++++++++--- .../broadcast/HttpBroadcastFactory.scala | 8 ++ .../spark/broadcast/TorrentBroadcast.scala | 86 ++++++----- .../broadcast/TorrentBroadcastFactory.scala | 7 + .../spark/storage/BlockManagerMessages.scala | 2 +- .../storage/BlockManagerSlaveActor.scala | 5 +- .../apache/spark/ContextCleanerSuite.scala | 21 ++- 12 files changed, 249 insertions(+), 112 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index deabf6f5c8c5f..f856a13f84dec 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -21,27 +21,41 @@ import java.lang.ref.{ReferenceQueue, WeakReference} import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -/** Listener class used for testing when any item has been cleaned by the Cleaner class */ -private[spark] trait CleanerListener { - def rddCleaned(rddId: Int) - def shuffleCleaned(shuffleId: Int) -} +/** + * Classes that represent cleaning tasks. + */ +private sealed trait CleanupTask +private case class CleanRDD(rddId: Int) extends CleanupTask +private case class CleanShuffle(shuffleId: Int) extends CleanupTask +private case class CleanBroadcast(broadcastId: Long) extends CleanupTask /** - * Cleans RDDs and shuffle data. + * A WeakReference associated with a CleanupTask. + * + * When the referent object becomes only weakly reachable, the corresponding + * CleanupTaskWeakReference is automatically added to the given reference queue. + */ +private class CleanupTaskWeakReference( + val task: CleanupTask, + referent: AnyRef, + referenceQueue: ReferenceQueue[AnyRef]) + extends WeakReference(referent, referenceQueue) + +/** + * An asynchronous cleaner for RDD, shuffle, and broadcast state. + * + * This maintains a weak reference for each RDD, ShuffleDependency, and Broadcast of interest, + * to be processed when the associated object goes out of scope of the application. Actual + * cleanup is performed in a separate daemon thread. */ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { - /** Classes to represent cleaning tasks */ - private sealed trait CleanupTask - private case class CleanRDD(rddId: Int) extends CleanupTask - private case class CleanShuffle(shuffleId: Int) extends CleanupTask - // TODO: add CleanBroadcast + private val referenceBuffer = new ArrayBuffer[CleanupTaskWeakReference] + with SynchronizedBuffer[CleanupTaskWeakReference] - private val referenceBuffer = new ArrayBuffer[WeakReferenceWithCleanupTask] - with SynchronizedBuffer[WeakReferenceWithCleanupTask] private val referenceQueue = new ReferenceQueue[AnyRef] private val listeners = new ArrayBuffer[CleanerListener] @@ -49,77 +63,64 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private val cleaningThread = new Thread() { override def run() { keepCleaning() }} - private val REF_QUEUE_POLL_TIMEOUT = 100 - @volatile private var stopped = false - private class WeakReferenceWithCleanupTask(referent: AnyRef, val task: CleanupTask) - extends WeakReference(referent, referenceQueue) + /** Attach a listener object to get information of when objects are cleaned. */ + def attachListener(listener: CleanerListener) { + listeners += listener + } - /** Start the cleaner */ + /** Start the cleaner. */ def start() { cleaningThread.setDaemon(true) cleaningThread.setName("ContextCleaner") cleaningThread.start() } - /** Stop the cleaner */ + /** Stop the cleaner. */ def stop() { stopped = true cleaningThread.interrupt() } - /** - * Register a RDD for cleanup when it is garbage collected. - */ + /** Register a RDD for cleanup when it is garbage collected. */ def registerRDDForCleanup(rdd: RDD[_]) { registerForCleanup(rdd, CleanRDD(rdd.id)) } - /** - * Register a shuffle dependency for cleanup when it is garbage collected. - */ + /** Register a ShuffleDependency for cleanup when it is garbage collected. */ def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _]) { registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId)) } - /** Cleanup RDD. */ - def cleanupRDD(rdd: RDD[_]) { - doCleanupRDD(rdd.id) - } - - /** Cleanup shuffle. */ - def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) { - doCleanupShuffle(shuffleDependency.shuffleId) - } - - /** Attach a listener object to get information of when objects are cleaned. */ - def attachListener(listener: CleanerListener) { - listeners += listener + /** Register a Broadcast for cleanup when it is garbage collected. */ + def registerBroadcastForCleanup[T](broadcast: Broadcast[T]) { + registerForCleanup(broadcast, CleanBroadcast(broadcast.id)) } /** Register an object for cleanup. */ private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask) { - referenceBuffer += new WeakReferenceWithCleanupTask(objectForCleanup, task) + referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue) } - /** Keep cleaning RDDs and shuffle data */ + /** Keep cleaning RDD, shuffle, and broadcast state. */ private def keepCleaning() { - while (!isStopped) { + while (!stopped) { try { - val reference = Option(referenceQueue.remove(REF_QUEUE_POLL_TIMEOUT)) - .map(_.asInstanceOf[WeakReferenceWithCleanupTask]) + val reference = Option(referenceQueue.remove(ContextCleaner.REF_QUEUE_POLL_TIMEOUT)) + .map(_.asInstanceOf[CleanupTaskWeakReference]) reference.map(_.task).foreach { task => logDebug("Got cleaning task " + task) referenceBuffer -= reference.get task match { case CleanRDD(rddId) => doCleanupRDD(rddId) case CleanShuffle(shuffleId) => doCleanupShuffle(shuffleId) + case CleanBroadcast(broadcastId) => doCleanupBroadcast(broadcastId) } } } catch { case ie: InterruptedException => - if (!isStopped) logWarning("Cleaning thread interrupted") + if (!stopped) logWarning("Cleaning thread interrupted") case t: Throwable => logError("Error in cleaning thread", t) } } @@ -129,7 +130,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private def doCleanupRDD(rddId: Int) { try { logDebug("Cleaning RDD " + rddId) - sc.unpersistRDD(rddId, false) + sc.unpersistRDD(rddId, blocking = false) listeners.foreach(_.rddCleaned(rddId)) logInfo("Cleaned RDD " + rddId) } catch { @@ -150,10 +151,47 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } } - private def mapOutputTrackerMaster = - sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + /** Perform broadcast cleanup. */ + private def doCleanupBroadcast(broadcastId: Long) { + try { + logDebug("Cleaning broadcast " + broadcastId) + broadcastManager.unbroadcast(broadcastId, removeFromDriver = true) + listeners.foreach(_.broadcastCleaned(broadcastId)) + logInfo("Cleaned broadcast " + broadcastId) + } catch { + case t: Throwable => logError("Error cleaning broadcast " + broadcastId, t) + } + } private def blockManagerMaster = sc.env.blockManager.master + private def broadcastManager = sc.env.broadcastManager + private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + + // Used for testing + + private[spark] def cleanupRDD(rdd: RDD[_]) { + doCleanupRDD(rdd.id) + } + + private[spark] def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) { + doCleanupShuffle(shuffleDependency.shuffleId) + } - private def isStopped = stopped + private[spark] def cleanupBroadcast[T](broadcast: Broadcast[T]) { + doCleanupBroadcast(broadcast.id) + } + +} + +private object ContextCleaner { + private val REF_QUEUE_POLL_TIMEOUT = 100 +} + +/** + * Listener class used for testing when any item has been cleaned by the Cleaner class. + */ +private[spark] trait CleanerListener { + def rddCleaned(rddId: Int) + def shuffleCleaned(shuffleId: Int) + def broadcastCleaned(broadcastId: Long) } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 5cd2caed10297..689180fcd719b 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -642,7 +642,11 @@ class SparkContext( * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. * The variable will be sent to each cluster only once. */ - def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal) + def broadcast[T](value: T) = { + val bc = env.broadcastManager.newBroadcast[T](value, isLocal) + cleaner.registerBroadcastForCleanup(bc) + bc + } /** * Add a file to be downloaded with this Spark job on every node. diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index e3e1e4f29b107..d75b9acfb7aa0 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -50,6 +50,12 @@ import java.io.Serializable abstract class Broadcast[T](val id: Long) extends Serializable { def value: T + /** + * Remove all persisted state associated with this broadcast. + * @param removeFromDriver Whether to remove state from the driver. + */ + def unpersist(removeFromDriver: Boolean) + // We cannot have an abstract readObject here due to some weird issues with // readObject having to be 'private' in sub-classes. diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index 0a0bb6cca336c..850650951e603 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -29,5 +29,6 @@ import org.apache.spark.SparkConf trait BroadcastFactory { def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] + def unbroadcast(id: Long, removeFromDriver: Boolean) def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index 746e23e81931a..85d62aae03959 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -60,4 +60,8 @@ private[spark] class BroadcastManager( broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) } + def unbroadcast(id: Long, removeFromDriver: Boolean) { + broadcastFactory.unbroadcast(id, removeFromDriver) + } + } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 374180e472805..89361efec44a4 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -21,10 +21,9 @@ import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream} import java.net.{URL, URLConnection, URI} import java.util.concurrent.TimeUnit -import it.unimi.dsi.fastutil.io.FastBufferedInputStream -import it.unimi.dsi.fastutil.io.FastBufferedOutputStream +import it.unimi.dsi.fastutil.io.{FastBufferedInputStream, FastBufferedOutputStream} -import org.apache.spark.{SparkConf, HttpServer, Logging, SecurityManager, SparkEnv} +import org.apache.spark.{HttpServer, Logging, SecurityManager, SparkConf, SparkEnv} import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils} @@ -32,18 +31,27 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { - def value = value_ + override def value = value_ - def blockId = BroadcastBlockId(id) + val blockId = BroadcastBlockId(id) HttpBroadcast.synchronized { - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) + SparkEnv.get.blockManager.putSingle( + blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) } if (!isLocal) { HttpBroadcast.write(id, value_) } + /** + * Remove all persisted state associated with this HTTP broadcast. + * @param removeFromDriver Whether to remove state from the driver. + */ + override def unpersist(removeFromDriver: Boolean) { + HttpBroadcast.unpersist(id, removeFromDriver) + } + // Called by JVM when deserializing an object private def readObject(in: ObjectInputStream) { in.defaultReadObject() @@ -54,7 +62,8 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea logInfo("Started reading broadcast variable " + id) val start = System.nanoTime value_ = HttpBroadcast.read[T](id) - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) + SparkEnv.get.blockManager.putSingle( + blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) val time = (System.nanoTime - start) / 1e9 logInfo("Reading broadcast variable " + id + " took " + time + " s") } @@ -63,7 +72,7 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea } } -private object HttpBroadcast extends Logging { +private[spark] object HttpBroadcast extends Logging { private var initialized = false private var broadcastDir: File = null @@ -74,7 +83,7 @@ private object HttpBroadcast extends Logging { private var securityManager: SecurityManager = null // TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist - private val files = new TimeStampedHashSet[String] + val files = new TimeStampedHashSet[String] private var cleaner: MetadataCleaner = null private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES).toInt @@ -122,8 +131,10 @@ private object HttpBroadcast extends Logging { logInfo("Broadcast server started at " + serverUri) } + def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name) + def write(id: Long, value: Any) { - val file = new File(broadcastDir, BroadcastBlockId(id).name) + val file = getFile(id) val out: OutputStream = { if (compress) { compressionCodec.compressedOutputStream(new FileOutputStream(file)) @@ -146,7 +157,7 @@ private object HttpBroadcast extends Logging { if (securityManager.isAuthenticationEnabled()) { logDebug("broadcast security enabled") val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager) - uc = newuri.toURL().openConnection() + uc = newuri.toURL.openConnection() uc.setAllowUserInteraction(false) } else { logDebug("broadcast not using security") @@ -155,7 +166,7 @@ private object HttpBroadcast extends Logging { val in = { uc.setReadTimeout(httpReadTimeout) - val inputStream = uc.getInputStream(); + val inputStream = uc.getInputStream if (compress) { compressionCodec.compressedInputStream(inputStream) } else { @@ -169,20 +180,50 @@ private object HttpBroadcast extends Logging { obj } - def cleanup(cleanupTime: Long) { + /** + * Remove all persisted blocks associated with this HTTP broadcast on the executors. + * If removeFromDriver is true, also remove these persisted blocks on the driver + * and delete the associated broadcast file. + */ + def unpersist(id: Long, removeFromDriver: Boolean) = synchronized { + //SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver) + if (removeFromDriver) { + val file = new File(broadcastDir, BroadcastBlockId(id).name) + files.remove(file.toString) + deleteBroadcastFile(file) + } + } + + /** + * Periodically clean up old broadcasts by removing the associated map entries and + * deleting the associated files. + */ + private def cleanup(cleanupTime: Long) { val iterator = files.internalMap.entrySet().iterator() while(iterator.hasNext) { val entry = iterator.next() val (file, time) = (entry.getKey, entry.getValue) if (time < cleanupTime) { - try { - iterator.remove() - new File(file.toString).delete() - logInfo("Deleted broadcast file '" + file + "'") - } catch { - case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e) - } + iterator.remove() + deleteBroadcastFile(new File(file.toString)) } } } + + /** Delete the given broadcast file. */ + private def deleteBroadcastFile(file: File) { + try { + if (!file.exists()) { + logWarning("Broadcast file to be deleted does not exist: %s".format(file)) + } else if (file.delete()) { + logInfo("Deleted broadcast file: %s".format(file)) + } else { + logWarning("Could not delete broadcast file: %s".format(file)) + } + } catch { + case e: Exception => + logWarning("Exception while deleting broadcast file: %s".format(file), e) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala index c4f0f149534a5..4affa922156c9 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala @@ -31,4 +31,12 @@ class HttpBroadcastFactory extends BroadcastFactory { new HttpBroadcast[T](value_, isLocal, id) def stop() { HttpBroadcast.stop() } + + /** + * Remove all persisted state associated with the HTTP broadcast with the given ID. + * @param removeFromDriver Whether to remove state from the driver. + */ + def unbroadcast(id: Long, removeFromDriver: Boolean) { + HttpBroadcast.unpersist(id, removeFromDriver) + } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 0828035c5d217..07ef54bb120b9 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -29,12 +29,13 @@ import org.apache.spark.util.Utils private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { - def value = value_ + override def value = value_ - def broadcastId = BroadcastBlockId(id) + val broadcastId = BroadcastBlockId(id) TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) + SparkEnv.get.blockManager.putSingle( + broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) } @transient var arrayOfBlocks: Array[TorrentBlock] = null @@ -47,8 +48,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo } def sendBroadcast() { - var tInfo = TorrentBroadcast.blockifyObject(value_) - + val tInfo = TorrentBroadcast.blockifyObject(value_) totalBlocks = tInfo.totalBlocks totalBytes = tInfo.totalBytes hasBlocks = tInfo.totalBlocks @@ -58,7 +58,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo val metaInfo = TorrentInfo(null, totalBlocks, totalBytes) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( - metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, true) + metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, tellMaster = true) } // Store individual pieces @@ -66,11 +66,19 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( - pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true) + pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, tellMaster = true) } } } + /** + * Remove all persisted state associated with this HTTP broadcast. + * @param removeFromDriver Whether to remove state from the driver. + */ + override def unpersist(removeFromDriver: Boolean) { + TorrentBroadcast.unpersist(id, removeFromDriver) + } + // Called by JVM when deserializing an object private def readObject(in: ObjectInputStream) { in.defaultReadObject() @@ -86,18 +94,18 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo // Initialize @transient variables that will receive garbage values from the master. resetWorkerVariables() - if (receiveBroadcast(id)) { + if (receiveBroadcast()) { value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - // Store the merged copy in cache so that the next worker doesn't need to rebuild it. - // This creates a tradeoff between memory usage and latency. - // Storing copy doubles the memory footprint; not storing doubles deserialization cost. + /* Store the merged copy in cache so that the next worker doesn't need to rebuild it. + * This creates a trade-off between memory usage and latency. Storing copy doubles + * the memory footprint; not storing doubles deserialization cost. */ SparkEnv.get.blockManager.putSingle( - broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) + broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) // Remove arrayOfBlocks from memory once value_ is on local cache resetWorkerVariables() - } else { + } else { logError("Reading broadcast variable " + id + " failed") } @@ -114,7 +122,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo hasBlocks = 0 } - def receiveBroadcast(variableID: Long): Boolean = { + def receiveBroadcast(): Boolean = { // Receive meta-info val metaId = BroadcastHelperBlockId(broadcastId, "meta") var attemptId = 10 @@ -148,7 +156,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock] hasBlocks += 1 SparkEnv.get.blockManager.putSingle( - pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true) + pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, tellMaster = true) case None => throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) @@ -156,15 +164,17 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo } } - (hasBlocks == totalBlocks) + hasBlocks == totalBlocks } } -private object TorrentBroadcast extends Logging { - +private[spark] object TorrentBroadcast extends Logging { private var initialized = false private var conf: SparkConf = null + + lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 + def initialize(_isDriver: Boolean, conf: SparkConf) { TorrentBroadcast.conf = conf //TODO: we might have to fix it in tests synchronized { @@ -178,39 +188,37 @@ private object TorrentBroadcast extends Logging { initialized = false } - lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 - def blockifyObject[T](obj: T): TorrentInfo = { val byteArray = Utils.serialize[T](obj) val bais = new ByteArrayInputStream(byteArray) - var blockNum = (byteArray.length / BLOCK_SIZE) + var blockNum = byteArray.length / BLOCK_SIZE if (byteArray.length % BLOCK_SIZE != 0) { blockNum += 1 } - var retVal = new Array[TorrentBlock](blockNum) - var blockID = 0 + val blocks = new Array[TorrentBlock](blockNum) + var blockId = 0 for (i <- 0 until (byteArray.length, BLOCK_SIZE)) { val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i) - var tempByteArray = new Array[Byte](thisBlockSize) - val hasRead = bais.read(tempByteArray, 0, thisBlockSize) + val tempByteArray = new Array[Byte](thisBlockSize) + bais.read(tempByteArray, 0, thisBlockSize) - retVal(blockID) = new TorrentBlock(blockID, tempByteArray) - blockID += 1 + blocks(blockId) = new TorrentBlock(blockId, tempByteArray) + blockId += 1 } bais.close() - val tInfo = TorrentInfo(retVal, blockNum, byteArray.length) - tInfo.hasBlocks = blockNum - - tInfo + val info = TorrentInfo(blocks, blockNum, byteArray.length) + info.hasBlocks = blockNum + info } - def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock], - totalBytes: Int, - totalBlocks: Int): T = { + def unBlockifyObject[T]( + arrayOfBlocks: Array[TorrentBlock], + totalBytes: Int, + totalBlocks: Int): T = { val retByteArray = new Array[Byte](totalBytes) for (i <- 0 until totalBlocks) { System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, @@ -219,6 +227,14 @@ private object TorrentBroadcast extends Logging { Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader) } + /** + * Remove all persisted blocks associated with this torrent broadcast on the executors. + * If removeFromDriver is true, also remove these persisted blocks on the driver. + */ + def unpersist(id: Long, removeFromDriver: Boolean) = synchronized { + //SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver) + } + } private[spark] case class TorrentBlock( @@ -227,7 +243,7 @@ private[spark] case class TorrentBlock( extends Serializable private[spark] case class TorrentInfo( - @transient arrayOfBlocks : Array[TorrentBlock], + @transient arrayOfBlocks: Array[TorrentBlock], totalBlocks: Int, totalBytes: Int) extends Serializable { diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala index a51c438c57717..eabe792b550bb 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -33,4 +33,11 @@ class TorrentBroadcastFactory extends BroadcastFactory { def stop() { TorrentBroadcast.stop() } + /** + * Remove all persisted state associated with the torrent broadcast with the given ID. + * @param removeFromDriver Whether to remove state from the driver. + */ + def unbroadcast(id: Long, removeFromDriver: Boolean) { + TorrentBroadcast.unpersist(id, removeFromDriver) + } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 50ea4e31ce509..4c5b31d0abe44 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -35,7 +35,7 @@ private[storage] object BlockManagerMessages { case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave // Remove all blocks belonging to a specific shuffle. - case class RemoveShuffle(shuffleId: Int) + case class RemoveShuffle(shuffleId: Int) extends ToBlockManagerSlave ////////////////////////////////////////////////////////////////////////////////// diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index a6ff147c1d3e6..9a12481b7f6d5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -29,8 +29,9 @@ import org.apache.spark.storage.BlockManagerMessages._ private[storage] class BlockManagerSlaveActor( blockManager: BlockManager, - mapOutputTracker: MapOutputTracker - ) extends Actor { + mapOutputTracker: MapOutputTracker) + extends Actor { + override def receive = { case RemoveBlock(blockId) => diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index b07f8817b7974..11e22145ebb88 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -17,7 +17,10 @@ package org.apache.spark +import java.lang.ref.WeakReference + import scala.collection.mutable.{ArrayBuffer, HashSet, SynchronizedSet} +import scala.util.Random import org.scalatest.{BeforeAndAfter, FunSuite} import org.scalatest.concurrent.Eventually @@ -26,9 +29,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext._ import org.apache.spark.storage.{RDDBlockId, ShuffleBlockId} -import org.apache.spark.rdd.{ShuffleCoGroupSplitDep, RDD} -import scala.util.Random -import java.lang.ref.WeakReference +import org.apache.spark.rdd.RDD class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { @@ -67,7 +68,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo test("automatically cleanup RDD") { var rdd = newRDD.persist() rdd.count() - + // test that GC does not cause RDD cleanup due to a strong reference val preGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id)) runGC() @@ -171,11 +172,16 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo /** Class to test whether RDDs, shuffles, etc. have been successfully cleaned. */ -class CleanerTester(sc: SparkContext, rddIds: Seq[Int] = Nil, shuffleIds: Seq[Int] = Nil) +class CleanerTester( + sc: SparkContext, + rddIds: Seq[Int] = Seq.empty, + shuffleIds: Seq[Int] = Seq.empty, + broadcastIds: Seq[Long] = Seq.empty) extends Logging { val toBeCleanedRDDIds = new HashSet[Int] with SynchronizedSet[Int] ++= rddIds val toBeCleanedShuffleIds = new HashSet[Int] with SynchronizedSet[Int] ++= shuffleIds + val toBeCleanedBroadcstIds = new HashSet[Long] with SynchronizedSet[Long] ++= broadcastIds val cleanerListener = new CleanerListener { def rddCleaned(rddId: Int): Unit = { @@ -187,6 +193,11 @@ class CleanerTester(sc: SparkContext, rddIds: Seq[Int] = Nil, shuffleIds: Seq[In toBeCleanedShuffleIds -= shuffleId logInfo("Shuffle " + shuffleId + " cleaned") } + + def broadcastCleaned(broadcastId: Long): Unit = { + toBeCleanedBroadcstIds -= broadcastId + logInfo("Broadcast" + broadcastId + " cleaned") + } } val MAX_VALIDATION_ATTEMPTS = 10 From 544ac866edf21230140fe56ee7a428fe0ab86329 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 26 Mar 2014 15:11:42 -0700 Subject: [PATCH 17/37] Clean up broadcast blocks through BlockManager* --- .../apache/spark/broadcast/HttpBroadcast.scala | 2 +- .../spark/broadcast/TorrentBroadcast.scala | 2 +- .../org/apache/spark/storage/BlockManager.scala | 14 +++++++++++++- .../spark/storage/BlockManagerMaster.scala | 7 +++++++ .../spark/storage/BlockManagerMasterActor.scala | 16 +++++++++++++--- .../spark/storage/BlockManagerMessages.scala | 13 ++++++++++--- .../spark/storage/BlockManagerSlaveActor.scala | 3 +++ .../main/scala/org/apache/spark/util/Utils.scala | 8 ++++---- .../org/apache/spark/ContextCleanerSuite.scala | 2 +- 9 files changed, 53 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 89361efec44a4..4985d4202ed6b 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -186,7 +186,7 @@ private[spark] object HttpBroadcast extends Logging { * and delete the associated broadcast file. */ def unpersist(id: Long, removeFromDriver: Boolean) = synchronized { - //SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver) + SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver) if (removeFromDriver) { val file = new File(broadcastDir, BroadcastBlockId(id).name) files.remove(file.toString) diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 07ef54bb120b9..51f1592cef752 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -232,7 +232,7 @@ private[spark] object TorrentBroadcast extends Logging { * If removeFromDriver is true, also remove these persisted blocks on the driver. */ def unpersist(id: Long, removeFromDriver: Boolean) = synchronized { - //SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver) + SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver) } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index ca23513c4dc64..3c0941e195724 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -820,10 +820,22 @@ private[spark] class BlockManager( // from RDD.id to blocks. logInfo("Removing RDD " + rddId) val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) - blocksToRemove.foreach(blockId => removeBlock(blockId, tellMaster = false)) + blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) } blocksToRemove.size } + /** + * Remove all blocks belonging to the given broadcast. + */ + def removeBroadcast(broadcastId: Long) { + logInfo("Removing broadcast " + broadcastId) + val blocksToRemove = blockInfo.keys.filter(_.isBroadcast).collect { + case bid: BroadcastBlockId if bid.broadcastId == broadcastId => bid + case bid: BroadcastHelperBlockId if bid.broadcastId.broadcastId == broadcastId => bid + } + blocksToRemove.foreach { blockId => removeBlock(blockId) } + } + /** * Remove a block from both memory and disk. */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index ff3f22b3b092a..4579c0d959553 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -126,6 +126,13 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log askDriverWithReply(RemoveShuffle(shuffleId)) } + /** + * Remove all blocks belonging to the given broadcast. + */ + def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean) { + askDriverWithReply(RemoveBroadcast(broadcastId, removeFromMaster)) + } + /** * Return the memory status for each block manager, in the form of a map from * the block manager's id to two long values. The first value is the maximum diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 646ccb7fa74f6..4cc4227fd87e2 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -100,6 +100,10 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus removeShuffle(shuffleId) sender ! true + case RemoveBroadcast(broadcastId, removeFromDriver) => + removeBroadcast(broadcastId, removeFromDriver) + sender ! true + case RemoveBlock(blockId) => removeBlockFromWorkers(blockId) sender ! true @@ -151,9 +155,15 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus private def removeShuffle(shuffleId: Int) { // Nothing to do in the BlockManagerMasterActor data structures val removeMsg = RemoveShuffle(shuffleId) - blockManagerInfo.values.foreach { bm => - bm.slaveActor ! removeMsg - } + blockManagerInfo.values.foreach { bm => bm.slaveActor ! removeMsg } + } + + private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) { + // TODO(aor): Consolidate usages of + val removeMsg = RemoveBroadcast(broadcastId) + blockManagerInfo.values + .filter { info => removeFromDriver || info.blockManagerId.executorId != "" } + .foreach { bm => bm.slaveActor ! removeMsg } } private def removeBlockManager(blockManagerId: BlockManagerId) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 4c5b31d0abe44..3ea710ebc786e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -22,9 +22,11 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} import akka.actor.ActorRef private[storage] object BlockManagerMessages { + ////////////////////////////////////////////////////////////////////////////////// // Messages from the master to slaves. ////////////////////////////////////////////////////////////////////////////////// + sealed trait ToBlockManagerSlave // Remove a block from the slaves that have it. This can only be used to remove @@ -37,10 +39,15 @@ private[storage] object BlockManagerMessages { // Remove all blocks belonging to a specific shuffle. case class RemoveShuffle(shuffleId: Int) extends ToBlockManagerSlave + // Remove all blocks belonging to a specific broadcast. + case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true) + extends ToBlockManagerSlave + ////////////////////////////////////////////////////////////////////////////////// // Messages from slaves to the master. ////////////////////////////////////////////////////////////////////////////////// + sealed trait ToBlockManagerMaster case class RegisterBlockManager( @@ -57,8 +64,7 @@ private[storage] object BlockManagerMessages { var storageLevel: StorageLevel, var memSize: Long, var diskSize: Long) - extends ToBlockManagerMaster - with Externalizable { + extends ToBlockManagerMaster with Externalizable { def this() = this(null, null, null, 0, 0) // For deserialization only @@ -80,7 +86,8 @@ private[storage] object BlockManagerMessages { } object UpdateBlockInfo { - def apply(blockManagerId: BlockManagerId, + def apply( + blockManagerId: BlockManagerId, blockId: BlockId, storageLevel: StorageLevel, memSize: Long, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 9a12481b7f6d5..8c2ccbe6a7e66 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -46,5 +46,8 @@ class BlockManagerSlaveActor( if (mapOutputTracker != null) { mapOutputTracker.unregisterShuffle(shuffleId) } + + case RemoveBroadcast(broadcastId, _) => + blockManager.removeBroadcast(broadcastId) } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index ad87fda140476..e541591ee7582 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -461,10 +461,10 @@ private[spark] object Utils extends Logging { private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]() def parseHostPort(hostPort: String): (String, Int) = { - { - // Check cache first. - val cached = hostPortParseResults.get(hostPort) - if (cached != null) return cached + // Check cache first. + val cached = hostPortParseResults.get(hostPort) + if (cached != null) { + return cached } val indx: Int = hostPort.lastIndexOf(':') diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 11e22145ebb88..77d9825434706 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -28,8 +28,8 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext._ -import org.apache.spark.storage.{RDDBlockId, ShuffleBlockId} import org.apache.spark.rdd.RDD +import org.apache.spark.storage.{RDDBlockId, ShuffleBlockId} class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { From e95479cd63b3259beddea278befd0bdee89bb17e Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 27 Mar 2014 14:37:51 -0700 Subject: [PATCH 18/37] Add tests for unpersisting broadcast There is not currently a way to query the blocks on the executors, an operation that is deceptively simple to accomplish. This commit adds this mechanism in order to verify that blocks are in fact persisted/unpersisted on the executors in the tests. --- .../apache/spark/broadcast/Broadcast.scala | 16 +- .../spark/broadcast/HttpBroadcast.scala | 13 +- .../spark/broadcast/TorrentBroadcast.scala | 13 +- .../apache/spark/storage/BlockManager.scala | 20 +- .../spark/storage/BlockManagerMaster.scala | 18 ++ .../storage/BlockManagerMasterActor.scala | 24 +- .../spark/storage/BlockManagerMessages.scala | 7 + .../storage/BlockManagerSlaveActor.scala | 7 +- .../org/apache/spark/BroadcastSuite.scala | 254 +++++++++++++++--- 9 files changed, 309 insertions(+), 63 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index d75b9acfb7aa0..3a2fef05861e6 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -48,16 +48,26 @@ import java.io.Serializable * @tparam T Type of the data contained in the broadcast variable. */ abstract class Broadcast[T](val id: Long) extends Serializable { + + /** + * Whether this Broadcast is actually usable. This should be false once persisted state is + * removed from the driver. + */ + protected var isValid: Boolean = true + def value: T /** - * Remove all persisted state associated with this broadcast. + * Remove all persisted state associated with this broadcast. Overriding implementations + * should set isValid to false if persisted state is also removed from the driver. + * * @param removeFromDriver Whether to remove state from the driver. + * If true, the resulting broadcast should no longer be valid. */ def unpersist(removeFromDriver: Boolean) - // We cannot have an abstract readObject here due to some weird issues with - // readObject having to be 'private' in sub-classes. + // We cannot define abstract readObject and writeObject here due to some weird issues + // with these methods having to be 'private' in sub-classes. override def toString = "Broadcast(" + id + ")" } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 4985d4202ed6b..d5e3d60a5b2b7 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -17,8 +17,8 @@ package org.apache.spark.broadcast -import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream} -import java.net.{URL, URLConnection, URI} +import java.io.{File, FileOutputStream, ObjectInputStream, ObjectOutputStream, OutputStream} +import java.net.{URI, URL, URLConnection} import java.util.concurrent.TimeUnit import it.unimi.dsi.fastutil.io.{FastBufferedInputStream, FastBufferedOutputStream} @@ -49,10 +49,17 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea * @param removeFromDriver Whether to remove state from the driver. */ override def unpersist(removeFromDriver: Boolean) { + isValid = !removeFromDriver HttpBroadcast.unpersist(id, removeFromDriver) } - // Called by JVM when deserializing an object + // Used by the JVM when serializing this object + private def writeObject(out: ObjectOutputStream) { + assert(isValid, "Attempted to serialize a broadcast variable that has been destroyed!") + out.defaultWriteObject() + } + + // Used by the JVM when deserializing this object private def readObject(in: ObjectInputStream) { in.defaultReadObject() HttpBroadcast.synchronized { diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 51f1592cef752..ace71575f5390 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -17,12 +17,12 @@ package org.apache.spark.broadcast -import java.io._ +import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream} import scala.math import scala.util.Random -import org.apache.spark._ +import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel} import org.apache.spark.util.Utils @@ -76,10 +76,17 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo * @param removeFromDriver Whether to remove state from the driver. */ override def unpersist(removeFromDriver: Boolean) { + isValid = !removeFromDriver TorrentBroadcast.unpersist(id, removeFromDriver) } - // Called by JVM when deserializing an object + // Used by the JVM when serializing this object + private def writeObject(out: ObjectOutputStream) { + assert(isValid, "Attempted to serialize a broadcast variable that has been destroyed!") + out.defaultWriteObject() + } + + // Used by the JVM when deserializing this object private def readObject(in: ObjectInputStream) { in.defaultReadObject() TorrentBroadcast.synchronized { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 3c0941e195724..78dc32b4b1525 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -29,7 +29,7 @@ import akka.actor.{ActorSystem, Cancellable, Props} import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream} import sun.nio.ch.DirectBuffer -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv, SparkException, MapOutputTracker} +import org.apache.spark._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.serializer.Serializer @@ -58,7 +58,7 @@ private[spark] class BlockManager( private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] - private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) + private[storage] val memoryStore = new MemoryStore(this, maxMemory) private[storage] val diskStore = new DiskStore(this, diskBlockManager) // If we use Netty for shuffle, start a new Netty-based shuffle sender service. @@ -210,9 +210,9 @@ private[spark] class BlockManager( } /** - * Get storage level of local block. If no info exists for the block, then returns null. + * Get storage level of local block. If no info exists for the block, return None. */ - def getLevel(blockId: BlockId): StorageLevel = blockInfo.get(blockId).map(_.level).orNull + def getLevel(blockId: BlockId): Option[StorageLevel] = blockInfo.get(blockId).map(_.level) /** * Tell the master about the current storage status of a block. This will send a block update @@ -496,9 +496,8 @@ private[spark] class BlockManager( /** * A short circuited method to get a block writer that can write data directly to disk. - * The Block will be appended to the File specified by filename. - * This is currently used for writing shuffle files out. Callers should handle error - * cases. + * The Block will be appended to the File specified by filename. This is currently used for + * writing shuffle files out. Callers should handle error cases. */ def getDiskWriter( blockId: BlockId, @@ -816,8 +815,7 @@ private[spark] class BlockManager( * @return The number of blocks removed. */ def removeRdd(rddId: Int): Int = { - // TODO: Instead of doing a linear scan on the blockInfo map, create another map that maps - // from RDD.id to blocks. + // TODO: Avoid a linear scan by creating another mapping of RDD.id to blocks. logInfo("Removing RDD " + rddId) val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) } @@ -827,13 +825,13 @@ private[spark] class BlockManager( /** * Remove all blocks belonging to the given broadcast. */ - def removeBroadcast(broadcastId: Long) { + def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) { logInfo("Removing broadcast " + broadcastId) val blocksToRemove = blockInfo.keys.filter(_.isBroadcast).collect { case bid: BroadcastBlockId if bid.broadcastId == broadcastId => bid case bid: BroadcastHelperBlockId if bid.broadcastId.broadcastId == broadcastId => bid } - blocksToRemove.foreach { blockId => removeBlock(blockId) } + blocksToRemove.foreach { blockId => removeBlock(blockId, removeFromDriver) } } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 4579c0d959553..674322e3034c8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -147,6 +147,24 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log askDriverWithReply[Array[StorageStatus]](GetStorageStatus) } + /** + * Mainly for testing. Ask the driver to query all executors for their storage levels + * regarding this block. This provides an avenue for the driver to learn the storage + * levels of blocks it has not been informed of. + * + * WARNING: This could lead to deadlocks if there are any outstanding messages the + * executors are already expecting from the driver. In this case, while the driver is + * waiting for the executors to respond to its GetStorageLevel query, the executors + * are also waiting for a response from the driver to a prior message. + * + * The interim solution is to wait for a brief window of time to pass before asking. + * This should suffice, since this mechanism is largely introduced for testing only. + */ + def askForStorageLevels(blockId: BlockId, waitTimeMs: Long = 1000) = { + Thread.sleep(waitTimeMs) + askDriverWithReply[Map[BlockManagerId, StorageLevel]](AskForStorageLevels(blockId)) + } + /** Stop the driver actor, called only on the Spark driver node */ def stop() { if (driverActor != null) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 4cc4227fd87e2..f83c26dafe2e9 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -21,7 +21,7 @@ import java.util.{HashMap => JHashMap} import scala.collection.mutable import scala.collection.JavaConversions._ -import scala.concurrent.Future +import scala.concurrent.{Await, Future} import scala.concurrent.duration._ import akka.actor.{Actor, ActorRef, Cancellable} @@ -126,6 +126,9 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus case HeartBeat(blockManagerId) => sender ! heartBeat(blockManagerId) + case AskForStorageLevels(blockId) => + sender ! askForStorageLevels(blockId) + case other => logWarning("Got unknown message: " + other) } @@ -158,6 +161,11 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus blockManagerInfo.values.foreach { bm => bm.slaveActor ! removeMsg } } + /** + * Delegate RemoveBroadcast messages to each BlockManager because the master may not notified + * of all broadcast blocks. If removeFromDriver is false, broadcast blocks are only removed + * from the executors, but not from the driver. + */ private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) { // TODO(aor): Consolidate usages of val removeMsg = RemoveBroadcast(broadcastId) @@ -246,6 +254,19 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus }.toArray } + // For testing. Ask all block managers for the given block's local storage level, if any. + private def askForStorageLevels(blockId: BlockId): Map[BlockManagerId, StorageLevel] = { + val getStorageLevel = GetStorageLevel(blockId) + blockManagerInfo.values.flatMap { info => + val future = info.slaveActor.ask(getStorageLevel)(akkaTimeout) + val result = Await.result(future, akkaTimeout) + if (result != null) { + // If the block does not exist on the slave, the slave replies None + result.asInstanceOf[Option[StorageLevel]].map { reply => (info.blockManagerId, reply) } + } else None + }.toMap + } + private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { @@ -329,6 +350,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // Note that this logic will select the same node multiple times if there aren't enough peers Array.tabulate[BlockManagerId](size) { i => peers((selfIndex + i + 1) % peers.length) }.toSeq } + } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 3ea710ebc786e..1d3e94c4b6533 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -43,6 +43,9 @@ private[storage] object BlockManagerMessages { case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true) extends ToBlockManagerSlave + // For testing. Ask the slave for the block's storage level. + case class GetStorageLevel(blockId: BlockId) extends ToBlockManagerSlave + ////////////////////////////////////////////////////////////////////////////////// // Messages from slaves to the master. @@ -116,4 +119,8 @@ private[storage] object BlockManagerMessages { case object ExpireDeadHosts extends ToBlockManagerMaster case object GetStorageStatus extends ToBlockManagerMaster + + // For testing. Have the master ask all slaves for the given block's storage level. + case class AskForStorageLevels(blockId: BlockId) extends ToBlockManagerMaster + } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 8c2ccbe6a7e66..85b8ec40c0ea3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -47,7 +47,10 @@ class BlockManagerSlaveActor( mapOutputTracker.unregisterShuffle(shuffleId) } - case RemoveBroadcast(broadcastId, _) => - blockManager.removeBroadcast(broadcastId) + case RemoveBroadcast(broadcastId, removeFromDriver) => + blockManager.removeBroadcast(broadcastId, removeFromDriver) + + case GetStorageLevel(blockId) => + sender ! blockManager.getLevel(blockId) } } diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index e022accee6d08..a462654197ea0 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -19,67 +19,241 @@ package org.apache.spark import org.scalatest.FunSuite +import org.apache.spark.storage._ +import org.apache.spark.broadcast.HttpBroadcast +import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId} + class BroadcastSuite extends FunSuite with LocalSparkContext { - override def afterEach() { - super.afterEach() - System.clearProperty("spark.broadcast.factory") - } + private val httpConf = broadcastConf("HttpBroadcastFactory") + private val torrentConf = broadcastConf("TorrentBroadcastFactory") test("Using HttpBroadcast locally") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") - sc = new SparkContext("local", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === Set((1, 10), (2, 10))) + sc = new SparkContext("local", "test", httpConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === Set((1, 10), (2, 10))) } test("Accessing HttpBroadcast variables from multiple threads") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") - sc = new SparkContext("local[10]", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet) + sc = new SparkContext("local[10]", "test", httpConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet) } test("Accessing HttpBroadcast variables in a local cluster") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") val numSlaves = 4 - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", httpConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) } test("Using TorrentBroadcast locally") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") - sc = new SparkContext("local", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === Set((1, 10), (2, 10))) + sc = new SparkContext("local", "test", torrentConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === Set((1, 10), (2, 10))) } test("Accessing TorrentBroadcast variables from multiple threads") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") - sc = new SparkContext("local[10]", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet) + sc = new SparkContext("local[10]", "test", torrentConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet) } test("Accessing TorrentBroadcast variables in a local cluster") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") val numSlaves = 4 - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", torrentConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + } + + test("Unpersisting HttpBroadcast on executors only") { + testUnpersistHttpBroadcast(2, removeFromDriver = false) + } + + test("Unpersisting HttpBroadcast on executors and driver") { + testUnpersistHttpBroadcast(2, removeFromDriver = true) + } + + test("Unpersisting TorrentBroadcast on executors only") { + testUnpersistTorrentBroadcast(2, removeFromDriver = false) + } + + test("Unpersisting TorrentBroadcast on executors and driver") { + testUnpersistTorrentBroadcast(2, removeFromDriver = true) + } + + /** + * Verify the persistence of state associated with an HttpBroadcast in a local-cluster. + * + * This test creates a broadcast variable, uses it on all executors, and then unpersists it. + * In between each step, this test verifies that the broadcast blocks and the broadcast file + * are present only on the expected nodes. + */ + private def testUnpersistHttpBroadcast(numSlaves: Int, removeFromDriver: Boolean) { + def getBlockIds(id: Long) = Seq[BlockId](BroadcastBlockId(id)) + + // Verify that the broadcast file is created, and blocks are persisted only on the driver + def afterCreation(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + assert(blockIds.size === 1) + val broadcastBlockId = blockIds.head.asInstanceOf[BroadcastBlockId] + val levels = bmm.askForStorageLevels(broadcastBlockId, waitTimeMs = 0) + assert(levels.size === 1) + levels.head match { case (bm, level) => + assert(bm.executorId === "") + assert(level === StorageLevel.MEMORY_AND_DISK) + } + assert(HttpBroadcast.getFile(broadcastBlockId.broadcastId).exists) + } + + // Verify that blocks are persisted in both the executors and the driver + def afterUsingBroadcast(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + assert(blockIds.size === 1) + val levels = bmm.askForStorageLevels(blockIds.head, waitTimeMs = 0) + assert(levels.size === numSlaves + 1) + levels.foreach { case (_, level) => + assert(level === StorageLevel.MEMORY_AND_DISK) + } + } + + // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver + // is true. In the latter case, also verify that the broadcast file is deleted on the driver. + def afterUnpersist(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + assert(blockIds.size === 1) + val broadcastBlockId = blockIds.head.asInstanceOf[BroadcastBlockId] + val levels = bmm.askForStorageLevels(broadcastBlockId, waitTimeMs = 0) + assert(levels.size === (if (removeFromDriver) 0 else 1)) + assert(removeFromDriver === !HttpBroadcast.getFile(broadcastBlockId.broadcastId).exists) + } + + testUnpersistBroadcast(numSlaves, httpConf, getBlockIds, afterCreation, + afterUsingBroadcast, afterUnpersist, removeFromDriver) + } + + /** + * Verify the persistence of state associated with an TorrentBroadcast in a local-cluster. + * + * This test creates a broadcast variable, uses it on all executors, and then unpersists it. + * In between each step, this test verifies that the broadcast blocks are present only on the + * expected nodes. + */ + private def testUnpersistTorrentBroadcast(numSlaves: Int, removeFromDriver: Boolean) { + def getBlockIds(id: Long) = { + val broadcastBlockId = BroadcastBlockId(id) + val metaBlockId = BroadcastHelperBlockId(broadcastBlockId, "meta") + // Assume broadcast value is small enough to fit into 1 piece + val pieceBlockId = BroadcastHelperBlockId(broadcastBlockId, "piece0") + Seq[BlockId](broadcastBlockId, metaBlockId, pieceBlockId) + } + + // Verify that blocks are persisted only on the driver + def afterCreation(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + blockIds.foreach { blockId => + val levels = bmm.askForStorageLevels(blockId, waitTimeMs = 0) + assert(levels.size === 1) + levels.head match { case (bm, level) => + assert(bm.executorId === "") + assert(level === StorageLevel.MEMORY_AND_DISK) + } + } + } + + // Verify that blocks are persisted in both the executors and the driver + def afterUsingBroadcast(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + blockIds.foreach { blockId => + val levels = bmm.askForStorageLevels(blockId, waitTimeMs = 0) + blockId match { + case BroadcastHelperBlockId(_, "meta") => + // Meta data is only on the driver + assert(levels.size === 1) + levels.head match { case (bm, _) => assert(bm.executorId === "") } + case _ => + // Other blocks are on both the executors and the driver + assert(levels.size === numSlaves + 1) + levels.foreach { case (_, level) => + assert(level === StorageLevel.MEMORY_AND_DISK) + } + } + } + } + + // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver + // is true. + def afterUnpersist(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + val expectedNumBlocks = if (removeFromDriver) 0 else 1 + var waitTimeMs = 1000L + blockIds.foreach { blockId => + // Allow a second for the messages triggered by unpersist to propagate to prevent deadlocks + val levels = bmm.askForStorageLevels(blockId, waitTimeMs) + assert(levels.size === expectedNumBlocks) + waitTimeMs = 0L + } + } + + testUnpersistBroadcast(numSlaves, torrentConf, getBlockIds, afterCreation, + afterUsingBroadcast, afterUnpersist, removeFromDriver) + } + + /** + * This test runs in 4 steps: + * + * 1) Create broadcast variable, and verify that all state is persisted on the driver. + * 2) Use the broadcast variable on all executors, and verify that all state is persisted + * on both the driver and the executors. + * 3) Unpersist the broadcast, and verify that all state is removed where they should be. + * 4) [Optional] If removeFromDriver is false, we verify that the broadcast is re-usable. + */ + private def testUnpersistBroadcast( + numSlaves: Int, + broadcastConf: SparkConf, + getBlockIds: Long => Seq[BlockId], + afterCreation: (Seq[BlockId], BlockManagerMaster) => Unit, + afterUsingBroadcast: (Seq[BlockId], BlockManagerMaster) => Unit, + afterUnpersist: (Seq[BlockId], BlockManagerMaster) => Unit, + removeFromDriver: Boolean) { + + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf) + val blockManagerMaster = sc.env.blockManager.master + val list = List[Int](1, 2, 3, 4) + + // Create broadcast variable + val broadcast = sc.broadcast(list) + val blocks = getBlockIds(broadcast.id) + afterCreation(blocks, blockManagerMaster) + + // Use broadcast variable on all executors + val results = sc.parallelize(1 to numSlaves, numSlaves).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + afterUsingBroadcast(blocks, blockManagerMaster) + + // Unpersist broadcast + broadcast.unpersist(removeFromDriver) + afterUnpersist(blocks, blockManagerMaster) + + if (!removeFromDriver) { + // The broadcast variable is not completely destroyed (i.e. state still exists on driver) + // Using the variable again should yield the same answer as before. + val results = sc.parallelize(1 to numSlaves, numSlaves).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + } + } + + /** Helper method to create a SparkConf that uses the given broadcast factory. */ + private def broadcastConf(factoryName: String): SparkConf = { + val conf = new SparkConf + conf.set("spark.broadcast.factory", "org.apache.spark.broadcast.%s".format(factoryName)) + conf } } From f201a8d3c2f3c95da986760ac7ce4acb199f4e71 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 27 Mar 2014 15:39:51 -0700 Subject: [PATCH 19/37] Test broadcast cleanup in ContextCleanerSuite + remove BoundedHashMap --- .../apache/spark/util/BoundedHashMap.scala | 67 -------- .../apache/spark/ContextCleanerSuite.scala | 147 +++++++++++------- .../spark/util/WrappedJavaHashMapSuite.scala | 5 - 3 files changed, 94 insertions(+), 125 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/util/BoundedHashMap.scala diff --git a/core/src/main/scala/org/apache/spark/util/BoundedHashMap.scala b/core/src/main/scala/org/apache/spark/util/BoundedHashMap.scala deleted file mode 100644 index 888a06b2408c9..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/BoundedHashMap.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import scala.collection.mutable.{ArrayBuffer, SynchronizedMap} - -import java.util.{Collections, LinkedHashMap} -import java.util.Map.{Entry => JMapEntry} -import scala.reflect.ClassTag - -/** - * A map that upper bounds the number of key-value pairs present in it. It can be configured to - * drop the least recently user pair or the earliest inserted pair. It exposes a - * scala.collection.mutable.Map interface to allow it to be a drop-in replacement for Scala - * HashMaps. - * - * Internally, a Java LinkedHashMap is used to get insert-order or access-order behavior. - * Note that the LinkedHashMap is not thread-safe and hence, it is wrapped in a - * Collections.synchronizedMap. However, getting the Java HashMap's iterator and - * using it can still lead to ConcurrentModificationExceptions. Hence, the iterator() - * function is overridden to copy the all pairs into an ArrayBuffer and then return the - * iterator to the ArrayBuffer. Also, the class apply the trait SynchronizedMap which - * ensures that all calls to the Scala Map API are synchronized. This together ensures - * that ConcurrentModificationException is never thrown. - * - * @param bound max number of key-value pairs - * @param useLRU true = least recently used/accessed will be dropped when bound is reached, - * false = earliest inserted will be dropped - */ -private[spark] class BoundedHashMap[A, B](bound: Int, useLRU: Boolean) - extends WrappedJavaHashMap[A, B, A, B] with SynchronizedMap[A, B] { - - private[util] val internalJavaMap = Collections.synchronizedMap(new LinkedHashMap[A, B]( - bound / 8, (0.75).toFloat, useLRU) { - override protected def removeEldestEntry(eldest: JMapEntry[A, B]): Boolean = { - size() > bound - } - }) - - private[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = { - new BoundedHashMap[K1, V1](bound, useLRU) - } - - /** - * Overriding iterator to make sure that the internal Java HashMap's iterator - * is not concurrently modified. This can be a performance issue and this should be overridden - * if it is known that this map will not be used in a multi-threaded environment. - */ - override def iterator: Iterator[(A, B)] = { - (new ArrayBuffer[(A, B)] ++= super.iterator).iterator - } -} diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 77d9825434706..6a12cb6603700 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark import java.lang.ref.WeakReference -import scala.collection.mutable.{ArrayBuffer, HashSet, SynchronizedSet} +import scala.collection.mutable.{HashSet, SynchronizedSet} import scala.util.Random import org.scalatest.{BeforeAndAfter, FunSuite} @@ -29,7 +29,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD -import org.apache.spark.storage.{RDDBlockId, ShuffleBlockId} +import org.apache.spark.storage.{BroadcastBlockId, RDDBlockId, ShuffleBlockId} class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { @@ -46,9 +46,9 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo // Explicit cleanup cleaner.cleanupRDD(rdd) - tester.assertCleanup + tester.assertCleanup() - // verify that RDDs can be re-executed after cleaning up + // Verify that RDDs can be re-executed after cleaning up assert(rdd.collect().toList === collected) } @@ -59,87 +59,101 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo // Explicit cleanup shuffleDeps.foreach(s => cleaner.cleanupShuffle(s)) - tester.assertCleanup + tester.assertCleanup() // Verify that shuffles can be re-executed after cleaning up assert(rdd.collect().toList === collected) } + test("cleanup broadcast") { + val broadcast = newBroadcast + val tester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) + + // Explicit cleanup + cleaner.cleanupBroadcast(broadcast) + tester.assertCleanup() + } + test("automatically cleanup RDD") { var rdd = newRDD.persist() rdd.count() - // test that GC does not cause RDD cleanup due to a strong reference + // Test that GC does not cause RDD cleanup due to a strong reference val preGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id)) runGC() intercept[Exception] { - preGCTester.assertCleanup(timeout(1000 millis)) + preGCTester.assertCleanup()(timeout(1000 millis)) } - // test that GC causes RDD cleanup after dereferencing the RDD + // Test that GC causes RDD cleanup after dereferencing the RDD val postGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id)) - rdd = null // make RDD out of scope + rdd = null // Make RDD out of scope runGC() - postGCTester.assertCleanup + postGCTester.assertCleanup() } test("automatically cleanup shuffle") { var rdd = newShuffleRDD rdd.count() - // test that GC does not cause shuffle cleanup due to a strong reference - val preGCTester = new CleanerTester(sc, shuffleIds = Seq(0)) + // Test that GC does not cause shuffle cleanup due to a strong reference + val preGCTester = new CleanerTester(sc, shuffleIds = Seq(0)) runGC() intercept[Exception] { - preGCTester.assertCleanup(timeout(1000 millis)) + preGCTester.assertCleanup()(timeout(1000 millis)) } - // test that GC causes shuffle cleanup after dereferencing the RDD + // Test that GC causes shuffle cleanup after dereferencing the RDD val postGCTester = new CleanerTester(sc, shuffleIds = Seq(0)) - rdd = null // make RDD out of scope, so that corresponding shuffle goes out of scope + rdd = null // Make RDD out of scope, so that corresponding shuffle goes out of scope runGC() - postGCTester.assertCleanup + postGCTester.assertCleanup() } - test("automatically cleanup RDD + shuffle") { + test("automatically cleanup broadcast") { + var broadcast = newBroadcast - def randomRDD: RDD[_] = { - val rdd: RDD[_] = Random.nextInt(3) match { - case 0 => newRDD - case 1 => newShuffleRDD - case 2 => newPairRDD.join(newPairRDD) - } - if (Random.nextBoolean()) rdd.persist() - rdd.count() - rdd + // Test that GC does not cause broadcast cleanup due to a strong reference + val preGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) + runGC() + intercept[Exception] { + preGCTester.assertCleanup()(timeout(1000 millis)) } - val buffer = new ArrayBuffer[RDD[_]] - for (i <- 1 to 500) { - buffer += randomRDD - } + // Test that GC causes broadcast cleanup after dereferencing the broadcast variable + val postGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) + broadcast = null // Make broadcast variable out of scope + runGC() + postGCTester.assertCleanup() + } + test("automatically cleanup RDD + shuffle + broadcast") { + val numRdds = 100 + val numBroadcasts = 4 // Broadcasts are more costly + val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer + val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer val rddIds = sc.persistentRdds.keys.toSeq val shuffleIds = 0 until sc.newShuffleId + val broadcastIds = 0L until numBroadcasts - val preGCTester = new CleanerTester(sc, rddIds, shuffleIds) + val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) runGC() intercept[Exception] { - preGCTester.assertCleanup(timeout(1000 millis)) + preGCTester.assertCleanup()(timeout(1000 millis)) } - // test that GC causes shuffle cleanup after dereferencing the RDD - val postGCTester = new CleanerTester(sc, rddIds, shuffleIds) - buffer.clear() + + // Test that GC triggers the cleanup of all variables after the dereferencing them + val postGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) + broadcastBuffer.clear() + rddBuffer.clear() runGC() - postGCTester.assertCleanup + postGCTester.assertCleanup() } def newRDD = sc.makeRDD(1 to 10) - def newPairRDD = newRDD.map(_ -> 1) - def newShuffleRDD = newPairRDD.reduceByKey(_ + _) - + def newBroadcast = sc.broadcast(1 to 100) def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _]]) = { def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = { rdd.dependencies ++ rdd.dependencies.flatMap { dep => @@ -149,11 +163,27 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo val rdd = newShuffleRDD // Get all the shuffle dependencies - val shuffleDeps = getAllDependencies(rdd).filter(_.isInstanceOf[ShuffleDependency[_, _]]) + val shuffleDeps = getAllDependencies(rdd) + .filter(_.isInstanceOf[ShuffleDependency[_, _]]) .map(_.asInstanceOf[ShuffleDependency[_, _]]) (rdd, shuffleDeps) } + def randomRdd = { + val rdd: RDD[_] = Random.nextInt(3) match { + case 0 => newRDD + case 1 => newShuffleRDD + case 2 => newPairRDD.join(newPairRDD) + } + if (Random.nextBoolean()) rdd.persist() + rdd.count() + rdd + } + + def randomBroadcast = { + sc.broadcast(Random.nextInt(Int.MaxValue)) + } + /** Run GC and make sure it actually has run */ def runGC() { val weakRef = new WeakReference(new Object()) @@ -208,7 +238,7 @@ class CleanerTester( sc.cleaner.attachListener(cleanerListener) /** Assert that all the stuff has been cleaned up */ - def assertCleanup(implicit waitTimeout: Eventually.Timeout) { + def assertCleanup()(implicit waitTimeout: Eventually.Timeout) { try { eventually(waitTimeout, interval(10 millis)) { assert(isAllCleanedUp) @@ -222,7 +252,7 @@ class CleanerTester( /** Verify that RDDs, shuffles, etc. occupy resources */ private def preCleanupValidate() { - assert(rddIds.nonEmpty || shuffleIds.nonEmpty, "Nothing to cleanup") + assert(rddIds.nonEmpty || shuffleIds.nonEmpty || broadcastIds.nonEmpty, "Nothing to cleanup") // Verify the RDDs have been persisted and blocks are present assert(rddIds.forall(sc.persistentRdds.contains), @@ -233,8 +263,12 @@ class CleanerTester( // Verify the shuffle ids are registered and blocks are present assert(shuffleIds.forall(mapOutputTrackerMaster.containsShuffle), "One or more shuffles have not been registered cannot start cleaner test") - assert(shuffleIds.forall(shuffleId => diskBlockManager.containsBlock(shuffleBlockId(shuffleId))), + assert(shuffleIds.forall(sid => diskBlockManager.containsBlock(shuffleBlockId(sid))), "One or more shuffles' blocks cannot be found in disk manager, cannot start cleaner test") + + // Verify that the broadcast is in the driver's block manager + assert(broadcastIds.forall(bid => blockManager.getLevel(broadcastBlockId(bid)).isDefined), + "One ore more broadcasts have not been persisted in the driver's block manager") } /** @@ -247,14 +281,19 @@ class CleanerTester( attempts += 1 logInfo("Attempt: " + attempts) try { - // Verify all the RDDs have been unpersisted + // Verify all RDDs have been unpersisted assert(rddIds.forall(!sc.persistentRdds.contains(_))) assert(rddIds.forall(rddId => !blockManager.master.contains(rddBlockId(rddId)))) - // Verify all the shuffle have been deregistered and cleaned up + // Verify all shuffles have been deregistered and cleaned up assert(shuffleIds.forall(!mapOutputTrackerMaster.containsShuffle(_))) - assert(shuffleIds.forall(shuffleId => - !diskBlockManager.containsBlock(shuffleBlockId(shuffleId)))) + assert(shuffleIds.forall(sid => !diskBlockManager.containsBlock(shuffleBlockId(sid)))) + + // Verify all broadcasts have been unpersisted + assert(broadcastIds.forall { bid => + blockManager.master.askForStorageLevels(broadcastBlockId(bid)).isEmpty + }) + return } catch { case t: Throwable => @@ -271,18 +310,20 @@ class CleanerTester( s""" |\tRDDs = ${toBeCleanedRDDIds.mkString("[", ", ", "]")} |\tShuffles = ${toBeCleanedShuffleIds.mkString("[", ", ", "]")} + |\tBroadcasts = ${toBeCleanedBroadcstIds.mkString("[", ", ", "]")} """.stripMargin } - private def isAllCleanedUp = toBeCleanedRDDIds.isEmpty && toBeCleanedShuffleIds.isEmpty - - private def shuffleBlockId(shuffleId: Int) = ShuffleBlockId(shuffleId, 0, 0) + private def isAllCleanedUp = + toBeCleanedRDDIds.isEmpty && + toBeCleanedShuffleIds.isEmpty && + toBeCleanedBroadcstIds.isEmpty private def rddBlockId(rddId: Int) = RDDBlockId(rddId, 0) + private def shuffleBlockId(shuffleId: Int) = ShuffleBlockId(shuffleId, 0, 0) + private def broadcastBlockId(broadcastId: Long) = BroadcastBlockId(broadcastId) private def blockManager = sc.env.blockManager - private def diskBlockManager = blockManager.diskBlockManager - private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala index e446c7f75dc0b..0b9847174ac84 100644 --- a/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala @@ -33,11 +33,6 @@ class WrappedJavaHashMapSuite extends FunSuite { // Test a simple WrappedJavaHashMap testMap(new TestMap[String, String]()) - // Test BoundedHashMap - testMap(new BoundedHashMap[String, String](100, true)) - - testMapThreadSafety(new BoundedHashMap[String, String](100, true)) - // Test TimeStampedHashMap testMap(new TimeStampedHashMap[String, String]) From 0d170606469ad1d58f7743f9cd57247d45082fad Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 27 Mar 2014 19:07:55 -0700 Subject: [PATCH 20/37] Import, comments, and style fixes (minor) --- .../scala/org/apache/spark/MapOutputTracker.scala | 11 +++++------ .../scala/org/apache/spark/SparkContext.scala | 3 ++- .../main/scala/org/apache/spark/SparkEnv.scala | 1 + .../apache/spark/broadcast/BroadcastFactory.scala | 4 ++-- .../apache/spark/broadcast/HttpBroadcast.scala | 9 ++++----- .../apache/spark/broadcast/TorrentBroadcast.scala | 7 +++---- .../src/main/scala/org/apache/spark/rdd/RDD.scala | 3 ++- .../org/apache/spark/scheduler/DAGScheduler.scala | 1 - .../org/apache/spark/storage/BlockManager.scala | 4 ++-- .../apache/spark/storage/BlockManagerMaster.scala | 2 +- .../spark/storage/BlockManagerMasterActor.scala | 3 +-- .../spark/storage/BlockManagerMessages.scala | 4 ---- .../apache/spark/storage/DiskBlockManager.scala | 2 +- .../spark/storage/ShuffleBlockManager.scala | 4 ++-- .../org/apache/spark/storage/ThreadingTest.scala | 2 +- .../org/apache/spark/util/MetadataCleaner.scala | 15 ++++++++------- .../spark/util/TimeStampedWeakValueHashMap.scala | 8 ++++---- .../org/apache/spark/MapOutputTrackerSuite.scala | 2 +- .../apache/spark/storage/BlockManagerSuite.scala | 3 +-- .../spark/util/WrappedJavaHashMapSuite.scala | 2 +- 20 files changed, 42 insertions(+), 48 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index e1a273593cce5..c45c5c90048f3 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -112,8 +112,8 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } /** - * Called from executors to get the server URIs and - * output sizes of the map outputs of a given shuffle + * Called from executors to get the server URIs and output sizes of the map outputs of + * a given shuffle. */ def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = { val statuses = mapStatuses.get(shuffleId).orNull @@ -218,10 +218,9 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) private var cacheEpoch = epoch /** - * Timestamp based HashMap for storing mapStatuses and cached serialized statuses - * in the master, so that statuses are dropped only by explicit deregistering or - * by TTL-based cleaning (if set). Other than these two - * scenarios, nothing should be dropped from this HashMap. + * Timestamp based HashMap for storing mapStatuses and cached serialized statuses in the master, + * so that statuses are dropped only by explicit deregistering or by TTL-based cleaning (if set). + * Other than these two scenarios, nothing should be dropped from this HashMap. */ protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]() private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]() diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index fe84b812ba8d0..79574c271cfb6 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -230,6 +230,7 @@ class SparkContext( private[spark] val cleaner = new ContextCleaner(this) cleaner.start() + postEnvironmentUpdate() /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ @@ -773,7 +774,7 @@ class SparkContext( * filesystems), an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node. */ def addJar(path: String) { - if (path == null) { + if (path == null) { logWarning("null specified as parameter to addJar") } else { var key = "" diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 521182021dd4b..62398dc930993 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -185,6 +185,7 @@ object SparkEnv extends Logging { } else { new MapOutputTrackerWorker(conf) } + // Have to assign trackerActor after initialization as MapOutputTrackerActor // requires the MapOutputTracker itself mapOutputTracker.trackerActor = registerOrLookup( diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index 850650951e603..9ff1675e76a5e 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -27,8 +27,8 @@ import org.apache.spark.SparkConf * entire Spark job. */ trait BroadcastFactory { - def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] def unbroadcast(id: Long, removeFromDriver: Boolean) - def stop(): Unit + def stop() } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index d5e3d60a5b2b7..d8981bb42e684 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -90,7 +90,7 @@ private[spark] object HttpBroadcast extends Logging { private var securityManager: SecurityManager = null // TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist - val files = new TimeStampedHashSet[String] + private val files = new TimeStampedHashSet[String] private var cleaner: MetadataCleaner = null private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES).toInt @@ -195,7 +195,7 @@ private[spark] object HttpBroadcast extends Logging { def unpersist(id: Long, removeFromDriver: Boolean) = synchronized { SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver) if (removeFromDriver) { - val file = new File(broadcastDir, BroadcastBlockId(id).name) + val file = getFile(id) files.remove(file.toString) deleteBroadcastFile(file) } @@ -217,10 +217,9 @@ private[spark] object HttpBroadcast extends Logging { } } - /** Delete the given broadcast file. */ private def deleteBroadcastFile(file: File) { try { - if (!file.exists()) { + if (!file.exists) { logWarning("Broadcast file to be deleted does not exist: %s".format(file)) } else if (file.delete()) { logInfo("Deleted broadcast file: %s".format(file)) @@ -229,7 +228,7 @@ private[spark] object HttpBroadcast extends Logging { } } catch { case e: Exception => - logWarning("Exception while deleting broadcast file: %s".format(file), e) + logError("Exception while deleting broadcast file: %s".format(file), e) } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index ace71575f5390..ab280fad4e28f 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -72,7 +72,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo } /** - * Remove all persisted state associated with this HTTP broadcast. + * Remove all persisted state associated with this Torrent broadcast. * @param removeFromDriver Whether to remove state from the driver. */ override def unpersist(removeFromDriver: Boolean) { @@ -177,13 +177,12 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo } private[spark] object TorrentBroadcast extends Logging { + private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 private var initialized = false private var conf: SparkConf = null - lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 - def initialize(_isDriver: Boolean, conf: SparkConf) { - TorrentBroadcast.conf = conf //TODO: we might have to fix it in tests + TorrentBroadcast.conf = conf // TODO: we might have to fix it in tests synchronized { if (!initialized) { initialized = true diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e5638d0132e88..e8d36e6bfc810 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -158,7 +158,7 @@ abstract class RDD[T: ClassTag]( */ def unpersist(blocking: Boolean = true): RDD[T] = { logInfo("Removing RDD " + id + " from persistence list") - sc.unpersistRDD(this.id, blocking) + sc.unpersistRDD(id, blocking) storageLevel = StorageLevel.NONE this } @@ -1128,4 +1128,5 @@ abstract class RDD[T: ClassTag]( def toJavaRDD() : JavaRDD[T] = { new JavaRDD(this)(elementClassTag) } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index edef40e7309f6..f31f0580c36fe 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1090,7 +1090,6 @@ class DAGScheduler( eventProcessActor ! StopDAGScheduler } taskScheduler.stop() - listenerBus.stop() } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 78dc32b4b1525..24ec8d3ab44bf 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -49,8 +49,8 @@ private[spark] class BlockManager( maxMemory: Long, val conf: SparkConf, securityManager: SecurityManager, - mapOutputTracker: MapOutputTracker - ) extends Logging { + mapOutputTracker: MapOutputTracker) + extends Logging { val shuffleBlockManager = new ShuffleBlockManager(this) val diskBlockManager = new DiskBlockManager(shuffleBlockManager, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 674322e3034c8..5c9ea88d6b1a4 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -82,7 +82,7 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log /** * Check if block manager master has a block. Note that this can be used to check for only - * those blocks that are expected to be reported to block manager master. + * those blocks that are reported to block manager master. */ def contains(blockId: BlockId) = { !getLocations(blockId).isEmpty diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index f83c26dafe2e9..3271d4f1375ef 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -167,7 +167,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus * from the executors, but not from the driver. */ private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) { - // TODO(aor): Consolidate usages of + // TODO: Consolidate usages of val removeMsg = RemoveBroadcast(broadcastId) blockManagerInfo.values .filter { info => removeFromDriver || info.blockManagerId.executorId != "" } @@ -350,7 +350,6 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // Note that this logic will select the same node multiple times if there aren't enough peers Array.tabulate[BlockManagerId](size) { i => peers((selfIndex + i + 1) % peers.length) }.toSeq } - } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 1d3e94c4b6533..9a29c39a28ab1 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -22,11 +22,9 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} import akka.actor.ActorRef private[storage] object BlockManagerMessages { - ////////////////////////////////////////////////////////////////////////////////// // Messages from the master to slaves. ////////////////////////////////////////////////////////////////////////////////// - sealed trait ToBlockManagerSlave // Remove a block from the slaves that have it. This can only be used to remove @@ -50,7 +48,6 @@ private[storage] object BlockManagerMessages { ////////////////////////////////////////////////////////////////////////////////// // Messages from slaves to the master. ////////////////////////////////////////////////////////////////////////////////// - sealed trait ToBlockManagerMaster case class RegisterBlockManager( @@ -122,5 +119,4 @@ private[storage] object BlockManagerMessages { // For testing. Have the master ask all slaves for the given block's storage level. case class AskForStorageLevels(blockId: BlockId) extends ToBlockManagerMaster - } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index a57e6f710305a..fcad84669c79a 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -90,7 +90,7 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD def getFile(blockId: BlockId): File = getFile(blockId.name) - /** Check if disk block manager has a block */ + /** Check if disk block manager has a block. */ def containsBlock(blockId: BlockId): Boolean = { getBlockLocation(blockId).file.exists() } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index cf83a60ffb9e8..06233153c56d4 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -169,13 +169,13 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { throw new IllegalStateException("Failed to find shuffle block: " + id) } - /** Remove all the blocks / files and metadata related to a particular shuffle */ + /** Remove all the blocks / files and metadata related to a particular shuffle. */ def removeShuffle(shuffleId: ShuffleId) { removeShuffleBlocks(shuffleId) shuffleStates.remove(shuffleId) } - /** Remove all the blocks / files related to a particular shuffle */ + /** Remove all the blocks / files related to a particular shuffle. */ private def removeShuffleBlocks(shuffleId: ShuffleId) { shuffleStates.get(shuffleId) match { case Some(state) => diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala index 7b75215846a9a..a107c5182b3be 100644 --- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala @@ -48,7 +48,7 @@ private[spark] object ThreadingTest { val block = (1 to blockSize).map(_ => Random.nextInt()) val level = randomLevel() val startTime = System.currentTimeMillis() - manager.put(blockId, block.iterator, level, true) + manager.put(blockId, block.iterator, level, tellMaster = true) println("Pushed block " + blockId + " in " + (System.currentTimeMillis - startTime) + " ms") queue.add((blockId, block)) } diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index 2ef853710a554..7ebed5105b9fd 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -78,15 +78,16 @@ private[spark] object MetadataCleaner { conf.getInt("spark.cleaner.ttl", -1) } - def getDelaySeconds(conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType): Int = - { - conf.get(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds(conf).toString) - .toInt + def getDelaySeconds( + conf: SparkConf, + cleanerType: MetadataCleanerType.MetadataCleanerType): Int = { + conf.get(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds(conf).toString).toInt } - def setDelaySeconds(conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType, - delay: Int) - { + def setDelaySeconds( + conf: SparkConf, + cleanerType: MetadataCleanerType.MetadataCleanerType, + delay: Int) { conf.set(MetadataCleanerType.systemProperty(cleanerType), delay.toString) } diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala index 09a6faf33ec60..9f3247a27ba38 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala @@ -17,14 +17,14 @@ package org.apache.spark.util -import scala.collection.{JavaConversions, immutable} - -import java.util import java.lang.ref.WeakReference +import java.util import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.JavaConversions import org.apache.spark.Logging -import java.util.concurrent.atomic.AtomicInteger private[util] case class TimeStampedWeakValue[T](timestamp: Long, weakValue: WeakReference[T]) { def this(timestamp: Long, value: T) = this(timestamp, new WeakReference[T](value)) diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index b83033c35f6b7..6b2571cd9295e 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -96,7 +96,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { assert(tracker.getServerStatuses(10, 0).isEmpty) } - test("master register shuffle and unregister mapoutput and fetch") { + test("master register shuffle and unregister map output and fetch") { val actorSystem = ActorSystem("test") val tracker = new MapOutputTrackerMaster(conf) tracker.trackerActor = diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 04e64ee7a45b3..1f5bcca64fc39 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -28,8 +28,7 @@ import org.scalatest.concurrent.Timeouts._ import org.scalatest.matchers.ShouldMatchers._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf, SparkContext} -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils} diff --git a/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala index 0b9847174ac84..f6e6a4c77c820 100644 --- a/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.util -import java.util import java.lang.ref.WeakReference +import java.util import scala.collection.mutable.{ArrayBuffer, HashMap, Map} import scala.util.Random From 34f436f7d1799a6fd22b745d339734f220108dae Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 28 Mar 2014 13:17:20 -0700 Subject: [PATCH 21/37] Generalize BroadcastBlockId to remove BroadcastHelperBlockId Rather than having a special purpose BroadcastHelperBlockId just for TorrentBroadcast, we now have a single BroadcastBlockId that has a possibly empty field. This simplifies broadcast clean-up because now we only have to look for one type of block. This commit also simplifies BlockId JSON de/serialization in general by parsing the name through regex with apply. --- .../spark/broadcast/TorrentBroadcast.scala | 10 +-- .../org/apache/spark/storage/BlockId.scala | 29 ++++--- .../apache/spark/storage/BlockManager.scala | 3 +- .../org/apache/spark/util/JsonProtocol.scala | 77 +------------------ .../org/apache/spark/BroadcastSuite.scala | 61 +++++++-------- .../apache/spark/util/JsonProtocolSuite.scala | 14 ---- 6 files changed, 54 insertions(+), 140 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index ab280fad4e28f..dbe65d88104fb 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -23,7 +23,7 @@ import scala.math import scala.util.Random import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} -import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel} +import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.Utils private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) @@ -54,7 +54,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo hasBlocks = tInfo.totalBlocks // Store meta-info - val metaId = BroadcastHelperBlockId(broadcastId, "meta") + val metaId = BroadcastBlockId(id, "meta") val metaInfo = TorrentInfo(null, totalBlocks, totalBytes) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( @@ -63,7 +63,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo // Store individual pieces for (i <- 0 until totalBlocks) { - val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i) + val pieceId = BroadcastBlockId(id, "piece" + i) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, tellMaster = true) @@ -131,7 +131,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo def receiveBroadcast(): Boolean = { // Receive meta-info - val metaId = BroadcastHelperBlockId(broadcastId, "meta") + val metaId = BroadcastBlockId(id, "meta") var attemptId = 10 while (attemptId > 0 && totalBlocks == -1) { TorrentBroadcast.synchronized { @@ -156,7 +156,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo // Receive actual blocks val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList) for (pid <- recvOrder) { - val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid) + val pieceId = BroadcastBlockId(id, "piece" + pid) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.getSingle(pieceId) match { case Some(x) => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 301d784b350a3..27e271368ed06 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -34,7 +34,7 @@ private[spark] sealed abstract class BlockId { def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None def isRDD = isInstanceOf[RDDBlockId] def isShuffle = isInstanceOf[ShuffleBlockId] - def isBroadcast = isInstanceOf[BroadcastBlockId] || isInstanceOf[BroadcastHelperBlockId] + def isBroadcast = isInstanceOf[BroadcastBlockId] override def toString = name override def hashCode = name.hashCode @@ -48,18 +48,15 @@ private[spark] case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockI def name = "rdd_" + rddId + "_" + splitIndex } -private[spark] -case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { +private[spark] case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) + extends BlockId { def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId } +// Leave field as an instance variable to avoid matching on it private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId { - def name = "broadcast_" + broadcastId -} - -private[spark] -case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId { - def name = broadcastId.name + "_" + hType + var field = "" + def name = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field) } private[spark] case class TaskResultBlockId(taskId: Long) extends BlockId { @@ -80,11 +77,19 @@ private[spark] case class TestBlockId(id: String) extends BlockId { def name = "test_" + id } +private[spark] object BroadcastBlockId { + def apply(broadcastId: Long, field: String) = { + val blockId = new BroadcastBlockId(broadcastId) + blockId.field = field + blockId + } +} + private[spark] object BlockId { val RDD = "rdd_([0-9]+)_([0-9]+)".r val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r val BROADCAST = "broadcast_([0-9]+)".r - val BROADCAST_HELPER = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r + val BROADCAST_FIELD = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r val TASKRESULT = "taskresult_([0-9]+)".r val STREAM = "input-([0-9]+)-([0-9]+)".r val TEST = "test_(.*)".r @@ -97,8 +102,8 @@ private[spark] object BlockId { ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) case BROADCAST(broadcastId) => BroadcastBlockId(broadcastId.toLong) - case BROADCAST_HELPER(broadcastId, hType) => - BroadcastHelperBlockId(BroadcastBlockId(broadcastId.toLong), hType) + case BROADCAST_FIELD(broadcastId, field) => + BroadcastBlockId(broadcastId.toLong, field) case TASKRESULT(taskId) => TaskResultBlockId(taskId.toLong) case STREAM(streamId, uniqueId) => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 24ec8d3ab44bf..a88eb1315a37b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -827,9 +827,8 @@ private[spark] class BlockManager( */ def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) { logInfo("Removing broadcast " + broadcastId) - val blocksToRemove = blockInfo.keys.filter(_.isBroadcast).collect { + val blocksToRemove = blockInfo.keys.collect { case bid: BroadcastBlockId if bid.broadcastId == broadcastId => bid - case bid: BroadcastHelperBlockId if bid.broadcastId.broadcastId == broadcastId => bid } blocksToRemove.foreach { blockId => removeBlock(blockId, removeFromDriver) } } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 346f2b7856791..d9a6af61872d1 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -195,7 +195,7 @@ private[spark] object JsonProtocol { taskMetrics.shuffleWriteMetrics.map(shuffleWriteMetricsToJson).getOrElse(JNothing) val updatedBlocks = taskMetrics.updatedBlocks.map { blocks => JArray(blocks.toList.map { case (id, status) => - ("Block ID" -> blockIdToJson(id)) ~ + ("Block ID" -> id.toString) ~ ("Status" -> blockStatusToJson(status)) }) }.getOrElse(JNothing) @@ -284,35 +284,6 @@ private[spark] object JsonProtocol { ("Replication" -> storageLevel.replication) } - def blockIdToJson(blockId: BlockId): JValue = { - val blockType = Utils.getFormattedClassName(blockId) - val json: JObject = blockId match { - case rddBlockId: RDDBlockId => - ("RDD ID" -> rddBlockId.rddId) ~ - ("Split Index" -> rddBlockId.splitIndex) - case shuffleBlockId: ShuffleBlockId => - ("Shuffle ID" -> shuffleBlockId.shuffleId) ~ - ("Map ID" -> shuffleBlockId.mapId) ~ - ("Reduce ID" -> shuffleBlockId.reduceId) - case broadcastBlockId: BroadcastBlockId => - "Broadcast ID" -> broadcastBlockId.broadcastId - case broadcastHelperBlockId: BroadcastHelperBlockId => - ("Broadcast Block ID" -> blockIdToJson(broadcastHelperBlockId.broadcastId)) ~ - ("Helper Type" -> broadcastHelperBlockId.hType) - case taskResultBlockId: TaskResultBlockId => - "Task ID" -> taskResultBlockId.taskId - case streamBlockId: StreamBlockId => - ("Stream ID" -> streamBlockId.streamId) ~ - ("Unique ID" -> streamBlockId.uniqueId) - case tempBlockId: TempBlockId => - val uuid = UUIDToJson(tempBlockId.id) - "Temp ID" -> uuid - case testBlockId: TestBlockId => - "Test ID" -> testBlockId.id - } - ("Type" -> blockType) ~ json - } - def blockStatusToJson(blockStatus: BlockStatus): JValue = { val storageLevel = storageLevelToJson(blockStatus.storageLevel) ("Storage Level" -> storageLevel) ~ @@ -513,7 +484,7 @@ private[spark] object JsonProtocol { Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson) metrics.updatedBlocks = Utils.jsonOption(json \ "Updated Blocks").map { value => value.extract[List[JValue]].map { block => - val id = blockIdFromJson(block \ "Block ID") + val id = BlockId((block \ "Block ID").extract[String]) val status = blockStatusFromJson(block \ "Status") (id, status) } @@ -616,50 +587,6 @@ private[spark] object JsonProtocol { StorageLevel(useDisk, useMemory, deserialized, replication) } - def blockIdFromJson(json: JValue): BlockId = { - val rddBlockId = Utils.getFormattedClassName(RDDBlockId) - val shuffleBlockId = Utils.getFormattedClassName(ShuffleBlockId) - val broadcastBlockId = Utils.getFormattedClassName(BroadcastBlockId) - val broadcastHelperBlockId = Utils.getFormattedClassName(BroadcastHelperBlockId) - val taskResultBlockId = Utils.getFormattedClassName(TaskResultBlockId) - val streamBlockId = Utils.getFormattedClassName(StreamBlockId) - val tempBlockId = Utils.getFormattedClassName(TempBlockId) - val testBlockId = Utils.getFormattedClassName(TestBlockId) - - (json \ "Type").extract[String] match { - case `rddBlockId` => - val rddId = (json \ "RDD ID").extract[Int] - val splitIndex = (json \ "Split Index").extract[Int] - new RDDBlockId(rddId, splitIndex) - case `shuffleBlockId` => - val shuffleId = (json \ "Shuffle ID").extract[Int] - val mapId = (json \ "Map ID").extract[Int] - val reduceId = (json \ "Reduce ID").extract[Int] - new ShuffleBlockId(shuffleId, mapId, reduceId) - case `broadcastBlockId` => - val broadcastId = (json \ "Broadcast ID").extract[Long] - new BroadcastBlockId(broadcastId) - case `broadcastHelperBlockId` => - val broadcastBlockId = - blockIdFromJson(json \ "Broadcast Block ID").asInstanceOf[BroadcastBlockId] - val hType = (json \ "Helper Type").extract[String] - new BroadcastHelperBlockId(broadcastBlockId, hType) - case `taskResultBlockId` => - val taskId = (json \ "Task ID").extract[Long] - new TaskResultBlockId(taskId) - case `streamBlockId` => - val streamId = (json \ "Stream ID").extract[Int] - val uniqueId = (json \ "Unique ID").extract[Long] - new StreamBlockId(streamId, uniqueId) - case `tempBlockId` => - val tempId = UUIDFromJson(json \ "Temp ID") - new TempBlockId(tempId) - case `testBlockId` => - val testId = (json \ "Test ID").extract[String] - new TestBlockId(testId) - } - } - def blockStatusFromJson(json: JValue): BlockStatus = { val storageLevel = storageLevelFromJson(json \ "Storage Level") val memorySize = (json \ "Memory Size").extract[Long] diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index a462654197ea0..9e600f1e91aa2 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -21,7 +21,7 @@ import org.scalatest.FunSuite import org.apache.spark.storage._ import org.apache.spark.broadcast.HttpBroadcast -import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId} +import org.apache.spark.storage.BroadcastBlockId class BroadcastSuite extends FunSuite with LocalSparkContext { @@ -102,23 +102,22 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { * are present only on the expected nodes. */ private def testUnpersistHttpBroadcast(numSlaves: Int, removeFromDriver: Boolean) { - def getBlockIds(id: Long) = Seq[BlockId](BroadcastBlockId(id)) + def getBlockIds(id: Long) = Seq[BroadcastBlockId](BroadcastBlockId(id)) // Verify that the broadcast file is created, and blocks are persisted only on the driver - def afterCreation(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { assert(blockIds.size === 1) - val broadcastBlockId = blockIds.head.asInstanceOf[BroadcastBlockId] - val levels = bmm.askForStorageLevels(broadcastBlockId, waitTimeMs = 0) + val levels = bmm.askForStorageLevels(blockIds.head, waitTimeMs = 0) assert(levels.size === 1) levels.head match { case (bm, level) => assert(bm.executorId === "") assert(level === StorageLevel.MEMORY_AND_DISK) } - assert(HttpBroadcast.getFile(broadcastBlockId.broadcastId).exists) + assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists) } // Verify that blocks are persisted in both the executors and the driver - def afterUsingBroadcast(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { assert(blockIds.size === 1) val levels = bmm.askForStorageLevels(blockIds.head, waitTimeMs = 0) assert(levels.size === numSlaves + 1) @@ -129,12 +128,11 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver // is true. In the latter case, also verify that the broadcast file is deleted on the driver. - def afterUnpersist(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { assert(blockIds.size === 1) - val broadcastBlockId = blockIds.head.asInstanceOf[BroadcastBlockId] - val levels = bmm.askForStorageLevels(broadcastBlockId, waitTimeMs = 0) + val levels = bmm.askForStorageLevels(blockIds.head, waitTimeMs = 0) assert(levels.size === (if (removeFromDriver) 0 else 1)) - assert(removeFromDriver === !HttpBroadcast.getFile(broadcastBlockId.broadcastId).exists) + assert(removeFromDriver === !HttpBroadcast.getFile(blockIds.head.broadcastId).exists) } testUnpersistBroadcast(numSlaves, httpConf, getBlockIds, afterCreation, @@ -151,14 +149,14 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { private def testUnpersistTorrentBroadcast(numSlaves: Int, removeFromDriver: Boolean) { def getBlockIds(id: Long) = { val broadcastBlockId = BroadcastBlockId(id) - val metaBlockId = BroadcastHelperBlockId(broadcastBlockId, "meta") + val metaBlockId = BroadcastBlockId(id, "meta") // Assume broadcast value is small enough to fit into 1 piece - val pieceBlockId = BroadcastHelperBlockId(broadcastBlockId, "piece0") - Seq[BlockId](broadcastBlockId, metaBlockId, pieceBlockId) + val pieceBlockId = BroadcastBlockId(id, "piece0") + Seq[BroadcastBlockId](broadcastBlockId, metaBlockId, pieceBlockId) } // Verify that blocks are persisted only on the driver - def afterCreation(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { blockIds.foreach { blockId => val levels = bmm.askForStorageLevels(blockId, waitTimeMs = 0) assert(levels.size === 1) @@ -170,27 +168,26 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { } // Verify that blocks are persisted in both the executors and the driver - def afterUsingBroadcast(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { blockIds.foreach { blockId => val levels = bmm.askForStorageLevels(blockId, waitTimeMs = 0) - blockId match { - case BroadcastHelperBlockId(_, "meta") => - // Meta data is only on the driver - assert(levels.size === 1) - levels.head match { case (bm, _) => assert(bm.executorId === "") } - case _ => - // Other blocks are on both the executors and the driver - assert(levels.size === numSlaves + 1) - levels.foreach { case (_, level) => - assert(level === StorageLevel.MEMORY_AND_DISK) - } + if (blockId.field == "meta") { + // Meta data is only on the driver + assert(levels.size === 1) + levels.head match { case (bm, _) => assert(bm.executorId === "") } + } else { + // Other blocks are on both the executors and the driver + assert(levels.size === numSlaves + 1) + levels.foreach { case (_, level) => + assert(level === StorageLevel.MEMORY_AND_DISK) + } } } } // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver // is true. - def afterUnpersist(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { val expectedNumBlocks = if (removeFromDriver) 0 else 1 var waitTimeMs = 1000L blockIds.foreach { blockId => @@ -217,10 +214,10 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { private def testUnpersistBroadcast( numSlaves: Int, broadcastConf: SparkConf, - getBlockIds: Long => Seq[BlockId], - afterCreation: (Seq[BlockId], BlockManagerMaster) => Unit, - afterUsingBroadcast: (Seq[BlockId], BlockManagerMaster) => Unit, - afterUnpersist: (Seq[BlockId], BlockManagerMaster) => Unit, + getBlockIds: Long => Seq[BroadcastBlockId], + afterCreation: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, + afterUsingBroadcast: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, + afterUnpersist: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, removeFromDriver: Boolean) { sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 67c0a434c9b52..580ac34f5f0b4 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -104,15 +104,6 @@ class JsonProtocolSuite extends FunSuite { testTaskEndReason(TaskKilled) testTaskEndReason(ExecutorLostFailure) testTaskEndReason(UnknownReason) - - // BlockId - testBlockId(RDDBlockId(1, 2)) - testBlockId(ShuffleBlockId(1, 2, 3)) - testBlockId(BroadcastBlockId(1L)) - testBlockId(BroadcastHelperBlockId(BroadcastBlockId(2L), "Spark")) - testBlockId(TaskResultBlockId(1L)) - testBlockId(StreamBlockId(1, 2L)) - testBlockId(TempBlockId(UUID.randomUUID())) } @@ -167,11 +158,6 @@ class JsonProtocolSuite extends FunSuite { assertEquals(reason, newReason) } - private def testBlockId(blockId: BlockId) { - val newBlockId = JsonProtocol.blockIdFromJson(JsonProtocol.blockIdToJson(blockId)) - blockId == newBlockId - } - /** -------------------------------- * | Util methods for comparing events | From fbfeec80cfb7a1bd86847fa22f641d9b9ad7480f Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 28 Mar 2014 18:33:11 -0700 Subject: [PATCH 22/37] Add functionality to query executors for their local BlockStatuses Not all blocks are reported to the master. In HttpBroadcast and TorrentBroadcast, for instance, most blocks are not reported to master. The lack of a mechanism to get local block statuses on each executor makes it difficult to test the correctness of un/persisting a broadcast. This new functionality, though only used for testing at the moment, is general enough to be used for other things in the future. --- .../spark/network/ConnectionManager.scala | 1 - .../org/apache/spark/storage/BlockInfo.scala | 2 + .../apache/spark/storage/BlockManager.scala | 15 ++-- .../spark/storage/BlockManagerMaster.scala | 33 +++++---- .../storage/BlockManagerMasterActor.scala | 47 ++++++++----- .../spark/storage/BlockManagerMessages.scala | 11 ++- .../storage/BlockManagerSlaveActor.scala | 4 +- .../org/apache/spark/BroadcastSuite.scala | 69 +++++++++++-------- .../apache/spark/ContextCleanerSuite.scala | 4 +- .../spark/storage/BlockManagerSuite.scala | 40 +++++++++++ 10 files changed, 150 insertions(+), 76 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index a75130cba2a2e..bb3abf1d032d1 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -17,7 +17,6 @@ package org.apache.spark.network -import java.net._ import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala index c8f397609a0b4..ef924123a3b11 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala @@ -79,3 +79,5 @@ private object BlockInfo { private val BLOCK_PENDING: Long = -1L private val BLOCK_FAILED: Long = -2L } + +private[spark] case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index a88eb1315a37b..dd2dbd1c8a397 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -209,10 +209,14 @@ private[spark] class BlockManager( } } - /** - * Get storage level of local block. If no info exists for the block, return None. - */ - def getLevel(blockId: BlockId): Option[StorageLevel] = blockInfo.get(blockId).map(_.level) + /** Return the status of the block identified by the given ID, if it exists. */ + def getStatus(blockId: BlockId): Option[BlockStatus] = { + blockInfo.get(blockId).map { info => + val memSize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L + val diskSize = if (diskStore.contains(blockId)) diskStore.getSize(blockId) else 0L + BlockStatus(info.level, memSize, diskSize) + } + } /** * Tell the master about the current storage status of a block. This will send a block update @@ -631,10 +635,9 @@ private[spark] class BlockManager( diskStore.putValues(blockId, iterator, level, askForBytes) case ArrayBufferValues(array) => diskStore.putValues(blockId, array, level, askForBytes) - case ByteBufferValues(bytes) => { + case ByteBufferValues(bytes) => bytes.rewind() diskStore.putBytes(blockId, bytes, level) - } } size = res.size res.data match { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 5c9ea88d6b1a4..f61aa1d6bc0fc 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -148,21 +148,30 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log } /** - * Mainly for testing. Ask the driver to query all executors for their storage levels - * regarding this block. This provides an avenue for the driver to learn the storage - * levels of blocks it has not been informed of. + * Return the block's local status on all block managers, if any. * - * WARNING: This could lead to deadlocks if there are any outstanding messages the - * executors are already expecting from the driver. In this case, while the driver is - * waiting for the executors to respond to its GetStorageLevel query, the executors - * are also waiting for a response from the driver to a prior message. + * If askSlaves is true, this invokes the master to query each block manager for the most + * updated block statuses. This is useful when the master is not informed of the given block + * by all block managers. * - * The interim solution is to wait for a brief window of time to pass before asking. - * This should suffice, since this mechanism is largely introduced for testing only. + * To avoid potential deadlocks, the use of Futures is necessary, because the master actor + * should not block on waiting for a block manager, which can in turn be waiting for the + * master actor for a response to a prior message. */ - def askForStorageLevels(blockId: BlockId, waitTimeMs: Long = 1000) = { - Thread.sleep(waitTimeMs) - askDriverWithReply[Map[BlockManagerId, StorageLevel]](AskForStorageLevels(blockId)) + def getBlockStatus( + blockId: BlockId, + askSlaves: Boolean = true): Map[BlockManagerId, BlockStatus] = { + val msg = GetBlockStatus(blockId, askSlaves) + val response = askDriverWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) + val (blockManagerIds, futures) = response.unzip + val result = Await.result(Future.sequence(futures), timeout) + if (result == null) { + throw new SparkException("BlockManager returned null for BlockStatus query: " + blockId) + } + val blockStatus = result.asInstanceOf[Iterable[Option[BlockStatus]]] + blockManagerIds.zip(blockStatus).flatMap { case (blockManagerId, status) => + status.map { s => (blockManagerId, s) } + }.toMap } /** Stop the driver actor, called only on the Spark driver node */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 3271d4f1375ef..2d9445425b879 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -21,7 +21,7 @@ import java.util.{HashMap => JHashMap} import scala.collection.mutable import scala.collection.JavaConversions._ -import scala.concurrent.{Await, Future} +import scala.concurrent.Future import scala.concurrent.duration._ import akka.actor.{Actor, ActorRef, Cancellable} @@ -93,6 +93,9 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus case GetStorageStatus => sender ! storageStatus + case GetBlockStatus(blockId, askSlaves) => + sender ! blockStatus(blockId, askSlaves) + case RemoveRdd(rddId) => sender ! removeRdd(rddId) @@ -126,9 +129,6 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus case HeartBeat(blockManagerId) => sender ! heartBeat(blockManagerId) - case AskForStorageLevels(blockId) => - sender ! askForStorageLevels(blockId) - case other => logWarning("Got unknown message: " + other) } @@ -254,16 +254,30 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus }.toArray } - // For testing. Ask all block managers for the given block's local storage level, if any. - private def askForStorageLevels(blockId: BlockId): Map[BlockManagerId, StorageLevel] = { - val getStorageLevel = GetStorageLevel(blockId) - blockManagerInfo.values.flatMap { info => - val future = info.slaveActor.ask(getStorageLevel)(akkaTimeout) - val result = Await.result(future, akkaTimeout) - if (result != null) { - // If the block does not exist on the slave, the slave replies None - result.asInstanceOf[Option[StorageLevel]].map { reply => (info.blockManagerId, reply) } - } else None + /** + * Return the block's local status for all block managers, if any. + * + * If askSlaves is true, the master queries each block manager for the most updated block + * statuses. This is useful when the master is not informed of the given block by all block + * managers. + * + * Rather than blocking on the block status query, master actor should simply return a + * Future to avoid potential deadlocks. This can arise if there exists a block manager + * that is also waiting for this master actor's response to a previous message. + */ + private def blockStatus( + blockId: BlockId, + askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = { + import context.dispatcher + val getBlockStatus = GetBlockStatus(blockId) + blockManagerInfo.values.map { info => + val blockStatusFuture = + if (askSlaves) { + info.slaveActor.ask(getBlockStatus)(akkaTimeout).mapTo[Option[BlockStatus]] + } else { + Future { info.getStatus(blockId) } + } + (info.blockManagerId, blockStatusFuture) }.toMap } @@ -352,9 +366,6 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus } } - -private[spark] case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long) - private[spark] class BlockManagerInfo( val blockManagerId: BlockManagerId, timeMs: Long, @@ -371,6 +382,8 @@ private[spark] class BlockManagerInfo( logInfo("Registering block manager %s with %s RAM".format( blockManagerId.hostPort, Utils.bytesToString(maxMem))) + def getStatus(blockId: BlockId) = Option(_blocks.get(blockId)) + def updateLastSeenMs() { _lastSeenMs = System.currentTimeMillis() } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 9a29c39a28ab1..afb2c6a12ce67 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -41,9 +41,6 @@ private[storage] object BlockManagerMessages { case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true) extends ToBlockManagerSlave - // For testing. Ask the slave for the block's storage level. - case class GetStorageLevel(blockId: BlockId) extends ToBlockManagerSlave - ////////////////////////////////////////////////////////////////////////////////// // Messages from slaves to the master. @@ -113,10 +110,10 @@ private[storage] object BlockManagerMessages { case object GetMemoryStatus extends ToBlockManagerMaster - case object ExpireDeadHosts extends ToBlockManagerMaster - case object GetStorageStatus extends ToBlockManagerMaster - // For testing. Have the master ask all slaves for the given block's storage level. - case class AskForStorageLevels(blockId: BlockId) extends ToBlockManagerMaster + case class GetBlockStatus(blockId: BlockId, askSlaves: Boolean = true) + extends ToBlockManagerMaster + + case object ExpireDeadHosts extends ToBlockManagerMaster } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 85b8ec40c0ea3..016ade428c68f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -50,7 +50,7 @@ class BlockManagerSlaveActor( case RemoveBroadcast(broadcastId, removeFromDriver) => blockManager.removeBroadcast(broadcastId, removeFromDriver) - case GetStorageLevel(blockId) => - sender ! blockManager.getLevel(blockId) + case GetBlockStatus(blockId, _) => + sender ! blockManager.getStatus(blockId) } } diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index 9e600f1e91aa2..d28496e316a34 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -107,22 +107,26 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { // Verify that the broadcast file is created, and blocks are persisted only on the driver def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { assert(blockIds.size === 1) - val levels = bmm.askForStorageLevels(blockIds.head, waitTimeMs = 0) - assert(levels.size === 1) - levels.head match { case (bm, level) => - assert(bm.executorId === "") - assert(level === StorageLevel.MEMORY_AND_DISK) + val statuses = bmm.getBlockStatus(blockIds.head) + assert(statuses.size === 1) + statuses.head match { case (bm, status) => + assert(bm.executorId === "", "Block should only be on the driver") + assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) + assert(status.memSize > 0, "Block should be in memory store on the driver") + assert(status.diskSize === 0, "Block should not be in disk store on the driver") } - assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists) + assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists, "Broadcast file not found!") } // Verify that blocks are persisted in both the executors and the driver def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { assert(blockIds.size === 1) - val levels = bmm.askForStorageLevels(blockIds.head, waitTimeMs = 0) - assert(levels.size === numSlaves + 1) - levels.foreach { case (_, level) => - assert(level === StorageLevel.MEMORY_AND_DISK) + val statuses = bmm.getBlockStatus(blockIds.head) + assert(statuses.size === numSlaves + 1) + statuses.foreach { case (_, status) => + assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) + assert(status.memSize > 0, "Block should be in memory store") + assert(status.diskSize === 0, "Block should not be in disk store") } } @@ -130,9 +134,13 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { // is true. In the latter case, also verify that the broadcast file is deleted on the driver. def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { assert(blockIds.size === 1) - val levels = bmm.askForStorageLevels(blockIds.head, waitTimeMs = 0) - assert(levels.size === (if (removeFromDriver) 0 else 1)) - assert(removeFromDriver === !HttpBroadcast.getFile(blockIds.head.broadcastId).exists) + val statuses = bmm.getBlockStatus(blockIds.head) + val expectedNumBlocks = if (removeFromDriver) 0 else 1 + val possiblyNot = if (removeFromDriver) "" else " not" + assert(statuses.size === expectedNumBlocks, + "Block should%s be unpersisted on the driver".format(possiblyNot)) + assert(removeFromDriver === !HttpBroadcast.getFile(blockIds.head.broadcastId).exists, + "Broadcast file should%s be deleted".format(possiblyNot)) } testUnpersistBroadcast(numSlaves, httpConf, getBlockIds, afterCreation, @@ -158,11 +166,13 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { // Verify that blocks are persisted only on the driver def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { blockIds.foreach { blockId => - val levels = bmm.askForStorageLevels(blockId, waitTimeMs = 0) - assert(levels.size === 1) - levels.head match { case (bm, level) => - assert(bm.executorId === "") - assert(level === StorageLevel.MEMORY_AND_DISK) + val statuses = bmm.getBlockStatus(blockIds.head) + assert(statuses.size === 1) + statuses.head match { case (bm, status) => + assert(bm.executorId === "", "Block should only be on the driver") + assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) + assert(status.memSize > 0, "Block should be in memory store on the driver") + assert(status.diskSize === 0, "Block should not be in disk store on the driver") } } } @@ -170,16 +180,18 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { // Verify that blocks are persisted in both the executors and the driver def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { blockIds.foreach { blockId => - val levels = bmm.askForStorageLevels(blockId, waitTimeMs = 0) + val statuses = bmm.getBlockStatus(blockId) if (blockId.field == "meta") { // Meta data is only on the driver - assert(levels.size === 1) - levels.head match { case (bm, _) => assert(bm.executorId === "") } + assert(statuses.size === 1) + statuses.head match { case (bm, _) => assert(bm.executorId === "") } } else { // Other blocks are on both the executors and the driver - assert(levels.size === numSlaves + 1) - levels.foreach { case (_, level) => - assert(level === StorageLevel.MEMORY_AND_DISK) + assert(statuses.size === numSlaves + 1) + statuses.foreach { case (_, status) => + assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) + assert(status.memSize > 0, "Block should be in memory store") + assert(status.diskSize === 0, "Block should not be in disk store") } } } @@ -189,12 +201,11 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { // is true. def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { val expectedNumBlocks = if (removeFromDriver) 0 else 1 - var waitTimeMs = 1000L + val possiblyNot = if (removeFromDriver) "" else " not" blockIds.foreach { blockId => - // Allow a second for the messages triggered by unpersist to propagate to prevent deadlocks - val levels = bmm.askForStorageLevels(blockId, waitTimeMs) - assert(levels.size === expectedNumBlocks) - waitTimeMs = 0L + val statuses = bmm.getBlockStatus(blockId) + assert(statuses.size === expectedNumBlocks, + "Block should%s be unpersisted on the driver".format(possiblyNot)) } } diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 6a12cb6603700..3d95547b20fc1 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -267,7 +267,7 @@ class CleanerTester( "One or more shuffles' blocks cannot be found in disk manager, cannot start cleaner test") // Verify that the broadcast is in the driver's block manager - assert(broadcastIds.forall(bid => blockManager.getLevel(broadcastBlockId(bid)).isDefined), + assert(broadcastIds.forall(bid => blockManager.getStatus(broadcastBlockId(bid)).isDefined), "One ore more broadcasts have not been persisted in the driver's block manager") } @@ -291,7 +291,7 @@ class CleanerTester( // Verify all broadcasts have been unpersisted assert(broadcastIds.forall { bid => - blockManager.master.askForStorageLevels(broadcastBlockId(bid)).isEmpty + blockManager.master.getBlockStatus(broadcastBlockId(bid)).isEmpty }) return diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 1f5bcca64fc39..bddbd381c2665 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -745,6 +745,46 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT assert(!store.get("list5").isDefined, "list5 was in store") } + test("query block statuses") { + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) + val list = List.fill(2)(new Array[Byte](200)) + + // Tell master. By LRU, only list2 and list3 remains. + store.put("list1", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.put("list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.put("list3", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + + // getLocations and getBlockStatus should yield the same locations + assert(store.master.getLocations("list1").size === 0) + assert(store.master.getLocations("list2").size === 1) + assert(store.master.getLocations("list3").size === 1) + assert(store.master.getBlockStatus("list1", askSlaves = false).size === 0) + assert(store.master.getBlockStatus("list2", askSlaves = false).size === 1) + assert(store.master.getBlockStatus("list3", askSlaves = false).size === 1) + assert(store.master.getBlockStatus("list1", askSlaves = true).size === 0) + assert(store.master.getBlockStatus("list2", askSlaves = true).size === 1) + assert(store.master.getBlockStatus("list3", askSlaves = true).size === 1) + + // This time don't tell master and see what happens. By LRU, only list5 and list6 remains. + store.put("list4", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false) + store.put("list5", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + store.put("list6", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false) + + // getLocations should return nothing because the master is not informed + // getBlockStatus without asking slaves should have the same result + // getBlockStatus with asking slaves, however, should present the actual block statuses + assert(store.master.getLocations("list4").size === 0) + assert(store.master.getLocations("list5").size === 0) + assert(store.master.getLocations("list6").size === 0) + assert(store.master.getBlockStatus("list4", askSlaves = false).size === 0) + assert(store.master.getBlockStatus("list5", askSlaves = false).size === 0) + assert(store.master.getBlockStatus("list6", askSlaves = false).size === 0) + assert(store.master.getBlockStatus("list4", askSlaves = true).size === 0) + assert(store.master.getBlockStatus("list5", askSlaves = true).size === 1) + assert(store.master.getBlockStatus("list6", askSlaves = true).size === 1) + } + test("SPARK-1194 regression: fix the same-RDD rule for cache replacement") { store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr, mapOutputTracker) From 88904a3659fe4a81bdfb2a6b615894d926af3fe1 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 28 Mar 2014 23:02:11 -0700 Subject: [PATCH 23/37] Make TimeStampedWeakValueHashMap a wrapper of TimeStampedHashMap This allows us to get rid of WrappedJavaHashMap without much duplicate code. --- .../scala/org/apache/spark/SparkContext.scala | 1 - .../apache/spark/storage/BlockManager.scala | 7 +- .../spark/util/TimeStampedHashMap.scala | 117 +++++++--- .../util/TimeStampedWeakValueHashMap.scala | 164 +++++++------- .../spark/util/WrappedJavaHashMap.scala | 152 ------------- .../spark/util/WrappedJavaHashMapSuite.scala | 206 ------------------ 6 files changed, 168 insertions(+), 479 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/util/WrappedJavaHashMap.scala delete mode 100644 core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 79574c271cfb6..13fba1e0dfe5d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -35,7 +35,6 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHad import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.mesos.MesosNativeLibrary -import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index dd2dbd1c8a397..991881b00c0eb 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -209,7 +209,7 @@ private[spark] class BlockManager( } } - /** Return the status of the block identified by the given ID, if it exists. */ + /** Get the BlockStatus for the block identified by the given ID, if it exists.*/ def getStatus(blockId: BlockId): Option[BlockStatus] = { blockInfo.get(blockId).map { info => val memSize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L @@ -635,9 +635,10 @@ private[spark] class BlockManager( diskStore.putValues(blockId, iterator, level, askForBytes) case ArrayBufferValues(array) => diskStore.putValues(blockId, array, level, askForBytes) - case ByteBufferValues(bytes) => + case ByteBufferValues(bytes) => { bytes.rewind() diskStore.putBytes(blockId, bytes, level) + } } size = res.size res.data match { @@ -872,7 +873,7 @@ private[spark] class BlockManager( } private def dropOldBlocks(cleanupTime: Long, shouldDrop: (BlockId => Boolean)) { - val iterator = blockInfo.internalMap.entrySet().iterator() + val iterator = blockInfo.getEntrySet.iterator while (iterator.hasNext) { val entry = iterator.next() val (id, info, time) = (entry.getKey, entry.getValue.value, entry.getValue.timestamp) diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala index c4d770fecdf74..1721818c212f9 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -17,64 +17,108 @@ package org.apache.spark.util +import java.util.Set +import java.util.Map.Entry import java.util.concurrent.ConcurrentHashMap +import scala.collection.{immutable, JavaConversions, mutable} + import org.apache.spark.Logging -private[util] case class TimeStampedValue[T](timestamp: Long, value: T) +private[spark] case class TimeStampedValue[V](value: V, timestamp: Long) /** - * A map that stores the timestamp of when a key was inserted along with the value. If specified, - * the timestamp of each pair can be updated every time it is accessed. - * Key-value pairs whose timestamps are older than a particular - * threshold time can then be removed using the clearOldValues method. It exposes a - * scala.collection.mutable.Map interface to allow it to be a drop-in replacement for Scala - * HashMaps. - * - * Internally, it uses a Java ConcurrentHashMap, so all operations on this HashMap are thread-safe. + * This is a custom implementation of scala.collection.mutable.Map which stores the insertion + * timestamp along with each key-value pair. If specified, the timestamp of each pair can be + * updated every time it is accessed. Key-value pairs whose timestamp are older than a particular + * threshold time can then be removed using the clearOldValues method. This is intended to + * be a drop-in replacement of scala.collection.mutable.HashMap. * - * @param updateTimeStampOnGet When enabled, the timestamp of a pair will be - * updated when it is accessed + * @param updateTimeStampOnGet Whether timestamp of a pair will be updated when it is accessed */ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false) - extends WrappedJavaHashMap[A, B, A, TimeStampedValue[B]] with Logging { + extends mutable.Map[A, B]() with Logging { - private[util] val internalJavaMap = new ConcurrentHashMap[A, TimeStampedValue[B]]() + private val internalMap = new ConcurrentHashMap[A, TimeStampedValue[B]]() - private[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = { - new TimeStampedHashMap[K1, V1]() + def get(key: A): Option[B] = { + val value = internalMap.get(key) + if (value != null && updateTimeStampOnGet) { + internalMap.replace(key, value, TimeStampedValue(value.value, currentTime)) + } + Option(value).map(_.value) } - def internalMap = internalJavaMap + def iterator: Iterator[(A, B)] = { + val jIterator = getEntrySet.iterator() + JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue.value)) + } - override def get(key: A): Option[B] = { - val timeStampedValue = internalMap.get(key) - if (updateTimeStampOnGet && timeStampedValue != null) { - internalJavaMap.replace(key, timeStampedValue, - TimeStampedValue(currentTime, timeStampedValue.value)) - } - Option(timeStampedValue).map(_.value) + def getEntrySet: Set[Entry[A, TimeStampedValue[B]]] = internalMap.entrySet() + + override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = { + val newMap = new TimeStampedHashMap[A, B1] + val oldInternalMap = this.internalMap.asInstanceOf[ConcurrentHashMap[A, TimeStampedValue[B1]]] + newMap.internalMap.putAll(oldInternalMap) + kv match { case (a, b) => newMap.internalMap.put(a, TimeStampedValue(b, currentTime)) } + newMap } - @inline override protected def externalValueToInternalValue(v: B): TimeStampedValue[B] = { - new TimeStampedValue(currentTime, v) + + override def - (key: A): mutable.Map[A, B] = { + val newMap = new TimeStampedHashMap[A, B] + newMap.internalMap.putAll(this.internalMap) + newMap.internalMap.remove(key) + newMap + } + + override def += (kv: (A, B)): this.type = { + kv match { case (a, b) => internalMap.put(a, TimeStampedValue(b, currentTime)) } + this + } + + override def -= (key: A): this.type = { + internalMap.remove(key) + this + } + + override def update(key: A, value: B) { + this += ((key, value)) } - @inline override protected def internalValueToExternalValue(iv: TimeStampedValue[B]): B = { - iv.value + override def apply(key: A): B = { + val value = internalMap.get(key) + Option(value).map(_.value).getOrElse { throw new NoSuchElementException() } } - /** Atomically put if a key is absent. This exposes the existing API of ConcurrentHashMap. */ + override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = { + JavaConversions.mapAsScalaConcurrentMap(internalMap) + .map { case (k, TimeStampedValue(v, t)) => (k, v) } + .filter(p) + } + + override def empty: mutable.Map[A, B] = new TimeStampedHashMap[A, B]() + + override def size: Int = internalMap.size + + override def foreach[U](f: ((A, B)) => U) { + val iterator = getEntrySet.iterator() + while(iterator.hasNext) { + val entry = iterator.next() + val kv = (entry.getKey, entry.getValue.value) + f(kv) + } + } + + // Should we return previous value directly or as Option? def putIfAbsent(key: A, value: B): Option[B] = { - val prev = internalJavaMap.putIfAbsent(key, TimeStampedValue(currentTime, value)) + val prev = internalMap.putIfAbsent(key, TimeStampedValue(value, currentTime)) Option(prev).map(_.value) } - /** - * Removes old key-value pairs that have timestamp earlier than `threshTime`, - * calling the supplied function on each such entry before removing. - */ + def toMap: immutable.Map[A, B] = iterator.toMap + def clearOldValues(threshTime: Long, f: (A, B) => Unit) { - val iterator = internalJavaMap.entrySet().iterator() + val iterator = getEntrySet.iterator() while (iterator.hasNext) { val entry = iterator.next() if (entry.getValue.timestamp < threshTime) { @@ -86,11 +130,12 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa } /** - * Removes old key-value pairs that have timestamp earlier than `threshTime` + * Removes old key-value pairs that have timestamp earlier than `threshTime`. */ def clearOldValues(threshTime: Long) { clearOldValues(threshTime, (_, _) => ()) } - private def currentTime: Long = System.currentTimeMillis() + private def currentTime: Long = System.currentTimeMillis + } diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala index 9f3247a27ba38..f814f58261bf3 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala @@ -18,113 +18,115 @@ package org.apache.spark.util import java.lang.ref.WeakReference -import java.util -import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.atomic.AtomicInteger -import scala.collection.JavaConversions - -import org.apache.spark.Logging - -private[util] case class TimeStampedWeakValue[T](timestamp: Long, weakValue: WeakReference[T]) { - def this(timestamp: Long, value: T) = this(timestamp, new WeakReference[T](value)) -} +import scala.collection.{immutable, mutable} /** - * A map that stores the timestamp of when a key was inserted along with the value, - * while ensuring that the values are weakly referenced. If the value is garbage collected and - * the weak reference is null, get() operation returns the key be non-existent. However, - * the key is actually not removed in the current implementation. Key-value pairs whose - * timestamps are older than a particular threshold time can then be removed using the - * clearOldValues method. It exposes a scala.collection.mutable.Map interface to allow it to be a - * drop-in replacement for Scala HashMaps. + * A wrapper of TimeStampedHashMap that ensures the values are weakly referenced and timestamped. + * + * If the value is garbage collected and the weak reference is null, get() operation returns + * a non-existent value. However, the corresponding key is actually not removed in the current + * implementation. Key-value pairs whose timestamps are older than a particular threshold time + * can then be removed using the clearOldValues method. It exposes a scala.collection.mutable.Map + * interface to allow it to be a drop-in replacement for Scala HashMaps. * * Internally, it uses a Java ConcurrentHashMap, so all operations on this HashMap are thread-safe. + * + * @param updateTimeStampOnGet Whether timestamp of a pair will be updated when it is accessed. */ +private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boolean = false) + extends mutable.Map[A, B]() { + + import TimeStampedWeakValueHashMap._ -private[spark] class TimeStampedWeakValueHashMap[A, B]() - extends WrappedJavaHashMap[A, B, A, TimeStampedWeakValue[B]] with Logging { + private val internalMap = new TimeStampedHashMap[A, WeakReference[B]](updateTimeStampOnGet) - /** Number of inserts after which keys whose weak ref values are null will be cleaned */ - private val CLEANUP_INTERVAL = 1000 + def get(key: A): Option[B] = internalMap.get(key) - /** Counter for counting the number of inserts */ - private val insertCounts = new AtomicInteger(0) + def iterator: Iterator[(A, B)] = internalMap.iterator + + override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = { + val newMap = new TimeStampedWeakValueHashMap[A, B1] + newMap.internalMap += kv + newMap + } - private[util] val internalJavaMap: util.Map[A, TimeStampedWeakValue[B]] = { - new ConcurrentHashMap[A, TimeStampedWeakValue[B]]() + override def - (key: A): mutable.Map[A, B] = { + val newMap = new TimeStampedWeakValueHashMap[A, B] + newMap.internalMap -= key + newMap } - private[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = { - new TimeStampedWeakValueHashMap[K1, V1]() + override def += (kv: (A, B)): this.type = { + internalMap += kv + this } - override def +=(kv: (A, B)): this.type = { - // Cleanup null value at certain intervals - if (insertCounts.incrementAndGet() % CLEANUP_INTERVAL == 0) { - cleanNullValues() - } - super.+=(kv) + override def -= (key: A): this.type = { + internalMap -= key + this } - override def get(key: A): Option[B] = { - Option(internalJavaMap.get(key)).flatMap { weakValue => - val value = weakValue.weakValue.get - if (value == null) { - internalJavaMap.remove(key) - } - Option(value) - } + override def update(key: A, value: B) = this += ((key, value)) + + override def apply(key: A): B = internalMap.apply(key) + + override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = internalMap.filter(p) + + override def empty: mutable.Map[A, B] = new TimeStampedWeakValueHashMap[A, B]() + + override def size: Int = internalMap.size + + override def foreach[U](f: ((A, B)) => U) = internalMap.foreach(f) + + def putIfAbsent(key: A, value: B): Option[B] = internalMap.putIfAbsent(key, value) + + def toMap: immutable.Map[A, B] = iterator.toMap + + /** + * Remove old key-value pairs that have timestamp earlier than `threshTime`. + */ + def clearOldValues(threshTime: Long) = internalMap.clearOldValues(threshTime) + +} + +/** + * Helper methods for converting to and from WeakReferences. + */ +private[spark] object TimeStampedWeakValueHashMap { + + /* Implicit conversion methods to WeakReferences */ + + implicit def toWeakReference[V](v: V): WeakReference[V] = new WeakReference[V](v) + + implicit def toWeakReferenceTuple[K, V](kv: (K, V)): (K, WeakReference[V]) = { + kv match { case (k, v) => (k, toWeakReference(v)) } } - @inline override protected def externalValueToInternalValue(v: B): TimeStampedWeakValue[B] = { - new TimeStampedWeakValue(currentTime, v) + implicit def toWeakReferenceFunction[K, V, R](p: ((K, V)) => R): ((K, WeakReference[V])) => R = { + (kv: (K, WeakReference[V])) => p(kv) } - @inline override protected def internalValueToExternalValue(iv: TimeStampedWeakValue[B]): B = { - iv.weakValue.get + /* Implicit conversion methods from WeakReferences */ + + implicit def fromWeakReference[V](ref: WeakReference[V]): V = ref.get + + implicit def fromWeakReferenceOption[V](v: Option[WeakReference[V]]): Option[V] = { + v.map(fromWeakReference) } - override def iterator: Iterator[(A, B)] = { - val iterator = internalJavaMap.entrySet().iterator() - JavaConversions.asScalaIterator(iterator).flatMap(kv => { - val (key, value) = (kv.getKey, kv.getValue.weakValue.get) - if (value != null) Seq((key, value)) else Seq.empty - }) + implicit def fromWeakReferenceTuple[K, V](kv: (K, WeakReference[V])): (K, V) = { + kv match { case (k, v) => (k, fromWeakReference(v)) } } - /** - * Removes old key-value pairs that have timestamp earlier than `threshTime`, - * calling the supplied function on each such entry before removing. - */ - def clearOldValues(threshTime: Long, f: (A, B) => Unit = null) { - val iterator = internalJavaMap.entrySet().iterator() - while (iterator.hasNext) { - val entry = iterator.next() - if (entry.getValue.timestamp < threshTime) { - val value = entry.getValue.weakValue.get - if (f != null && value != null) { - f(entry.getKey, value) - } - logDebug("Removing key " + entry.getKey) - iterator.remove() - } - } + implicit def fromWeakReferenceIterator[K, V]( + it: Iterator[(K, WeakReference[V])]): Iterator[(K, V)] = { + it.map(fromWeakReferenceTuple) } - /** - * Removes keys whose weak referenced values have become null. - */ - private def cleanNullValues() { - val iterator = internalJavaMap.entrySet().iterator() - while (iterator.hasNext) { - val entry = iterator.next() - if (entry.getValue.weakValue.get == null) { - logDebug("Removing key " + entry.getKey) - iterator.remove() - } - } + implicit def fromWeakReferenceMap[K, V]( + map: mutable.Map[K, WeakReference[V]]) : mutable.Map[K, V] = { + mutable.Map(map.mapValues(fromWeakReference).toSeq: _*) } - private def currentTime = System.currentTimeMillis() } diff --git a/core/src/main/scala/org/apache/spark/util/WrappedJavaHashMap.scala b/core/src/main/scala/org/apache/spark/util/WrappedJavaHashMap.scala deleted file mode 100644 index 6cc3007f5d7ac..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/WrappedJavaHashMap.scala +++ /dev/null @@ -1,152 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import scala.collection.mutable.Map -import java.util.{Map => JMap} -import java.util.Map.{Entry => JMapEntry} -import scala.collection.{immutable, JavaConversions} -import scala.reflect.ClassTag - -/** - * Convenient wrapper class for exposing Java HashMaps as Scala Maps even if the - * exposed key-value type is different from the internal type. This allows these - * implementations of WrappedJavaHashMap to be drop-in replacements for Scala HashMaps. - * - * While Java <-> Scala conversion methods exists, its hard to understand the performance - * implications and thread safety of the Scala wrapper. This class allows you to convert - * between types and applying the necessary overridden methods to take care of performance. - * - * Note that the threading behavior of an implementation of WrappedJavaHashMap is tied to that of - * the internal Java HashMap used in the implementation. Each implementation must use - * necessary traits (e.g, scala.collection.mutable.SynchronizedMap), etc. to achieve the - * desired thread safety. - * - * @tparam K External key type - * @tparam V External value type - * @tparam IK Internal key type - * @tparam IV Internal value type - */ -private[spark] abstract class WrappedJavaHashMap[K, V, IK, IV] extends Map[K, V] { - - /* Methods that must be defined. */ - - /** - * Internal Java HashMap that is being wrapped. - * Scoped private[util] so that rest of Spark code cannot - * directly access the internal map. - */ - private[util] val internalJavaMap: JMap[IK, IV] - - /** Method to get a new instance of the internal Java HashMap. */ - private[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] - - /* - Methods that convert between internal and external types. These implementations - optimistically assume that the internal types are same as external types. These must - be overridden if the internal and external types are different. Otherwise there will be - runtime exceptions. - */ - - @inline protected def externalKeyToInternalKey(k: K): IK = { - k.asInstanceOf[IK] // works only if K is same or subclass of K - } - - @inline protected def externalValueToInternalValue(v: V): IV = { - v.asInstanceOf[IV] // works only if V is same or subclass of - } - - @inline protected def internalKeyToExternalKey(ik: IK): K = { - ik.asInstanceOf[K] - } - - @inline protected def internalValueToExternalValue(iv: IV): V = { - iv.asInstanceOf[V] - } - - @inline protected def internalPairToExternalPair(ip: JMapEntry[IK, IV]): (K, V) = { - (internalKeyToExternalKey(ip.getKey), internalValueToExternalValue(ip.getValue) ) - } - - /* Implicit methods to convert the types. */ - - @inline implicit private def convExtKeyToIntKey(k: K) = externalKeyToInternalKey(k) - - @inline implicit private def convExtValueToIntValue(v: V) = externalValueToInternalValue(v) - - @inline implicit private def convIntKeyToExtKey(ia: IK) = internalKeyToExternalKey(ia) - - @inline implicit private def convIntValueToExtValue(ib: IV) = internalValueToExternalValue(ib) - - @inline implicit private def convIntPairToExtPair(ip: JMapEntry[IK, IV]) = { - internalPairToExternalPair(ip) - } - - /* Methods that must be implemented for a scala.collection.mutable.Map */ - - def get(key: K): Option[V] = { - Option(internalJavaMap.get(key)) - } - - def iterator: Iterator[(K, V)] = { - val jIterator = internalJavaMap.entrySet().iterator() - JavaConversions.asScalaIterator(jIterator).map(kv => convIntPairToExtPair(kv)) - } - - /* Other methods that are implemented to ensure performance. */ - - def +=(kv: (K, V)): this.type = { - internalJavaMap.put(kv._1, kv._2) - this - } - - def -=(key: K): this.type = { - internalJavaMap.remove(key) - this - } - - override def + [V1 >: V](kv: (K, V1)): Map[K, V1] = { - val newMap = newInstance[K, V1]() - newMap.internalJavaMap.asInstanceOf[JMap[IK, IV]].putAll(this.internalJavaMap) - newMap += kv - newMap - } - - override def - (key: K): Map[K, V] = { - val newMap = newInstance[K, V]() - newMap.internalJavaMap.asInstanceOf[JMap[IK, IV]].putAll(this.internalJavaMap) - newMap -= key - } - - override def foreach[U](f: ((K, V)) => U) { - val jIterator = internalJavaMap.entrySet().iterator() - while(jIterator.hasNext) { - f(jIterator.next()) - } - } - - override def empty: Map[K, V] = newInstance[K, V]() - - override def size: Int = internalJavaMap.size - - override def filter(p: ((K, V)) => Boolean): Map[K, V] = { - newInstance[K, V]() ++= iterator.filter(p) - } - - def toMap: immutable.Map[K, V] = iterator.toMap -} diff --git a/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala deleted file mode 100644 index f6e6a4c77c820..0000000000000 --- a/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala +++ /dev/null @@ -1,206 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import java.lang.ref.WeakReference -import java.util - -import scala.collection.mutable.{ArrayBuffer, HashMap, Map} -import scala.util.Random - -import org.scalatest.FunSuite - -class WrappedJavaHashMapSuite extends FunSuite { - - // Test the testMap function - a Scala HashMap should obviously pass - testMap(new HashMap[String, String]()) - - // Test a simple WrappedJavaHashMap - testMap(new TestMap[String, String]()) - - // Test TimeStampedHashMap - testMap(new TimeStampedHashMap[String, String]) - - testMapThreadSafety(new TimeStampedHashMap[String, String]) - - test("TimeStampedHashMap - clearing by timestamp") { - // clearing by insertion time - val map = new TimeStampedHashMap[String, String](false) - map("k1") = "v1" - assert(map("k1") === "v1") - Thread.sleep(10) - val threshTime = System.currentTimeMillis() - assert(map.internalMap.get("k1").timestamp < threshTime) - map.clearOldValues(threshTime) - assert(map.get("k1") === None) - - // clearing by modification time - val map1 = new TimeStampedHashMap[String, String](true) - map1("k1") = "v1" - map1("k2") = "v2" - assert(map1("k1") === "v1") - Thread.sleep(10) - val threshTime1 = System.currentTimeMillis() - Thread.sleep(10) - assert(map1("k2") === "v2") // access k2 to update its access time to > threshTime - assert(map1.internalMap.get("k1").timestamp < threshTime1) - assert(map1.internalMap.get("k2").timestamp >= threshTime1) - map1.clearOldValues(threshTime1) //should only clear k1 - assert(map1.get("k1") === None) - assert(map1.get("k2").isDefined) - } - - // Test TimeStampedHashMap - testMap(new TimeStampedWeakValueHashMap[String, String]) - - testMapThreadSafety(new TimeStampedWeakValueHashMap[String, String]) - - test("TimeStampedWeakValueHashMap - clearing by timestamp") { - // clearing by insertion time - val map = new TimeStampedWeakValueHashMap[String, String]() - map("k1") = "v1" - assert(map("k1") === "v1") - Thread.sleep(10) - val threshTime = System.currentTimeMillis() - assert(map.internalJavaMap.get("k1").timestamp < threshTime) - map.clearOldValues(threshTime) - assert(map.get("k1") === None) - } - - - test("TimeStampedWeakValueHashMap - get not returning null when weak reference is cleared") { - var strongRef = new Object - val weakRef = new WeakReference(strongRef) - val map = new TimeStampedWeakValueHashMap[String, Object] - - map("k1") = strongRef - assert(map("k1") === strongRef) - - strongRef = null - val startTime = System.currentTimeMillis - System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC. - System.runFinalization() // Make a best effort to call finalizer on all cleaned objects. - while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { - System.gc() - System.runFinalization() - Thread.sleep(100) - } - assert(map.internalJavaMap.get("k1").weakValue.get == null) - assert(map.get("k1") === None) - - // TODO (TD): Test clearing of null-value pairs - } - - def testMap(hashMapConstructor: => Map[String, String]) { - def newMap() = hashMapConstructor - - val name = newMap().getClass.getSimpleName - - test(name + " - basic test") { - val testMap1 = newMap() - - // put and get - testMap1 += (("k1", "v1")) - assert(testMap1.get("k1").get === "v1") - testMap1("k2") = "v2" - assert(testMap1.get("k2").get === "v2") - assert(testMap1("k2") === "v2") - - // remove - testMap1.remove("k1") - assert(testMap1.get("k1").isEmpty) - testMap1.remove("k2") - intercept[Exception] { - testMap1("k2") // Map.apply() causes exception - } - - // multi put - val keys = (1 to 100).map(_.toString) - val pairs = keys.map(x => (x, x * 2)) - val testMap2 = newMap() - assert((testMap2 ++ pairs).iterator.toSet === pairs.toSet) - testMap2 ++= pairs - - // iterator - assert(testMap2.iterator.toSet === pairs.toSet) - testMap2("k1") = "v1" - - // foreach - val buffer = new ArrayBuffer[(String, String)] - testMap2.foreach(x => buffer += x) - assert(testMap2.toSet === buffer.toSet) - - // multi remove - testMap2 --= keys - assert(testMap2.size === 1) - assert(testMap2.iterator.toSeq.head === ("k1", "v1")) - } - } - - def testMapThreadSafety(hashMapConstructor: => Map[String, String]) { - def newMap() = hashMapConstructor - - val name = newMap().getClass.getSimpleName - val testMap = newMap() - @volatile var error = false - - def getRandomKey(m: Map[String, String]): Option[String] = { - val keys = testMap.keysIterator.toSeq - if (keys.nonEmpty) { - Some(keys(Random.nextInt(keys.size))) - } else { - None - } - } - - val threads = (1 to 100).map(i => new Thread() { - override def run() { - try { - for (j <- 1 to 1000) { - Random.nextInt(3) match { - case 0 => - testMap(Random.nextString(10)) = Random.nextDouble.toString // put - case 1 => - getRandomKey(testMap).map(testMap.get) // get - case 2 => - getRandomKey(testMap).map(testMap.remove) // remove - } - } - } catch { - case t : Throwable => - error = true - throw t - } - } - }) - - test(name + " - threading safety test") { - threads.map(_.start) - threads.map(_.join) - assert(!error) - } - } -} - -class TestMap[A, B] extends WrappedJavaHashMap[A, B, A, B] { - private[util] val internalJavaMap: util.Map[A, B] = new util.HashMap[A, B]() - - private[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = { - new TestMap[K1, V1] - } -} From 7ed72fbbef4be653bce83ce75ad9929d29b36fcf Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 31 Mar 2014 14:24:18 -0700 Subject: [PATCH 24/37] Fix style test fail + remove verbose test message regarding broadcast --- .../org/apache/spark/broadcast/HttpBroadcast.scala | 12 ++++++------ .../org/apache/spark/storage/BlockManager.scala | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index d8981bb42e684..79216bd2b8404 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -219,12 +219,12 @@ private[spark] object HttpBroadcast extends Logging { private def deleteBroadcastFile(file: File) { try { - if (!file.exists) { - logWarning("Broadcast file to be deleted does not exist: %s".format(file)) - } else if (file.delete()) { - logInfo("Deleted broadcast file: %s".format(file)) - } else { - logWarning("Could not delete broadcast file: %s".format(file)) + if (file.exists) { + if (file.delete()) { + logInfo("Deleted broadcast file: %s".format(file)) + } else { + logWarning("Could not delete broadcast file: %s".format(file)) + } } } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 991881b00c0eb..c90abb187bdb1 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -209,7 +209,7 @@ private[spark] class BlockManager( } } - /** Get the BlockStatus for the block identified by the given ID, if it exists.*/ + /** Get the BlockStatus for the block identified by the given ID, if it exists. */ def getStatus(blockId: BlockId): Option[BlockStatus] = { blockInfo.get(blockId).map { info => val memSize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L From 5016375fb32c0de8df0529467a3c5a57fe73a18f Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 1 Apr 2014 13:16:11 -0700 Subject: [PATCH 25/37] Address TD's comments --- .../org/apache/spark/ContextCleaner.scala | 7 +- .../apache/spark/broadcast/Broadcast.scala | 31 +++++--- .../spark/broadcast/BroadcastManager.scala | 3 +- .../spark/broadcast/HttpBroadcast.scala | 25 ++++--- .../spark/broadcast/TorrentBroadcast.scala | 35 ++++++---- .../main/scala/org/apache/spark/rdd/RDD.scala | 1 - .../org/apache/spark/storage/BlockId.scala | 21 ++---- .../apache/spark/storage/BlockManager.scala | 4 +- .../spark/storage/BlockManagerMaster.scala | 23 +++--- .../storage/BlockManagerMasterActor.scala | 11 +-- .../org/apache/spark/util/JsonProtocol.scala | 70 ++++++++++++++++++- .../org/apache/spark/BroadcastSuite.scala | 17 +++-- .../spark/storage/BlockManagerSuite.scala | 2 +- .../apache/spark/util/JsonProtocolSuite.scala | 15 +++- 14 files changed, 181 insertions(+), 84 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index f856a13f84dec..b71b7fa517fd2 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -169,18 +169,17 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { // Used for testing - private[spark] def cleanupRDD(rdd: RDD[_]) { + def cleanupRDD(rdd: RDD[_]) { doCleanupRDD(rdd.id) } - private[spark] def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) { + def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) { doCleanupShuffle(shuffleDependency.shuffleId) } - private[spark] def cleanupBroadcast[T](broadcast: Broadcast[T]) { + def cleanupBroadcast[T](broadcast: Broadcast[T]) { doCleanupBroadcast(broadcast.id) } - } private object ContextCleaner { diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index 3a2fef05861e6..81e0e5297683b 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -19,6 +19,8 @@ package org.apache.spark.broadcast import java.io.Serializable +import org.apache.spark.SparkException + /** * A broadcast variable. Broadcast variables allow the programmer to keep a read-only variable * cached on each machine rather than shipping a copy of it with tasks. They can be used, for @@ -49,25 +51,36 @@ import java.io.Serializable */ abstract class Broadcast[T](val id: Long) extends Serializable { + protected var _isValid: Boolean = true + /** * Whether this Broadcast is actually usable. This should be false once persisted state is * removed from the driver. */ - protected var isValid: Boolean = true + def isValid: Boolean = _isValid def value: T /** - * Remove all persisted state associated with this broadcast. Overriding implementations - * should set isValid to false if persisted state is also removed from the driver. - * - * @param removeFromDriver Whether to remove state from the driver. - * If true, the resulting broadcast should no longer be valid. + * Remove all persisted state associated with this broadcast on the executors. The next use + * of this broadcast on the executors will trigger a remote fetch. */ - def unpersist(removeFromDriver: Boolean) + def unpersist() - // We cannot define abstract readObject and writeObject here due to some weird issues - // with these methods having to be 'private' in sub-classes. + /** + * Remove all persisted state associated with this broadcast on both the executors and the + * driver. Overriding implementations should set isValid to false. + */ + private[spark] def destroy() + + /** + * If this broadcast is no longer valid, throw an exception. + */ + protected def assertValid() { + if (!_isValid) { + throw new SparkException("Attempted to use %s when is no longer valid!".format(toString)) + } + } override def toString = "Broadcast(" + id + ")" } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index 85d62aae03959..c3ea16ff9eb5e 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -25,7 +25,7 @@ private[spark] class BroadcastManager( val isDriver: Boolean, conf: SparkConf, securityManager: SecurityManager) - extends Logging with Serializable { + extends Logging { private var initialized = false private var broadcastFactory: BroadcastFactory = null @@ -63,5 +63,4 @@ private[spark] class BroadcastManager( def unbroadcast(id: Long, removeFromDriver: Boolean) { broadcastFactory.unbroadcast(id, removeFromDriver) } - } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 79216bd2b8404..ec5acf5f23f5f 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -31,7 +31,10 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { - override def value = value_ + def value: T = { + assertValid() + value_ + } val blockId = BroadcastBlockId(id) @@ -45,17 +48,24 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea } /** - * Remove all persisted state associated with this HTTP broadcast. - * @param removeFromDriver Whether to remove state from the driver. + * Remove all persisted state associated with this HTTP broadcast on the executors. + */ + def unpersist() { + HttpBroadcast.unpersist(id, removeFromDriver = false) + } + + /** + * Remove all persisted state associated with this HTTP Broadcast on both the executors + * and the driver. */ - override def unpersist(removeFromDriver: Boolean) { - isValid = !removeFromDriver - HttpBroadcast.unpersist(id, removeFromDriver) + private[spark] def destroy() { + _isValid = false + HttpBroadcast.unpersist(id, removeFromDriver = true) } // Used by the JVM when serializing this object private def writeObject(out: ObjectOutputStream) { - assert(isValid, "Attempted to serialize a broadcast variable that has been destroyed!") + assertValid() out.defaultWriteObject() } @@ -231,5 +241,4 @@ private[spark] object HttpBroadcast extends Logging { logError("Exception while deleting broadcast file: %s".format(file), e) } } - } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index dbe65d88104fb..590caa9699dd3 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -29,7 +29,10 @@ import org.apache.spark.util.Utils private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { - override def value = value_ + def value = { + assertValid() + value_ + } val broadcastId = BroadcastBlockId(id) @@ -47,7 +50,23 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo sendBroadcast() } - def sendBroadcast() { + /** + * Remove all persisted state associated with this Torrent broadcast on the executors. + */ + def unpersist() { + TorrentBroadcast.unpersist(id, removeFromDriver = false) + } + + /** + * Remove all persisted state associated with this Torrent broadcast on both the executors + * and the driver. + */ + private[spark] def destroy() { + _isValid = false + TorrentBroadcast.unpersist(id, removeFromDriver = true) + } + + private def sendBroadcast() { val tInfo = TorrentBroadcast.blockifyObject(value_) totalBlocks = tInfo.totalBlocks totalBytes = tInfo.totalBytes @@ -71,18 +90,9 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo } } - /** - * Remove all persisted state associated with this Torrent broadcast. - * @param removeFromDriver Whether to remove state from the driver. - */ - override def unpersist(removeFromDriver: Boolean) { - isValid = !removeFromDriver - TorrentBroadcast.unpersist(id, removeFromDriver) - } - // Used by the JVM when serializing this object private def writeObject(out: ObjectOutputStream) { - assert(isValid, "Attempted to serialize a broadcast variable that has been destroyed!") + assertValid() out.defaultWriteObject() } @@ -240,7 +250,6 @@ private[spark] object TorrentBroadcast extends Logging { def unpersist(id: Long, removeFromDriver: Boolean) = synchronized { SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver) } - } private[spark] case class TorrentBlock( diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e8d36e6bfc810..ea22ad29bc885 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1128,5 +1128,4 @@ abstract class RDD[T: ClassTag]( def toJavaRDD() : JavaRDD[T] = { new JavaRDD(this)(elementClassTag) } - } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 27e271368ed06..cffea28fbf794 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -53,9 +53,7 @@ private[spark] case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: I def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId } -// Leave field as an instance variable to avoid matching on it -private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId { - var field = "" +private[spark] case class BroadcastBlockId(broadcastId: Long, field: String = "") extends BlockId { def name = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field) } @@ -77,19 +75,10 @@ private[spark] case class TestBlockId(id: String) extends BlockId { def name = "test_" + id } -private[spark] object BroadcastBlockId { - def apply(broadcastId: Long, field: String) = { - val blockId = new BroadcastBlockId(broadcastId) - blockId.field = field - blockId - } -} - private[spark] object BlockId { val RDD = "rdd_([0-9]+)_([0-9]+)".r val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r - val BROADCAST = "broadcast_([0-9]+)".r - val BROADCAST_FIELD = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r + val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r val TASKRESULT = "taskresult_([0-9]+)".r val STREAM = "input-([0-9]+)-([0-9]+)".r val TEST = "test_(.*)".r @@ -100,10 +89,8 @@ private[spark] object BlockId { RDDBlockId(rddId.toInt, splitIndex.toInt) case SHUFFLE(shuffleId, mapId, reduceId) => ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) - case BROADCAST(broadcastId) => - BroadcastBlockId(broadcastId.toLong) - case BROADCAST_FIELD(broadcastId, field) => - BroadcastBlockId(broadcastId.toLong, field) + case BROADCAST(broadcastId, field) => + BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_")) case TASKRESULT(taskId) => TaskResultBlockId(taskId.toLong) case STREAM(streamId, uniqueId) => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index c90abb187bdb1..925cee1eb6be7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -832,7 +832,7 @@ private[spark] class BlockManager( def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) { logInfo("Removing broadcast " + broadcastId) val blocksToRemove = blockInfo.keys.collect { - case bid: BroadcastBlockId if bid.broadcastId == broadcastId => bid + case bid @ BroadcastBlockId(`broadcastId`, _) => bid } blocksToRemove.foreach { blockId => removeBlock(blockId, removeFromDriver) } } @@ -897,7 +897,7 @@ private[spark] class BlockManager( def shouldCompress(blockId: BlockId): Boolean = blockId match { case ShuffleBlockId(_, _, _) => compressShuffle - case BroadcastBlockId(_) => compressBroadcast + case BroadcastBlockId(_, _) => compressBroadcast case RDDBlockId(_, _) => compressRdds case TempBlockId(_) => compressShuffleSpill case _ => false diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index f61aa1d6bc0fc..4e45bb8452fd8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -106,9 +106,7 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log askDriverWithReply(RemoveBlock(blockId)) } - /** - * Remove all blocks belonging to the given RDD. - */ + /** Remove all blocks belonging to the given RDD. */ def removeRdd(rddId: Int, blocking: Boolean) { val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId)) future onFailure { @@ -119,16 +117,12 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log } } - /** - * Remove all blocks belonging to the given shuffle. - */ + /** Remove all blocks belonging to the given shuffle. */ def removeShuffle(shuffleId: Int) { askDriverWithReply(RemoveShuffle(shuffleId)) } - /** - * Remove all blocks belonging to the given broadcast. - */ + /** Remove all blocks belonging to the given broadcast. */ def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean) { askDriverWithReply(RemoveBroadcast(broadcastId, removeFromMaster)) } @@ -148,20 +142,21 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log } /** - * Return the block's local status on all block managers, if any. + * Return the block's status on all block managers, if any. * * If askSlaves is true, this invokes the master to query each block manager for the most * updated block statuses. This is useful when the master is not informed of the given block * by all block managers. - * - * To avoid potential deadlocks, the use of Futures is necessary, because the master actor - * should not block on waiting for a block manager, which can in turn be waiting for the - * master actor for a response to a prior message. */ def getBlockStatus( blockId: BlockId, askSlaves: Boolean = true): Map[BlockManagerId, BlockStatus] = { val msg = GetBlockStatus(blockId, askSlaves) + /* + * To avoid potential deadlocks, the use of Futures is necessary, because the master actor + * should not block on waiting for a block manager, which can in turn be waiting for the + * master actor for a response to a prior message. + */ val response = askDriverWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) val (blockManagerIds, futures) = response.unzip val result = Await.result(Future.sequence(futures), timeout) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 2d9445425b879..4159fc733a566 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -255,21 +255,22 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus } /** - * Return the block's local status for all block managers, if any. + * Return the block's status for all block managers, if any. * * If askSlaves is true, the master queries each block manager for the most updated block * statuses. This is useful when the master is not informed of the given block by all block * managers. - * - * Rather than blocking on the block status query, master actor should simply return a - * Future to avoid potential deadlocks. This can arise if there exists a block manager - * that is also waiting for this master actor's response to a previous message. */ private def blockStatus( blockId: BlockId, askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = { import context.dispatcher val getBlockStatus = GetBlockStatus(blockId) + /* + * Rather than blocking on the block status query, master actor should simply return + * Futures to avoid potential deadlocks. This can arise if there exists a block manager + * that is also waiting for this master actor's response to a previous message. + */ blockManagerInfo.values.map { info => val blockStatusFuture = if (askSlaves) { diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index d9a6af61872d1..c23b6b3944ba0 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -195,7 +195,7 @@ private[spark] object JsonProtocol { taskMetrics.shuffleWriteMetrics.map(shuffleWriteMetricsToJson).getOrElse(JNothing) val updatedBlocks = taskMetrics.updatedBlocks.map { blocks => JArray(blocks.toList.map { case (id, status) => - ("Block ID" -> id.toString) ~ + ("Block ID" -> blockIdToJson(id)) ~ ("Status" -> blockStatusToJson(status)) }) }.getOrElse(JNothing) @@ -284,6 +284,33 @@ private[spark] object JsonProtocol { ("Replication" -> storageLevel.replication) } + def blockIdToJson(blockId: BlockId): JValue = { + val blockType = Utils.getFormattedClassName(blockId) + val json: JObject = blockId match { + case rddBlockId: RDDBlockId => + ("RDD ID" -> rddBlockId.rddId) ~ + ("Split Index" -> rddBlockId.splitIndex) + case shuffleBlockId: ShuffleBlockId => + ("Shuffle ID" -> shuffleBlockId.shuffleId) ~ + ("Map ID" -> shuffleBlockId.mapId) ~ + ("Reduce ID" -> shuffleBlockId.reduceId) + case broadcastBlockId: BroadcastBlockId => + ("Broadcast ID" -> broadcastBlockId.broadcastId) ~ + ("Field" -> broadcastBlockId.field) + case taskResultBlockId: TaskResultBlockId => + "Task ID" -> taskResultBlockId.taskId + case streamBlockId: StreamBlockId => + ("Stream ID" -> streamBlockId.streamId) ~ + ("Unique ID" -> streamBlockId.uniqueId) + case tempBlockId: TempBlockId => + val uuid = UUIDToJson(tempBlockId.id) + "Temp ID" -> uuid + case testBlockId: TestBlockId => + "Test ID" -> testBlockId.id + } + ("Type" -> blockType) ~ json + } + def blockStatusToJson(blockStatus: BlockStatus): JValue = { val storageLevel = storageLevelToJson(blockStatus.storageLevel) ("Storage Level" -> storageLevel) ~ @@ -484,7 +511,7 @@ private[spark] object JsonProtocol { Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson) metrics.updatedBlocks = Utils.jsonOption(json \ "Updated Blocks").map { value => value.extract[List[JValue]].map { block => - val id = BlockId((block \ "Block ID").extract[String]) + val id = blockIdFromJson(block \ "Block ID") val status = blockStatusFromJson(block \ "Status") (id, status) } @@ -587,6 +614,45 @@ private[spark] object JsonProtocol { StorageLevel(useDisk, useMemory, deserialized, replication) } + def blockIdFromJson(json: JValue): BlockId = { + val rddBlockId = Utils.getFormattedClassName(RDDBlockId) + val shuffleBlockId = Utils.getFormattedClassName(ShuffleBlockId) + val broadcastBlockId = Utils.getFormattedClassName(BroadcastBlockId) + val taskResultBlockId = Utils.getFormattedClassName(TaskResultBlockId) + val streamBlockId = Utils.getFormattedClassName(StreamBlockId) + val tempBlockId = Utils.getFormattedClassName(TempBlockId) + val testBlockId = Utils.getFormattedClassName(TestBlockId) + + (json \ "Type").extract[String] match { + case `rddBlockId` => + val rddId = (json \ "RDD ID").extract[Int] + val splitIndex = (json \ "Split Index").extract[Int] + new RDDBlockId(rddId, splitIndex) + case `shuffleBlockId` => + val shuffleId = (json \ "Shuffle ID").extract[Int] + val mapId = (json \ "Map ID").extract[Int] + val reduceId = (json \ "Reduce ID").extract[Int] + new ShuffleBlockId(shuffleId, mapId, reduceId) + case `broadcastBlockId` => + val broadcastId = (json \ "Broadcast ID").extract[Long] + val field = (json \ "Field").extract[String] + new BroadcastBlockId(broadcastId, field) + case `taskResultBlockId` => + val taskId = (json \ "Task ID").extract[Long] + new TaskResultBlockId(taskId) + case `streamBlockId` => + val streamId = (json \ "Stream ID").extract[Int] + val uniqueId = (json \ "Unique ID").extract[Long] + new StreamBlockId(streamId, uniqueId) + case `tempBlockId` => + val tempId = UUIDFromJson(json \ "Temp ID") + new TempBlockId(tempId) + case `testBlockId` => + val testId = (json \ "Test ID").extract[String] + new TestBlockId(testId) + } + } + def blockStatusFromJson(json: JValue): BlockStatus = { val storageLevel = storageLevelFromJson(json \ "Storage Level") val memorySize = (json \ "Memory Size").extract[Long] diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index d28496e316a34..f1bfb6666ddda 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -246,12 +246,20 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { afterUsingBroadcast(blocks, blockManagerMaster) // Unpersist broadcast - broadcast.unpersist(removeFromDriver) + if (removeFromDriver) { + broadcast.destroy() + } else { + broadcast.unpersist() + } afterUnpersist(blocks, blockManagerMaster) - if (!removeFromDriver) { - // The broadcast variable is not completely destroyed (i.e. state still exists on driver) - // Using the variable again should yield the same answer as before. + // If the broadcast is removed from driver, all subsequent uses of the broadcast variable + // should throw SparkExceptions. Otherwise, the result should be the same as before. + if (removeFromDriver) { + // Using this variable on the executors crashes them, which hangs the test. + // Instead, crash the driver by directly accessing the broadcast value. + intercept[SparkException] { broadcast.value } + } else { val results = sc.parallelize(1 to numSlaves, numSlaves).map(x => (x, broadcast.value.sum)) assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) } @@ -263,5 +271,4 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { conf.set("spark.broadcast.factory", "org.apache.spark.broadcast.%s".format(factoryName)) conf } - } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index bddbd381c2665..b47de5eab95a4 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -773,7 +773,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT // getLocations should return nothing because the master is not informed // getBlockStatus without asking slaves should have the same result - // getBlockStatus with asking slaves, however, should present the actual block statuses + // getBlockStatus with asking slaves, however, should return the actual block statuses assert(store.master.getLocations("list4").size === 0) assert(store.master.getLocations("list5").size === 0) assert(store.master.getLocations("list6").size === 0) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 580ac34f5f0b4..6bc8bcc036cb3 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -104,6 +104,14 @@ class JsonProtocolSuite extends FunSuite { testTaskEndReason(TaskKilled) testTaskEndReason(ExecutorLostFailure) testTaskEndReason(UnknownReason) + + // BlockId + testBlockId(RDDBlockId(1, 2)) + testBlockId(ShuffleBlockId(1, 2, 3)) + testBlockId(BroadcastBlockId(1L, "")) + testBlockId(TaskResultBlockId(1L)) + testBlockId(StreamBlockId(1, 2L)) + testBlockId(TempBlockId(UUID.randomUUID())) } @@ -158,6 +166,11 @@ class JsonProtocolSuite extends FunSuite { assertEquals(reason, newReason) } + private def testBlockId(blockId: BlockId) { + val newBlockId = JsonProtocol.blockIdFromJson(JsonProtocol.blockIdToJson(blockId)) + blockId == newBlockId + } + /** -------------------------------- * | Util methods for comparing events | @@ -542,4 +555,4 @@ class JsonProtocolSuite extends FunSuite { {"Event":"SparkListenerUnpersistRDD","RDD ID":12345} """ - } +} From f0aabb1c8496dc79daeb6d090fb36ceef310622b Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 1 Apr 2014 19:55:45 -0700 Subject: [PATCH 26/37] Correct semantics for TimeStampedWeakValueHashMap + add tests This largely accounts for the cases when WeakReference becomes no longer strongly reachable, in which case the map should return None for all get() operations, and should skip the entry for all listing operations. --- .../apache/spark/broadcast/Broadcast.scala | 2 +- .../spark/util/TimeStampedHashMap.scala | 43 +-- .../util/TimeStampedWeakValueHashMap.scala | 78 ++++-- .../spark/util/TimeStampedHashMapSuite.scala | 264 ++++++++++++++++++ 4 files changed, 350 insertions(+), 37 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index 81e0e5297683b..b28e15a6840d9 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -78,7 +78,7 @@ abstract class Broadcast[T](val id: Long) extends Serializable { */ protected def assertValid() { if (!_isValid) { - throw new SparkException("Attempted to use %s when is no longer valid!".format(toString)) + throw new SparkException("Attempted to use %s after it has been destroyed!".format(toString)) } } diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala index 1721818c212f9..5c239329588d8 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -21,7 +21,7 @@ import java.util.Set import java.util.Map.Entry import java.util.concurrent.ConcurrentHashMap -import scala.collection.{immutable, JavaConversions, mutable} +import scala.collection.{JavaConversions, mutable} import org.apache.spark.Logging @@ -50,11 +50,11 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa } def iterator: Iterator[(A, B)] = { - val jIterator = getEntrySet.iterator() + val jIterator = getEntrySet.iterator JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue.value)) } - def getEntrySet: Set[Entry[A, TimeStampedValue[B]]] = internalMap.entrySet() + def getEntrySet: Set[Entry[A, TimeStampedValue[B]]] = internalMap.entrySet override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = { val newMap = new TimeStampedHashMap[A, B1] @@ -86,8 +86,7 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa } override def apply(key: A): B = { - val value = internalMap.get(key) - Option(value).map(_.value).getOrElse { throw new NoSuchElementException() } + get(key).getOrElse { throw new NoSuchElementException() } } override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = { @@ -101,9 +100,9 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa override def size: Int = internalMap.size override def foreach[U](f: ((A, B)) => U) { - val iterator = getEntrySet.iterator() - while(iterator.hasNext) { - val entry = iterator.next() + val it = getEntrySet.iterator + while(it.hasNext) { + val entry = it.next() val kv = (entry.getKey, entry.getValue.value) f(kv) } @@ -115,27 +114,39 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa Option(prev).map(_.value) } - def toMap: immutable.Map[A, B] = iterator.toMap + def putAll(map: Map[A, B]) { + map.foreach { case (k, v) => update(k, v) } + } + + def toMap: Map[A, B] = iterator.toMap def clearOldValues(threshTime: Long, f: (A, B) => Unit) { - val iterator = getEntrySet.iterator() - while (iterator.hasNext) { - val entry = iterator.next() + val it = getEntrySet.iterator + while (it.hasNext) { + val entry = it.next() if (entry.getValue.timestamp < threshTime) { f(entry.getKey, entry.getValue.value) logDebug("Removing key " + entry.getKey) - iterator.remove() + it.remove() } } } - /** - * Removes old key-value pairs that have timestamp earlier than `threshTime`. - */ + /** Removes old key-value pairs that have timestamp earlier than `threshTime`. */ def clearOldValues(threshTime: Long) { clearOldValues(threshTime, (_, _) => ()) } private def currentTime: Long = System.currentTimeMillis + // For testing + + def getTimeStampedValue(key: A): Option[TimeStampedValue[B]] = { + Option(internalMap.get(key)) + } + + def getTimestamp(key: A): Option[Long] = { + getTimeStampedValue(key).map(_.timestamp) + } + } diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala index f814f58261bf3..b65017d6806c6 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala @@ -18,47 +18,61 @@ package org.apache.spark.util import java.lang.ref.WeakReference +import java.util.concurrent.atomic.AtomicInteger -import scala.collection.{immutable, mutable} +import scala.collection.mutable + +import org.apache.spark.Logging /** * A wrapper of TimeStampedHashMap that ensures the values are weakly referenced and timestamped. * - * If the value is garbage collected and the weak reference is null, get() operation returns - * a non-existent value. However, the corresponding key is actually not removed in the current - * implementation. Key-value pairs whose timestamps are older than a particular threshold time - * can then be removed using the clearOldValues method. It exposes a scala.collection.mutable.Map - * interface to allow it to be a drop-in replacement for Scala HashMaps. + * If the value is garbage collected and the weak reference is null, get() will return a + * non-existent value. These entries are removed from the map periodically (every N inserts), as + * their values are no longer strongly reachable. Further, key-value pairs whose timestamps are + * older than a particular threshold can be removed using the clearOldValues method. * - * Internally, it uses a Java ConcurrentHashMap, so all operations on this HashMap are thread-safe. + * TimeStampedWeakValueHashMap exposes a scala.collection.mutable.Map interface, which allows it + * to be a drop-in replacement for Scala HashMaps. Internally, it uses a Java ConcurrentHashMap, + * so all operations on this HashMap are thread-safe. * * @param updateTimeStampOnGet Whether timestamp of a pair will be updated when it is accessed. */ private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boolean = false) - extends mutable.Map[A, B]() { + extends mutable.Map[A, B]() with Logging { import TimeStampedWeakValueHashMap._ private val internalMap = new TimeStampedHashMap[A, WeakReference[B]](updateTimeStampOnGet) + private val insertCount = new AtomicInteger(0) + + /** Return a map consisting only of entries whose values are still strongly reachable. */ + private def nonNullReferenceMap = internalMap.filter { case (_, ref) => ref.get != null } def get(key: A): Option[B] = internalMap.get(key) - def iterator: Iterator[(A, B)] = internalMap.iterator + def iterator: Iterator[(A, B)] = nonNullReferenceMap.iterator override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = { val newMap = new TimeStampedWeakValueHashMap[A, B1] + val oldMap = nonNullReferenceMap.asInstanceOf[mutable.Map[A, WeakReference[B1]]] + newMap.internalMap.putAll(oldMap.toMap) newMap.internalMap += kv newMap } override def - (key: A): mutable.Map[A, B] = { val newMap = new TimeStampedWeakValueHashMap[A, B] + newMap.internalMap.putAll(nonNullReferenceMap.toMap) newMap.internalMap -= key newMap } override def += (kv: (A, B)): this.type = { internalMap += kv + if (insertCount.incrementAndGet() % CLEAR_NULL_VALUES_INTERVAL == 0) { + clearNullValues() + } this } @@ -71,31 +85,53 @@ private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boo override def apply(key: A): B = internalMap.apply(key) - override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = internalMap.filter(p) + override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = nonNullReferenceMap.filter(p) override def empty: mutable.Map[A, B] = new TimeStampedWeakValueHashMap[A, B]() override def size: Int = internalMap.size - override def foreach[U](f: ((A, B)) => U) = internalMap.foreach(f) + override def foreach[U](f: ((A, B)) => U) = nonNullReferenceMap.foreach(f) def putIfAbsent(key: A, value: B): Option[B] = internalMap.putIfAbsent(key, value) - def toMap: immutable.Map[A, B] = iterator.toMap + def toMap: Map[A, B] = iterator.toMap - /** - * Remove old key-value pairs that have timestamp earlier than `threshTime`. - */ + /** Remove old key-value pairs with timestamps earlier than `threshTime`. */ def clearOldValues(threshTime: Long) = internalMap.clearOldValues(threshTime) + /** Remove entries with values that are no longer strongly reachable. */ + def clearNullValues() { + val it = internalMap.getEntrySet.iterator + while (it.hasNext) { + val entry = it.next() + if (entry.getValue.value.get == null) { + logDebug("Removing key " + entry.getKey + " because it is no longer strongly reachable.") + it.remove() + } + } + } + + // For testing + + def getTimestamp(key: A): Option[Long] = { + internalMap.getTimeStampedValue(key).map(_.timestamp) + } + + def getReference(key: A): Option[WeakReference[B]] = { + internalMap.getTimeStampedValue(key).map(_.value) + } } /** * Helper methods for converting to and from WeakReferences. */ -private[spark] object TimeStampedWeakValueHashMap { +private object TimeStampedWeakValueHashMap { - /* Implicit conversion methods to WeakReferences */ + // Number of inserts after which entries with null references are removed + val CLEAR_NULL_VALUES_INTERVAL = 100 + + /* Implicit conversion methods to WeakReferences. */ implicit def toWeakReference[V](v: V): WeakReference[V] = new WeakReference[V](v) @@ -107,12 +143,15 @@ private[spark] object TimeStampedWeakValueHashMap { (kv: (K, WeakReference[V])) => p(kv) } - /* Implicit conversion methods from WeakReferences */ + /* Implicit conversion methods from WeakReferences. */ implicit def fromWeakReference[V](ref: WeakReference[V]): V = ref.get implicit def fromWeakReferenceOption[V](v: Option[WeakReference[V]]): Option[V] = { - v.map(fromWeakReference) + v match { + case Some(ref) => Option(fromWeakReference(ref)) + case None => None + } } implicit def fromWeakReferenceTuple[K, V](kv: (K, WeakReference[V])): (K, V) = { @@ -128,5 +167,4 @@ private[spark] object TimeStampedWeakValueHashMap { map: mutable.Map[K, WeakReference[V]]) : mutable.Map[K, V] = { mutable.Map(map.mapValues(fromWeakReference).toSeq: _*) } - } diff --git a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala new file mode 100644 index 0000000000000..6a5653ed2fb54 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala @@ -0,0 +1,264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.lang.ref.WeakReference + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.Random + +import org.scalatest.FunSuite + +class TimeStampedHashMapSuite extends FunSuite { + + // Test the testMap function - a Scala HashMap should obviously pass + testMap(new mutable.HashMap[String, String]()) + + // Test TimeStampedHashMap basic functionality + testMap(new TimeStampedHashMap[String, String]()) + testMapThreadSafety(new TimeStampedHashMap[String, String]()) + + // Test TimeStampedWeakValueHashMap basic functionality + testMap(new TimeStampedWeakValueHashMap[String, String]()) + testMapThreadSafety(new TimeStampedWeakValueHashMap[String, String]()) + + test("TimeStampedHashMap - clearing by timestamp") { + // clearing by insertion time + val map = new TimeStampedHashMap[String, String](updateTimeStampOnGet = false) + map("k1") = "v1" + assert(map("k1") === "v1") + Thread.sleep(10) + val threshTime = System.currentTimeMillis + assert(map.getTimestamp("k1").isDefined) + assert(map.getTimestamp("k1").get < threshTime) + map.clearOldValues(threshTime) + assert(map.get("k1") === None) + + // clearing by modification time + val map1 = new TimeStampedHashMap[String, String](updateTimeStampOnGet = true) + map1("k1") = "v1" + map1("k2") = "v2" + assert(map1("k1") === "v1") + Thread.sleep(10) + val threshTime1 = System.currentTimeMillis + Thread.sleep(10) + assert(map1("k2") === "v2") // access k2 to update its access time to > threshTime + assert(map1.getTimestamp("k1").isDefined) + assert(map1.getTimestamp("k1").get < threshTime1) + assert(map1.getTimestamp("k2").isDefined) + assert(map1.getTimestamp("k2").get >= threshTime1) + map1.clearOldValues(threshTime1) //should only clear k1 + assert(map1.get("k1") === None) + assert(map1.get("k2").isDefined) + } + + test("TimeStampedWeakValueHashMap - clearing by timestamp") { + // clearing by insertion time + val map = new TimeStampedWeakValueHashMap[String, String](updateTimeStampOnGet = false) + map("k1") = "v1" + assert(map("k1") === "v1") + Thread.sleep(10) + val threshTime = System.currentTimeMillis + assert(map.getTimestamp("k1").isDefined) + assert(map.getTimestamp("k1").get < threshTime) + map.clearOldValues(threshTime) + assert(map.get("k1") === None) + + // clearing by modification time + val map1 = new TimeStampedWeakValueHashMap[String, String](updateTimeStampOnGet = true) + map1("k1") = "v1" + map1("k2") = "v2" + assert(map1("k1") === "v1") + Thread.sleep(10) + val threshTime1 = System.currentTimeMillis + Thread.sleep(10) + assert(map1("k2") === "v2") // access k2 to update its access time to > threshTime + assert(map1.getTimestamp("k1").isDefined) + assert(map1.getTimestamp("k1").get < threshTime1) + assert(map1.getTimestamp("k2").isDefined) + assert(map1.getTimestamp("k2").get >= threshTime1) + map1.clearOldValues(threshTime1) //should only clear k1 + assert(map1.get("k1") === None) + assert(map1.get("k2").isDefined) + } + + test("TimeStampedWeakValueHashMap - clearing weak references") { + var strongRef = new Object + val weakRef = new WeakReference(strongRef) + val map = new TimeStampedWeakValueHashMap[String, Object] + map("k1") = strongRef + map("k2") = "v2" + map("k3") = "v3" + assert(map("k1") === strongRef) + + // clear strong reference to "k1" + strongRef = null + val startTime = System.currentTimeMillis + System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC. + System.runFinalization() // Make a best effort to call finalizer on all cleaned objects. + while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { + System.gc() + System.runFinalization() + Thread.sleep(100) + } + assert(map.getReference("k1").isDefined) + val ref = map.getReference("k1").get + assert(ref.get === null) + assert(map.get("k1") === None) + + // operations should only display non-null entries + assert(map.iterator.forall { case (k, v) => k != "k1" }) + assert(map.filter { case (k, v) => k != "k2" }.size === 1) + assert(map.filter { case (k, v) => k != "k2" }.head._1 === "k3") + assert(map.toMap.size === 2) + assert(map.toMap.forall { case (k, v) => k != "k1" }) + val buffer = new ArrayBuffer[String] + map.foreach { case (k, v) => buffer += v.toString } + assert(buffer.size === 2) + assert(buffer.forall(_ != "k1")) + val plusMap = map + (("k4", "v4")) + assert(plusMap.size === 3) + assert(plusMap.forall { case (k, v) => k != "k1" }) + val minusMap = map - "k2" + assert(minusMap.size === 1) + assert(minusMap.head._1 == "k3") + + // clear null values - should only clear k1 + map.clearNullValues() + assert(map.getReference("k1") === None) + assert(map.get("k1") === None) + assert(map.get("k2").isDefined) + assert(map.get("k2").get === "v2") + } + + /** Test basic operations of a Scala mutable Map. */ + def testMap(hashMapConstructor: => mutable.Map[String, String]) { + def newMap() = hashMapConstructor + val testMap1 = newMap() + val testMap2 = newMap() + val name = testMap1.getClass.getSimpleName + + test(name + " - basic test") { + // put, get, and apply + testMap1 += (("k1", "v1")) + assert(testMap1.get("k1").isDefined) + assert(testMap1.get("k1").get === "v1") + testMap1("k2") = "v2" + assert(testMap1.get("k2").isDefined) + assert(testMap1.get("k2").get === "v2") + assert(testMap1("k2") === "v2") + testMap1.update("k3", "v3") + assert(testMap1.get("k3").isDefined) + assert(testMap1.get("k3").get === "v3") + + // remove + testMap1.remove("k1") + assert(testMap1.get("k1").isEmpty) + testMap1.remove("k2") + intercept[NoSuchElementException] { + testMap1("k2") // Map.apply() causes exception + } + testMap1 -= "k3" + assert(testMap1.get("k3").isEmpty) + + // multi put + val keys = (1 to 100).map(_.toString) + val pairs = keys.map(x => (x, x * 2)) + assert((testMap2 ++ pairs).iterator.toSet === pairs.toSet) + testMap2 ++= pairs + + // iterator + assert(testMap2.iterator.toSet === pairs.toSet) + + // filter + val filtered = testMap2.filter { case (_, v) => v.toInt % 2 == 0 } + val evenPairs = pairs.filter { case (_, v) => v.toInt % 2 == 0 } + assert(filtered.iterator.toSet === evenPairs.toSet) + + // foreach + val buffer = new ArrayBuffer[(String, String)] + testMap2.foreach(x => buffer += x) + assert(testMap2.toSet === buffer.toSet) + + // multi remove + testMap2("k1") = "v1" + testMap2 --= keys + assert(testMap2.size === 1) + assert(testMap2.iterator.toSeq.head === ("k1", "v1")) + + // + + val testMap3 = testMap2 + (("k0", "v0")) + assert(testMap3.size === 2) + assert(testMap3.get("k1").isDefined) + assert(testMap3.get("k1").get === "v1") + assert(testMap3.get("k0").isDefined) + assert(testMap3.get("k0").get === "v0") + + // - + val testMap4 = testMap3 - "k0" + assert(testMap4.size === 1) + assert(testMap4.get("k1").isDefined) + assert(testMap4.get("k1").get === "v1") + } + } + + /** Test thread safety of a Scala mutable map. */ + def testMapThreadSafety(hashMapConstructor: => mutable.Map[String, String]) { + def newMap() = hashMapConstructor + val name = newMap().getClass.getSimpleName + val testMap = newMap() + @volatile var error = false + + def getRandomKey(m: mutable.Map[String, String]): Option[String] = { + val keys = testMap.keysIterator.toSeq + if (keys.nonEmpty) { + Some(keys(Random.nextInt(keys.size))) + } else { + None + } + } + + val threads = (1 to 25).map(i => new Thread() { + override def run() { + try { + for (j <- 1 to 1000) { + Random.nextInt(3) match { + case 0 => + testMap(Random.nextString(10)) = Random.nextDouble().toString // put + case 1 => + getRandomKey(testMap).map(testMap.get) // get + case 2 => + getRandomKey(testMap).map(testMap.remove) // remove + } + } + } catch { + case t: Throwable => + error = true + throw t + } + } + }) + + test(name + " - threading safety test") { + threads.map(_.start) + threads.map(_.join) + assert(!error) + } + } +} From c5b1d986e89ce8e52cfc9ac25d44be4b8fc5a259 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 3 Apr 2014 18:45:04 -0700 Subject: [PATCH 27/37] Address Patrick's comments --- .../org/apache/spark/ContextCleaner.scala | 3 -- .../org/apache/spark/MapOutputTracker.scala | 30 +++++++------- .../scala/org/apache/spark/SparkContext.scala | 3 +- .../apache/spark/broadcast/Broadcast.scala | 15 ++++--- .../spark/broadcast/HttpBroadcast.scala | 12 +----- .../spark/broadcast/TorrentBroadcast.scala | 7 +--- .../apache/spark/storage/BlockManager.scala | 4 +- .../spark/storage/BlockManagerMaster.scala | 9 +++-- .../storage/BlockManagerMasterActor.scala | 5 ++- .../storage/BlockManagerSlaveActor.scala | 39 +++++++++++++------ 10 files changed, 70 insertions(+), 57 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index b71b7fa517fd2..7b1e2af1b824f 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -80,7 +80,6 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { /** Stop the cleaner. */ def stop() { stopped = true - cleaningThread.interrupt() } /** Register a RDD for cleanup when it is garbage collected. */ @@ -119,8 +118,6 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } } } catch { - case ie: InterruptedException => - if (!stopped) logWarning("Cleaning thread interrupted") case t: Throwable => logError("Error in cleaning thread", t) } } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index c45c5c90048f3..ee82d9fa7874b 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -71,13 +71,18 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster * (driver and worker) use different HashMap to store its metadata. */ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging { - private val timeout = AkkaUtils.askTimeout(conf) - /** Set to the MapOutputTrackerActor living on the driver */ + /** Set to the MapOutputTrackerActor living on the driver. */ var trackerActor: ActorRef = _ - /** This HashMap needs to have different storage behavior for driver and worker */ + /** + * This HashMap has different behavior for the master and the workers. + * + * On the master, it serves as the source of map outputs recorded from ShuffleMapTasks. + * On the workers, it simply serves as a cache, in which a miss triggers a fetch from the + * master's corresponding HashMap. + */ protected val mapStatuses: Map[Int, Array[MapStatus]] /** @@ -87,7 +92,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging protected var epoch: Long = 0 protected val epochLock = new AnyRef - /** Remembers which map output locations are currently being fetched on a worker */ + /** Remembers which map output locations are currently being fetched on a worker. */ private val fetching = new HashSet[Int] /** @@ -173,7 +178,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } - /** Called to get current epoch number */ + /** Called to get current epoch number. */ def getEpoch: Long = { epochLock.synchronized { return epoch @@ -195,16 +200,13 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } - /** Unregister shuffle data */ + /** Unregister shuffle data. */ def unregisterShuffle(shuffleId: Int) { mapStatuses.remove(shuffleId) } - def stop() { - sendTracker(StopMapOutputTracker) - mapStatuses.clear() - trackerActor = null - } + /** Stop the tracker. */ + def stop() { } } /** @@ -219,7 +221,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) /** * Timestamp based HashMap for storing mapStatuses and cached serialized statuses in the master, - * so that statuses are dropped only by explicit deregistering or by TTL-based cleaning (if set). + * so that statuses are dropped only by explicit de-registering or by TTL-based cleaning (if set). * Other than these two scenarios, nothing should be dropped from this HashMap. */ protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]() @@ -314,7 +316,9 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) } override def stop() { - super.stop() + sendTracker(StopMapOutputTracker) + mapStatuses.clear() + trackerActor = null metadataCleaner.cancel() cachedSerializedStatuses.clear() } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 13fba1e0dfe5d..316b9f0ed8a04 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -35,6 +35,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHad import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.mesos.MesosNativeLibrary +import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ @@ -643,7 +644,7 @@ class SparkContext( * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. * The variable will be sent to each cluster only once. */ - def broadcast[T](value: T) = { + def broadcast[T](value: T): Broadcast[T] = { val bc = env.broadcastManager.newBroadcast[T](value, isLocal) cleaner.registerBroadcastForCleanup(bc) bc diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index b28e15a6840d9..e8a97d1754901 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -62,16 +62,21 @@ abstract class Broadcast[T](val id: Long) extends Serializable { def value: T /** - * Remove all persisted state associated with this broadcast on the executors. The next use - * of this broadcast on the executors will trigger a remote fetch. + * Delete cached copies of this broadcast on the executors. If the broadcast is used after + * this is called, it will need to be re-sent to each executor. */ def unpersist() /** - * Remove all persisted state associated with this broadcast on both the executors and the - * driver. Overriding implementations should set isValid to false. + * Remove all persisted state associated with this broadcast on both the executors and + * the driver. */ - private[spark] def destroy() + private[spark] def destroy() { + _isValid = false + onDestroy() + } + + protected def onDestroy() /** * If this broadcast is no longer valid, throw an exception. diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index ec5acf5f23f5f..f4e2e222f4984 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -54,12 +54,7 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea HttpBroadcast.unpersist(id, removeFromDriver = false) } - /** - * Remove all persisted state associated with this HTTP Broadcast on both the executors - * and the driver. - */ - private[spark] def destroy() { - _isValid = false + protected def onDestroy() { HttpBroadcast.unpersist(id, removeFromDriver = true) } @@ -91,7 +86,6 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea private[spark] object HttpBroadcast extends Logging { private var initialized = false - private var broadcastDir: File = null private var compress: Boolean = false private var bufferSize: Int = 65536 @@ -101,11 +95,9 @@ private[spark] object HttpBroadcast extends Logging { // TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist private val files = new TimeStampedHashSet[String] - private var cleaner: MetadataCleaner = null - private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES).toInt - private var compressionCodec: CompressionCodec = null + private var cleaner: MetadataCleaner = null def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { synchronized { diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 590caa9699dd3..73eeedb8d1f63 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -57,12 +57,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo TorrentBroadcast.unpersist(id, removeFromDriver = false) } - /** - * Remove all persisted state associated with this Torrent broadcast on both the executors - * and the driver. - */ - private[spark] def destroy() { - _isValid = false + protected def onDestroy() { TorrentBroadcast.unpersist(id, removeFromDriver = true) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 925cee1eb6be7..616d24ccd8b6e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -829,12 +829,12 @@ private[spark] class BlockManager( /** * Remove all blocks belonging to the given broadcast. */ - def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) { + def removeBroadcast(broadcastId: Long, tellMaster: Boolean) { logInfo("Removing broadcast " + broadcastId) val blocksToRemove = blockInfo.keys.collect { case bid @ BroadcastBlockId(`broadcastId`, _) => bid } - blocksToRemove.foreach { blockId => removeBlock(blockId, removeFromDriver) } + blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster) } } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 4e45bb8452fd8..73074e2188e65 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -109,7 +109,7 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log /** Remove all blocks belonging to the given RDD. */ def removeRdd(rddId: Int, blocking: Boolean) { val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId)) - future onFailure { + future.onFailure { case e: Throwable => logError("Failed to remove RDD " + rddId, e) } if (blocking) { @@ -117,12 +117,12 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log } } - /** Remove all blocks belonging to the given shuffle. */ + /** Remove all blocks belonging to the given shuffle asynchronously. */ def removeShuffle(shuffleId: Int) { askDriverWithReply(RemoveShuffle(shuffleId)) } - /** Remove all blocks belonging to the given broadcast. */ + /** Remove all blocks belonging to the given broadcast asynchronously. */ def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean) { askDriverWithReply(RemoveBroadcast(broadcastId, removeFromMaster)) } @@ -142,7 +142,8 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log } /** - * Return the block's status on all block managers, if any. + * Return the block's status on all block managers, if any. This can potentially be an + * expensive operation and is used mainly for testing. * * If askSlaves is true, this invokes the master to query each block manager for the most * updated block statuses. This is useful when the master is not informed of the given block diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 4159fc733a566..3b63bf3f3774d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -168,7 +168,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus */ private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) { // TODO: Consolidate usages of - val removeMsg = RemoveBroadcast(broadcastId) + val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver) blockManagerInfo.values .filter { info => removeFromDriver || info.blockManagerId.executorId != "" } .foreach { bm => bm.slaveActor ! removeMsg } @@ -255,7 +255,8 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus } /** - * Return the block's status for all block managers, if any. + * Return the block's status for all block managers, if any. This can potentially be an + * expensive operation and is used mainly for testing. * * If askSlaves is true, the master queries each block manager for the most updated block * statuses. This is useful when the master is not informed of the given block by all block diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 016ade428c68f..2396ca49a7d3f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -17,9 +17,11 @@ package org.apache.spark.storage +import scala.concurrent.Future + import akka.actor.Actor -import org.apache.spark.MapOutputTracker +import org.apache.spark.{Logging, MapOutputTracker} import org.apache.spark.storage.BlockManagerMessages._ /** @@ -30,25 +32,40 @@ private[storage] class BlockManagerSlaveActor( blockManager: BlockManager, mapOutputTracker: MapOutputTracker) - extends Actor { + extends Actor with Logging { - override def receive = { + import context.dispatcher + // Operations that involve removing blocks may be slow and should be done asynchronously + override def receive = { case RemoveBlock(blockId) => - blockManager.removeBlock(blockId) + val removeBlock = Future { blockManager.removeBlock(blockId) } + removeBlock.onFailure { case t: Throwable => + logError("Error in removing block " + blockId, t) + } case RemoveRdd(rddId) => - val numBlocksRemoved = blockManager.removeRdd(rddId) - sender ! numBlocksRemoved + val removeRdd = Future { sender ! blockManager.removeRdd(rddId) } + removeRdd.onFailure { case t: Throwable => + logError("Error in removing RDD " + rddId, t) + } case RemoveShuffle(shuffleId) => - blockManager.shuffleBlockManager.removeShuffle(shuffleId) - if (mapOutputTracker != null) { - mapOutputTracker.unregisterShuffle(shuffleId) + val removeShuffle = Future { + blockManager.shuffleBlockManager.removeShuffle(shuffleId) + if (mapOutputTracker != null) { + mapOutputTracker.unregisterShuffle(shuffleId) + } + } + removeShuffle.onFailure { case t: Throwable => + logError("Error in removing shuffle " + shuffleId, t) } - case RemoveBroadcast(broadcastId, removeFromDriver) => - blockManager.removeBroadcast(broadcastId, removeFromDriver) + case RemoveBroadcast(broadcastId, tellMaster) => + val removeBroadcast = Future { blockManager.removeBroadcast(broadcastId, tellMaster) } + removeBroadcast.onFailure { case t: Throwable => + logError("Error in removing broadcast " + broadcastId, t) + } case GetBlockStatus(blockId, _) => sender ! blockManager.getStatus(blockId) From cd72d192e0371c9ccbcd73a7086bbb6acc234017 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 3 Apr 2014 21:16:43 -0700 Subject: [PATCH 28/37] Make automatic cleanup configurable (not documented) --- .../main/scala/org/apache/spark/SparkContext.scala | 12 ++++++++---- .../apache/spark/scheduler/TaskSchedulerImpl.scala | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 316b9f0ed8a04..579c963094a78 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -228,8 +228,12 @@ class SparkContext( @volatile private[spark] var dagScheduler = new DAGScheduler(this) dagScheduler.start() - private[spark] val cleaner = new ContextCleaner(this) - cleaner.start() + private[spark] val cleaner: Option[ContextCleaner] = + if (conf.getBoolean("spark.cleaner.automatic", true)) { + Some(new ContextCleaner(this)) + } else None + + cleaner.foreach(_.start()) postEnvironmentUpdate() @@ -646,7 +650,7 @@ class SparkContext( */ def broadcast[T](value: T): Broadcast[T] = { val bc = env.broadcastManager.newBroadcast[T](value, isLocal) - cleaner.registerBroadcastForCleanup(bc) + cleaner.foreach(_.registerBroadcastForCleanup(bc)) bc } @@ -841,7 +845,7 @@ class SparkContext( dagScheduler = null if (dagSchedulerCopy != null) { metadataCleaner.cancel() - cleaner.stop() + cleaner.foreach(_.stop()) dagSchedulerCopy.stop() listenerBus.stop() taskScheduler = null diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index a92922166f595..acd152dda89d4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -42,7 +42,7 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode * * THREADING: SchedulerBackends and task-submitting clients can call this class from multiple * threads, so it needs locks in public API methods to maintain its state. In addition, some - * SchedulerBackends sycnchronize on themselves when they want to send events here, and then + * SchedulerBackends synchronize on themselves when they want to send events here, and then * acquire a lock on us, so we need to make sure that we don't try to lock the backend while * we are holding a lock on ourselves. */ From a430f068be461f248754260a331e70fceabfb6a7 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 4 Apr 2014 04:19:40 -0700 Subject: [PATCH 29/37] Fixed compilation errors. --- core/src/main/scala/org/apache/spark/Dependency.scala | 2 +- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 2 +- .../src/test/scala/org/apache/spark/ContextCleanerSuite.scala | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 905328ba6b5a4..1cd629c15bd46 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -56,7 +56,7 @@ class ShuffleDependency[K, V]( val shuffleId: Int = rdd.context.newShuffleId() - rdd.sparkContext.cleaner.registerShuffleForCleanup(this) + rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index ea22ad29bc885..50dbbe35f3745 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -139,7 +139,7 @@ abstract class RDD[T: ClassTag]( } sc.persistRDD(this) // Register the RDD with the ContextCleaner for automatic GC-based cleanup - sc.cleaner.registerRDDForCleanup(this) + sc.cleaner.foreach(_.registerRDDForCleanup(this)) storageLevel = newLevel this } diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 3d95547b20fc1..9eb434ed0ac0e 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -197,7 +197,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } } - def cleaner = sc.cleaner + def cleaner = sc.cleaner.get } @@ -235,7 +235,7 @@ class CleanerTester( logInfo("Attempting to validate before cleanup:\n" + uncleanedResourcesToString) preCleanupValidate() - sc.cleaner.attachListener(cleanerListener) + sc.cleaner.get.attachListener(cleanerListener) /** Assert that all the stuff has been cleaned up */ def assertCleanup()(implicit waitTimeout: Eventually.Timeout) { From 104a89a7ca744ea9b58095c93b58bd90404e8055 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 4 Apr 2014 09:26:03 -0700 Subject: [PATCH 30/37] Fixed failing BroadcastSuite unit tests by introducing blocking for removeShuffle and removeBroadcast in BlockManager* --- .../org/apache/spark/ContextCleaner.scala | 28 +++--- .../apache/spark/broadcast/Broadcast.scala | 17 +++- .../spark/broadcast/BroadcastFactory.scala | 2 +- .../spark/broadcast/BroadcastManager.scala | 4 +- .../spark/broadcast/HttpBroadcast.scala | 12 +-- .../broadcast/HttpBroadcastFactory.scala | 7 +- .../spark/broadcast/TorrentBroadcast.scala | 12 +-- .../broadcast/TorrentBroadcastFactory.scala | 5 +- .../apache/spark/storage/BlockManager.scala | 3 +- .../spark/storage/BlockManagerMaster.scala | 26 ++++-- .../storage/BlockManagerMasterActor.scala | 37 +++++--- .../storage/BlockManagerSlaveActor.scala | 43 +++++---- .../spark/storage/ShuffleBlockManager.scala | 4 +- .../org/apache/spark/BroadcastSuite.scala | 93 +++++++++++++------ 14 files changed, 190 insertions(+), 103 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 7b1e2af1b824f..75c87a0553a7a 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -112,9 +112,9 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { logDebug("Got cleaning task " + task) referenceBuffer -= reference.get task match { - case CleanRDD(rddId) => doCleanupRDD(rddId) - case CleanShuffle(shuffleId) => doCleanupShuffle(shuffleId) - case CleanBroadcast(broadcastId) => doCleanupBroadcast(broadcastId) + case CleanRDD(rddId) => doCleanupRDD(rddId, blocking = false) + case CleanShuffle(shuffleId) => doCleanupShuffle(shuffleId, blocking = false) + case CleanBroadcast(broadcastId) => doCleanupBroadcast(broadcastId, blocking = false) } } } catch { @@ -124,10 +124,10 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** Perform RDD cleanup. */ - private def doCleanupRDD(rddId: Int) { + private def doCleanupRDD(rddId: Int, blocking: Boolean) { try { logDebug("Cleaning RDD " + rddId) - sc.unpersistRDD(rddId, blocking = false) + sc.unpersistRDD(rddId, blocking) listeners.foreach(_.rddCleaned(rddId)) logInfo("Cleaned RDD " + rddId) } catch { @@ -135,12 +135,12 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } } - /** Perform shuffle cleanup. */ - private def doCleanupShuffle(shuffleId: Int) { + /** Perform shuffle cleanup, asynchronously. */ + private def doCleanupShuffle(shuffleId: Int, blocking: Boolean) { try { logDebug("Cleaning shuffle " + shuffleId) mapOutputTrackerMaster.unregisterShuffle(shuffleId) - blockManagerMaster.removeShuffle(shuffleId) + blockManagerMaster.removeShuffle(shuffleId, blocking) listeners.foreach(_.shuffleCleaned(shuffleId)) logInfo("Cleaned shuffle " + shuffleId) } catch { @@ -149,10 +149,10 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** Perform broadcast cleanup. */ - private def doCleanupBroadcast(broadcastId: Long) { + private def doCleanupBroadcast(broadcastId: Long, blocking: Boolean) { try { logDebug("Cleaning broadcast " + broadcastId) - broadcastManager.unbroadcast(broadcastId, removeFromDriver = true) + broadcastManager.unbroadcast(broadcastId, true, blocking) listeners.foreach(_.broadcastCleaned(broadcastId)) logInfo("Cleaned broadcast " + broadcastId) } catch { @@ -164,18 +164,18 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private def broadcastManager = sc.env.broadcastManager private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] - // Used for testing + // Used for testing, explicitly blocks until cleanup is completed def cleanupRDD(rdd: RDD[_]) { - doCleanupRDD(rdd.id) + doCleanupRDD(rdd.id, blocking = true) } def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) { - doCleanupShuffle(shuffleDependency.shuffleId) + doCleanupShuffle(shuffleDependency.shuffleId, blocking = true) } def cleanupBroadcast[T](broadcast: Broadcast[T]) { - doCleanupBroadcast(broadcast.id) + doCleanupBroadcast(broadcast.id, blocking = true) } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index e8a97d1754901..f28b6565a830c 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -61,22 +61,31 @@ abstract class Broadcast[T](val id: Long) extends Serializable { def value: T + /** + * Asynchronously delete cached copies of this broadcast on the executors. + * If the broadcast is used after this is called, it will need to be re-sent to each executor. + */ + def unpersist() { + unpersist(blocking = false) + } + /** * Delete cached copies of this broadcast on the executors. If the broadcast is used after * this is called, it will need to be re-sent to each executor. + * @param blocking Whether to block until unpersisting has completed */ - def unpersist() + def unpersist(blocking: Boolean) /** * Remove all persisted state associated with this broadcast on both the executors and * the driver. */ - private[spark] def destroy() { + private[spark] def destroy(blocking: Boolean) { _isValid = false - onDestroy() + onDestroy(blocking) } - protected def onDestroy() + protected def onDestroy(blocking: Boolean) /** * If this broadcast is no longer valid, throw an exception. diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index 9ff1675e76a5e..a7867bcaabfc2 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -29,6 +29,6 @@ import org.apache.spark.SparkConf trait BroadcastFactory { def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] - def unbroadcast(id: Long, removeFromDriver: Boolean) + def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) def stop() } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index c3ea16ff9eb5e..cf62aca4d45e8 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -60,7 +60,7 @@ private[spark] class BroadcastManager( broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) } - def unbroadcast(id: Long, removeFromDriver: Boolean) { - broadcastFactory.unbroadcast(id, removeFromDriver) + def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) { + broadcastFactory.unbroadcast(id, removeFromDriver, blocking) } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index f4e2e222f4984..2d5e0352f4265 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -50,12 +50,12 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea /** * Remove all persisted state associated with this HTTP broadcast on the executors. */ - def unpersist() { - HttpBroadcast.unpersist(id, removeFromDriver = false) + def unpersist(blocking: Boolean) { + HttpBroadcast.unpersist(id, removeFromDriver = false, blocking) } - protected def onDestroy() { - HttpBroadcast.unpersist(id, removeFromDriver = true) + protected def onDestroy(blocking: Boolean) { + HttpBroadcast.unpersist(id, removeFromDriver = true, blocking) } // Used by the JVM when serializing this object @@ -194,8 +194,8 @@ private[spark] object HttpBroadcast extends Logging { * If removeFromDriver is true, also remove these persisted blocks on the driver * and delete the associated broadcast file. */ - def unpersist(id: Long, removeFromDriver: Boolean) = synchronized { - SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver) + def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized { + SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) if (removeFromDriver) { val file = getFile(id) files.remove(file.toString) diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala index 4affa922156c9..2958e4f4c658a 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala @@ -34,9 +34,10 @@ class HttpBroadcastFactory extends BroadcastFactory { /** * Remove all persisted state associated with the HTTP broadcast with the given ID. - * @param removeFromDriver Whether to remove state from the driver. + * @param removeFromDriver Whether to remove state from the driver + * @param blocking Whether to block until unbroadcasted */ - def unbroadcast(id: Long, removeFromDriver: Boolean) { - HttpBroadcast.unpersist(id, removeFromDriver) + def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) { + HttpBroadcast.unpersist(id, removeFromDriver, blocking) } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 73eeedb8d1f63..7f37e306f0d07 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -53,12 +53,12 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo /** * Remove all persisted state associated with this Torrent broadcast on the executors. */ - def unpersist() { - TorrentBroadcast.unpersist(id, removeFromDriver = false) + def unpersist(blocking: Boolean) { + TorrentBroadcast.unpersist(id, removeFromDriver = false, blocking) } - protected def onDestroy() { - TorrentBroadcast.unpersist(id, removeFromDriver = true) + protected def onDestroy(blocking: Boolean) { + TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking) } private def sendBroadcast() { @@ -242,8 +242,8 @@ private[spark] object TorrentBroadcast extends Logging { * Remove all persisted blocks associated with this torrent broadcast on the executors. * If removeFromDriver is true, also remove these persisted blocks on the driver. */ - def unpersist(id: Long, removeFromDriver: Boolean) = synchronized { - SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver) + def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized { + SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala index eabe792b550bb..feb0e945fac19 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -36,8 +36,9 @@ class TorrentBroadcastFactory extends BroadcastFactory { /** * Remove all persisted state associated with the torrent broadcast with the given ID. * @param removeFromDriver Whether to remove state from the driver. + * @param blocking Whether to block until unbroadcasted */ - def unbroadcast(id: Long, removeFromDriver: Boolean) { - TorrentBroadcast.unpersist(id, removeFromDriver) + def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) { + TorrentBroadcast.unpersist(id, removeFromDriver, blocking) } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 616d24ccd8b6e..4c8e718539ec7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -829,12 +829,13 @@ private[spark] class BlockManager( /** * Remove all blocks belonging to the given broadcast. */ - def removeBroadcast(broadcastId: Long, tellMaster: Boolean) { + def removeBroadcast(broadcastId: Long, tellMaster: Boolean): Int = { logInfo("Removing broadcast " + broadcastId) val blocksToRemove = blockInfo.keys.collect { case bid @ BroadcastBlockId(`broadcastId`, _) => bid } blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster) } + blocksToRemove.size } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 73074e2188e65..29300de7d6638 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -117,14 +117,28 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log } } - /** Remove all blocks belonging to the given shuffle asynchronously. */ - def removeShuffle(shuffleId: Int) { - askDriverWithReply(RemoveShuffle(shuffleId)) + /** Remove all blocks belonging to the given shuffle. */ + def removeShuffle(shuffleId: Int, blocking: Boolean) { + val future = askDriverWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) + future.onFailure { + case e: Throwable => logError("Failed to remove shuffle " + shuffleId, e) + } + if (blocking) { + Await.result(future, timeout) + } } - /** Remove all blocks belonging to the given broadcast asynchronously. */ - def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean) { - askDriverWithReply(RemoveBroadcast(broadcastId, removeFromMaster)) + /** Remove all blocks belonging to the given broadcast. */ + def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) { + val future = askDriverWithReply[Future[Seq[Int]]](RemoveBroadcast(broadcastId, removeFromMaster)) + future.onFailure { + case e: Throwable => + logError("Failed to remove broadcast " + broadcastId + + " with removeFromMaster = " + removeFromMaster, e) + } + if (blocking) { + Await.result(future, timeout) + } } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 3b63bf3f3774d..f238820942e34 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -100,12 +100,10 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus sender ! removeRdd(rddId) case RemoveShuffle(shuffleId) => - removeShuffle(shuffleId) - sender ! true + sender ! removeShuffle(shuffleId) case RemoveBroadcast(broadcastId, removeFromDriver) => - removeBroadcast(broadcastId, removeFromDriver) - sender ! true + sender ! removeBroadcast(broadcastId, removeFromDriver) case RemoveBlock(blockId) => removeBlockFromWorkers(blockId) @@ -150,15 +148,22 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // The dispatcher is used as an implicit argument into the Future sequence construction. import context.dispatcher val removeMsg = RemoveRdd(rddId) - Future.sequence(blockManagerInfo.values.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] - }.toSeq) + Future.sequence( + blockManagerInfo.values.map { bm => + bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] + }.toSeq + ) } - private def removeShuffle(shuffleId: Int) { + private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = { // Nothing to do in the BlockManagerMasterActor data structures + import context.dispatcher val removeMsg = RemoveShuffle(shuffleId) - blockManagerInfo.values.foreach { bm => bm.slaveActor ! removeMsg } + Future.sequence( + blockManagerInfo.values.map { bm => + bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Boolean] + }.toSeq + ) } /** @@ -166,12 +171,18 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus * of all broadcast blocks. If removeFromDriver is false, broadcast blocks are only removed * from the executors, but not from the driver. */ - private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) { + private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = { // TODO: Consolidate usages of + import context.dispatcher val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver) - blockManagerInfo.values - .filter { info => removeFromDriver || info.blockManagerId.executorId != "" } - .foreach { bm => bm.slaveActor ! removeMsg } + val requiredBlockManagers = blockManagerInfo.values.filter { info => + removeFromDriver || info.blockManagerId.executorId != "" + } + Future.sequence( + requiredBlockManagers.map { bm => + bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] + }.toSeq + ) } private def removeBlockManager(blockManagerId: BlockManagerId) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 2396ca49a7d3f..5c91ad36371bc 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -19,7 +19,7 @@ package org.apache.spark.storage import scala.concurrent.Future -import akka.actor.Actor +import akka.actor.{ActorRef, Actor} import org.apache.spark.{Logging, MapOutputTracker} import org.apache.spark.storage.BlockManagerMessages._ @@ -39,35 +39,44 @@ class BlockManagerSlaveActor( // Operations that involve removing blocks may be slow and should be done asynchronously override def receive = { case RemoveBlock(blockId) => - val removeBlock = Future { blockManager.removeBlock(blockId) } - removeBlock.onFailure { case t: Throwable => - logError("Error in removing block " + blockId, t) + doAsync("removing block", sender) { + blockManager.removeBlock(blockId) + true } case RemoveRdd(rddId) => - val removeRdd = Future { sender ! blockManager.removeRdd(rddId) } - removeRdd.onFailure { case t: Throwable => - logError("Error in removing RDD " + rddId, t) + doAsync("removing RDD", sender) { + blockManager.removeRdd(rddId) } case RemoveShuffle(shuffleId) => - val removeShuffle = Future { + doAsync("removing shuffle", sender) { blockManager.shuffleBlockManager.removeShuffle(shuffleId) - if (mapOutputTracker != null) { - mapOutputTracker.unregisterShuffle(shuffleId) - } - } - removeShuffle.onFailure { case t: Throwable => - logError("Error in removing shuffle " + shuffleId, t) } case RemoveBroadcast(broadcastId, tellMaster) => - val removeBroadcast = Future { blockManager.removeBroadcast(broadcastId, tellMaster) } - removeBroadcast.onFailure { case t: Throwable => - logError("Error in removing broadcast " + broadcastId, t) + doAsync("removing RDD", sender) { + blockManager.removeBroadcast(broadcastId, tellMaster) } case GetBlockStatus(blockId, _) => sender ! blockManager.getStatus(blockId) } + + private def doAsync[T](actionMessage: String, responseActor: ActorRef)(body: => T) { + val future = Future { + logDebug(actionMessage) + val response = body + response + } + future.onSuccess { case response => + logDebug("Successful in " + actionMessage + ", response is " + response) + responseActor ! response + logDebug("Sent response: " + response + " to " + responseActor) + } + future.onFailure { case t: Throwable => + logError("Error in " + actionMessage, t) + responseActor ! null.asInstanceOf[T] + } + } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index 06233153c56d4..1f9732565709d 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -176,7 +176,7 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { } /** Remove all the blocks / files related to a particular shuffle. */ - private def removeShuffleBlocks(shuffleId: ShuffleId) { + private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = { shuffleStates.get(shuffleId) match { case Some(state) => if (consolidateShuffleFiles) { @@ -190,8 +190,10 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { } } logInfo("Deleted all files for shuffle " + shuffleId) + true case None => logInfo("Could not find files for shuffle " + shuffleId + " for deleting") + false } } diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index f1bfb6666ddda..e2f6ba80e0dbb 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -78,30 +78,48 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) } - test("Unpersisting HttpBroadcast on executors only") { - testUnpersistHttpBroadcast(2, removeFromDriver = false) + test("Unpersisting HttpBroadcast on executors only in local mode") { + testUnpersistHttpBroadcast(distributed = false, removeFromDriver = false) } - test("Unpersisting HttpBroadcast on executors and driver") { - testUnpersistHttpBroadcast(2, removeFromDriver = true) + test("Unpersisting HttpBroadcast on executors and driver in local mode") { + testUnpersistHttpBroadcast(distributed = false, removeFromDriver = true) } - test("Unpersisting TorrentBroadcast on executors only") { - testUnpersistTorrentBroadcast(2, removeFromDriver = false) + test("Unpersisting HttpBroadcast on executors only in distributed mode") { + testUnpersistHttpBroadcast(distributed = true, removeFromDriver = false) } - test("Unpersisting TorrentBroadcast on executors and driver") { - testUnpersistTorrentBroadcast(2, removeFromDriver = true) + test("Unpersisting HttpBroadcast on executors and driver in distributed mode") { + testUnpersistHttpBroadcast(distributed = true, removeFromDriver = true) } + test("Unpersisting TorrentBroadcast on executors only in local mode") { + testUnpersistTorrentBroadcast(distributed = false, removeFromDriver = false) + } + + test("Unpersisting TorrentBroadcast on executors and driver in local mode") { + testUnpersistTorrentBroadcast(distributed = false, removeFromDriver = true) + } + + test("Unpersisting TorrentBroadcast on executors only in distributed mode") { + testUnpersistTorrentBroadcast(distributed = true, removeFromDriver = false) + } + + test("Unpersisting TorrentBroadcast on executors and driver in distributed mode") { + testUnpersistTorrentBroadcast(distributed = true, removeFromDriver = true) + } /** - * Verify the persistence of state associated with an HttpBroadcast in a local-cluster. + * Verify the persistence of state associated with an HttpBroadcast in either local mode or + * local-cluster mode (when distributed = true). * * This test creates a broadcast variable, uses it on all executors, and then unpersists it. * In between each step, this test verifies that the broadcast blocks and the broadcast file * are present only on the expected nodes. */ - private def testUnpersistHttpBroadcast(numSlaves: Int, removeFromDriver: Boolean) { + private def testUnpersistHttpBroadcast(distributed: Boolean, removeFromDriver: Boolean) { + val numSlaves = if (distributed) 2 else 0 + def getBlockIds(id: Long) = Seq[BroadcastBlockId](BroadcastBlockId(id)) // Verify that the broadcast file is created, and blocks are persisted only on the driver @@ -115,7 +133,10 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { assert(status.memSize > 0, "Block should be in memory store on the driver") assert(status.diskSize === 0, "Block should not be in disk store on the driver") } - assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists, "Broadcast file not found!") + if (distributed) { + // this file is only generated in distributed mode + assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists, "Broadcast file not found!") + } } // Verify that blocks are persisted in both the executors and the driver @@ -138,12 +159,15 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { val expectedNumBlocks = if (removeFromDriver) 0 else 1 val possiblyNot = if (removeFromDriver) "" else " not" assert(statuses.size === expectedNumBlocks, - "Block should%s be unpersisted on the driver".format(possiblyNot)) - assert(removeFromDriver === !HttpBroadcast.getFile(blockIds.head.broadcastId).exists, - "Broadcast file should%s be deleted".format(possiblyNot)) + "Block should%s be unpersisted on the driver".format(possiblyNot)) + if (distributed && removeFromDriver) { + // this file is only generated in distributed mode + assert(!HttpBroadcast.getFile(blockIds.head.broadcastId).exists, + "Broadcast file should%s be deleted".format(possiblyNot)) + } } - testUnpersistBroadcast(numSlaves, httpConf, getBlockIds, afterCreation, + testUnpersistBroadcast(distributed, numSlaves, httpConf, getBlockIds, afterCreation, afterUsingBroadcast, afterUnpersist, removeFromDriver) } @@ -154,13 +178,20 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { * In between each step, this test verifies that the broadcast blocks are present only on the * expected nodes. */ - private def testUnpersistTorrentBroadcast(numSlaves: Int, removeFromDriver: Boolean) { + private def testUnpersistTorrentBroadcast(distributed: Boolean, removeFromDriver: Boolean) { + val numSlaves = if (distributed) 2 else 0 + def getBlockIds(id: Long) = { val broadcastBlockId = BroadcastBlockId(id) val metaBlockId = BroadcastBlockId(id, "meta") // Assume broadcast value is small enough to fit into 1 piece val pieceBlockId = BroadcastBlockId(id, "piece0") - Seq[BroadcastBlockId](broadcastBlockId, metaBlockId, pieceBlockId) + if (distributed) { + // the metadata and piece blocks are generated only in distributed mode + Seq[BroadcastBlockId](broadcastBlockId, metaBlockId, pieceBlockId) + } else { + Seq[BroadcastBlockId](broadcastBlockId) + } } // Verify that blocks are persisted only on the driver @@ -187,7 +218,8 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { statuses.head match { case (bm, _) => assert(bm.executorId === "") } } else { // Other blocks are on both the executors and the driver - assert(statuses.size === numSlaves + 1) + assert(statuses.size === numSlaves + 1, + blockId + " has " + statuses.size + " statuses: " + statuses.mkString(",")) statuses.foreach { case (_, status) => assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) assert(status.memSize > 0, "Block should be in memory store") @@ -209,7 +241,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { } } - testUnpersistBroadcast(numSlaves, torrentConf, getBlockIds, afterCreation, + testUnpersistBroadcast(distributed, numSlaves, torrentConf, getBlockIds, afterCreation, afterUsingBroadcast, afterUnpersist, removeFromDriver) } @@ -223,7 +255,8 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { * 4) [Optional] If removeFromDriver is false, we verify that the broadcast is re-usable. */ private def testUnpersistBroadcast( - numSlaves: Int, + distributed: Boolean, + numSlaves: Int, // used only when distributed = true broadcastConf: SparkConf, getBlockIds: Long => Seq[BroadcastBlockId], afterCreation: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, @@ -231,7 +264,11 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { afterUnpersist: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, removeFromDriver: Boolean) { - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf) + sc = if (distributed) { + new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf) + } else { + new SparkContext("local", "test", broadcastConf) + } val blockManagerMaster = sc.env.blockManager.master val list = List[Int](1, 2, 3, 4) @@ -241,15 +278,17 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { afterCreation(blocks, blockManagerMaster) // Use broadcast variable on all executors - val results = sc.parallelize(1 to numSlaves, numSlaves).map(x => (x, broadcast.value.sum)) - assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + val partitions = 10 + assert(partitions > numSlaves) + val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet) afterUsingBroadcast(blocks, blockManagerMaster) // Unpersist broadcast if (removeFromDriver) { - broadcast.destroy() + broadcast.destroy(blocking = true) } else { - broadcast.unpersist() + broadcast.unpersist(blocking = true) } afterUnpersist(blocks, blockManagerMaster) @@ -260,8 +299,8 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { // Instead, crash the driver by directly accessing the broadcast value. intercept[SparkException] { broadcast.value } } else { - val results = sc.parallelize(1 to numSlaves, numSlaves).map(x => (x, broadcast.value.sum)) - assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet) } } From 6222697f41a69ac12717f3b421aa303b26c4642c Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 4 Apr 2014 10:10:59 -0700 Subject: [PATCH 31/37] Fixed bug and adding unit test for removeBroadcast in BlockManagerSuite. --- .../scala/org/apache/spark/storage/ShuffleBlockManager.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index 1f9732565709d..4eeeb9aa9c7ab 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -170,9 +170,9 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { } /** Remove all the blocks / files and metadata related to a particular shuffle. */ - def removeShuffle(shuffleId: ShuffleId) { - removeShuffleBlocks(shuffleId) + def removeShuffle(shuffleId: ShuffleId): Boolean = { shuffleStates.remove(shuffleId) + removeShuffleBlocks(shuffleId) } /** Remove all the blocks / files related to a particular shuffle. */ From 41c9ecec21e61bcc077d6e3ea052a3b7a2d4b01a Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 6 Apr 2014 17:49:48 -0700 Subject: [PATCH 32/37] Added more unit tests for BlockManager, DiskBlockManager, and ContextCleaner. --- .../org/apache/spark/ContextCleaner.scala | 12 +- .../scala/org/apache/spark/SparkContext.scala | 2 +- .../apache/spark/storage/BlockManager.scala | 9 + .../spark/storage/BlockManagerMaster.scala | 23 ++- .../storage/BlockManagerMasterActor.scala | 33 ++- .../spark/storage/BlockManagerMessages.scala | 3 + .../storage/BlockManagerSlaveActor.scala | 16 +- .../spark/storage/DiskBlockManager.scala | 10 + .../spark/storage/ShuffleBlockManager.scala | 5 +- .../org/apache/spark/BroadcastSuite.scala | 12 +- .../apache/spark/ContextCleanerSuite.scala | 189 +++++++++++++----- .../spark/storage/BlockManagerSuite.scala | 106 ++++++++++ .../spark/storage/DiskBlockManagerSuite.scala | 7 + 13 files changed, 355 insertions(+), 72 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 75c87a0553a7a..cd29ac05bdfb2 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -63,6 +63,9 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private val cleaningThread = new Thread() { override def run() { keepCleaning() }} + /** Whether the cleaning thread will block on cleanup tasks */ + private val blockOnCleanupTasks = sc.conf.getBoolean("spark.cleaner.referenceTracking.blocking", false) + @volatile private var stopped = false /** Attach a listener object to get information of when objects are cleaned. */ @@ -112,9 +115,12 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { logDebug("Got cleaning task " + task) referenceBuffer -= reference.get task match { - case CleanRDD(rddId) => doCleanupRDD(rddId, blocking = false) - case CleanShuffle(shuffleId) => doCleanupShuffle(shuffleId, blocking = false) - case CleanBroadcast(broadcastId) => doCleanupBroadcast(broadcastId, blocking = false) + case CleanRDD(rddId) => + doCleanupRDD(rddId, blocking = blockOnCleanupTasks) + case CleanShuffle(shuffleId) => + doCleanupShuffle(shuffleId, blocking = blockOnCleanupTasks) + case CleanBroadcast(broadcastId) => + doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks) } } } catch { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 579c963094a78..c8d659d656ef4 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -229,7 +229,7 @@ class SparkContext( dagScheduler.start() private[spark] val cleaner: Option[ContextCleaner] = - if (conf.getBoolean("spark.cleaner.automatic", true)) { + if (conf.getBoolean("spark.cleaner.referenceTracking", true)) { Some(new ContextCleaner(this)) } else None diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 4c8e718539ec7..e684831c00abc 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -218,6 +218,15 @@ private[spark] class BlockManager( } } + /** + * Get the ids of existing blocks that match the given filter. Note that this will + * query the blocks stored in the disk block manager (that the block manager + * may not know of). + */ + def getMatchingBlockIds(filter: BlockId => Boolean): Seq[BlockId] = { + (blockInfo.keys ++ diskBlockManager.getAllBlocks()).filter(filter).toSeq + } + /** * Tell the master about the current storage status of a block. This will send a block update * message reflecting the current status, *not* the desired storage level in its block info. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 29300de7d6638..d939c5da96967 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -130,7 +130,8 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log /** Remove all blocks belonging to the given broadcast. */ def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) { - val future = askDriverWithReply[Future[Seq[Int]]](RemoveBroadcast(broadcastId, removeFromMaster)) + val future = askDriverWithReply[Future[Seq[Int]]]( + RemoveBroadcast(broadcastId, removeFromMaster)) future.onFailure { case e: Throwable => logError("Failed to remove broadcast " + broadcastId + @@ -156,8 +157,8 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log } /** - * Return the block's status on all block managers, if any. This can potentially be an - * expensive operation and is used mainly for testing. + * Return the block's status on all block managers, if any. NOTE: This is a + * potentially expensive operation and should only be used for testing. * * If askSlaves is true, this invokes the master to query each block manager for the most * updated block statuses. This is useful when the master is not informed of the given block @@ -184,6 +185,22 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log }.toMap } + /** + * Return a list of ids of existing blocks such that the ids match the given filter. NOTE: This + * is a potentially expensive operation and should only be used for testing. + * + * If askSlaves is true, this invokes the master to query each block manager for the most + * updated block statuses. This is useful when the master is not informed of the given block + * by all block managers. + */ + def getMatcinghBlockIds( + filter: BlockId => Boolean, + askSlaves: Boolean): Seq[BlockId] = { + val msg = GetMatchingBlockIds(filter, askSlaves) + val future = askDriverWithReply[Future[Seq[BlockId]]](msg) + Await.result(future, timeout) + } + /** Stop the driver actor, called only on the Spark driver node */ def stop() { if (driverActor != null) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index f238820942e34..69f261b2002a6 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -96,6 +96,9 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus case GetBlockStatus(blockId, askSlaves) => sender ! blockStatus(blockId, askSlaves) + case GetMatchingBlockIds(filter, askSlaves) => + sender ! getMatchingBlockIds(filter, askSlaves) + case RemoveRdd(rddId) => sender ! removeRdd(rddId) @@ -266,8 +269,8 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus } /** - * Return the block's status for all block managers, if any. This can potentially be an - * expensive operation and is used mainly for testing. + * Return the block's status for all block managers, if any. NOTE: This is a + * potentially expensive operation and should only be used for testing. * * If askSlaves is true, the master queries each block manager for the most updated block * statuses. This is useful when the master is not informed of the given block by all block @@ -294,6 +297,32 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus }.toMap } + /** + * Return the ids of blocks present in all the block managers that match the given filter. + * NOTE: This is a potentially expensive operation and should only be used for testing. + * + * If askSlaves is true, the master queries each block manager for the most updated block + * statuses. This is useful when the master is not informed of the given block by all block + * managers. + */ + private def getMatchingBlockIds( + filter: BlockId => Boolean, + askSlaves: Boolean): Future[Seq[BlockId]] = { + import context.dispatcher + val getMatchingBlockIds = GetMatchingBlockIds(filter) + Future.sequence( + blockManagerInfo.values.map { info => + val future = + if (askSlaves) { + info.slaveActor.ask(getMatchingBlockIds)(akkaTimeout).mapTo[Seq[BlockId]] + } else { + Future { info.blocks.keys.filter(filter).toSeq } + } + future + } + ).map(_.flatten.toSeq) + } + private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index afb2c6a12ce67..365e3900731dc 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -115,5 +115,8 @@ private[storage] object BlockManagerMessages { case class GetBlockStatus(blockId: BlockId, askSlaves: Boolean = true) extends ToBlockManagerMaster + case class GetMatchingBlockIds(filter: BlockId => Boolean, askSlaves: Boolean = true) + extends ToBlockManagerMaster + case object ExpireDeadHosts extends ToBlockManagerMaster } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 5c91ad36371bc..fc22f54ceb9d8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -39,28 +39,34 @@ class BlockManagerSlaveActor( // Operations that involve removing blocks may be slow and should be done asynchronously override def receive = { case RemoveBlock(blockId) => - doAsync("removing block", sender) { + doAsync[Boolean]("removing block", sender) { blockManager.removeBlock(blockId) true } case RemoveRdd(rddId) => - doAsync("removing RDD", sender) { + doAsync[Int]("removing RDD", sender) { blockManager.removeRdd(rddId) } case RemoveShuffle(shuffleId) => - doAsync("removing shuffle", sender) { + doAsync[Boolean]("removing shuffle", sender) { + if (mapOutputTracker != null) { + mapOutputTracker.unregisterShuffle(shuffleId) + } blockManager.shuffleBlockManager.removeShuffle(shuffleId) } case RemoveBroadcast(broadcastId, tellMaster) => - doAsync("removing RDD", sender) { + doAsync[Int]("removing RDD", sender) { blockManager.removeBroadcast(broadcastId, tellMaster) } case GetBlockStatus(blockId, _) => sender ! blockManager.getStatus(blockId) + + case GetMatchingBlockIds(filter, _) => + sender ! blockManager.getMatchingBlockIds(filter) } private def doAsync[T](actionMessage: String, responseActor: ActorRef)(body: => T) { @@ -70,7 +76,7 @@ class BlockManagerSlaveActor( response } future.onSuccess { case response => - logDebug("Successful in " + actionMessage + ", response is " + response) + logDebug("Done " + actionMessage + ", response is " + response) responseActor ! response logDebug("Sent response: " + response + " to " + responseActor) } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index fcad84669c79a..47a1a6d4a5869 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -47,6 +47,7 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) private var shuffleSender : ShuffleSender = null + addShutdownHook() /** @@ -95,6 +96,15 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD getBlockLocation(blockId).file.exists() } + /** List all the blocks currently stored in disk by the disk manager. */ + def getAllBlocks(): Seq[BlockId] = { + // Get all the files inside the array of array of directories + subDirs.flatten.filter(_ != null).flatMap { dir => + val files = dir.list() + if (files != null) files else Seq.empty + }.map(BlockId.apply) + } + /** Produces a unique block id and File suitable for intermediate results. */ def createTempBlock(): (TempBlockId, File) = { var blockId = new TempBlockId(UUID.randomUUID()) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index 4eeeb9aa9c7ab..4cd4cdbd9909d 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -171,8 +171,11 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { /** Remove all the blocks / files and metadata related to a particular shuffle. */ def removeShuffle(shuffleId: ShuffleId): Boolean = { + // Do not change the ordering of this, if shuffleStates should be removed only + // after the corresponding shuffle blocks have been removed + val cleaned = removeShuffleBlocks(shuffleId) shuffleStates.remove(shuffleId) - removeShuffleBlocks(shuffleId) + cleaned } /** Remove all the blocks / files related to a particular shuffle. */ diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index e2f6ba80e0dbb..79dcc4a159235 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -125,7 +125,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { // Verify that the broadcast file is created, and blocks are persisted only on the driver def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { assert(blockIds.size === 1) - val statuses = bmm.getBlockStatus(blockIds.head) + val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true) assert(statuses.size === 1) statuses.head match { case (bm, status) => assert(bm.executorId === "", "Block should only be on the driver") @@ -142,7 +142,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { // Verify that blocks are persisted in both the executors and the driver def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { assert(blockIds.size === 1) - val statuses = bmm.getBlockStatus(blockIds.head) + val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true) assert(statuses.size === numSlaves + 1) statuses.foreach { case (_, status) => assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) @@ -155,7 +155,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { // is true. In the latter case, also verify that the broadcast file is deleted on the driver. def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { assert(blockIds.size === 1) - val statuses = bmm.getBlockStatus(blockIds.head) + val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true) val expectedNumBlocks = if (removeFromDriver) 0 else 1 val possiblyNot = if (removeFromDriver) "" else " not" assert(statuses.size === expectedNumBlocks, @@ -197,7 +197,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { // Verify that blocks are persisted only on the driver def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { blockIds.foreach { blockId => - val statuses = bmm.getBlockStatus(blockIds.head) + val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true) assert(statuses.size === 1) statuses.head match { case (bm, status) => assert(bm.executorId === "", "Block should only be on the driver") @@ -211,7 +211,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { // Verify that blocks are persisted in both the executors and the driver def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { blockIds.foreach { blockId => - val statuses = bmm.getBlockStatus(blockId) + val statuses = bmm.getBlockStatus(blockId, askSlaves = true) if (blockId.field == "meta") { // Meta data is only on the driver assert(statuses.size === 1) @@ -235,7 +235,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { val expectedNumBlocks = if (removeFromDriver) 0 else 1 val possiblyNot = if (removeFromDriver) "" else " not" blockIds.foreach { blockId => - val statuses = bmm.getBlockStatus(blockId) + val statuses = bmm.getBlockStatus(blockId, askSlaves = true) assert(statuses.size === expectedNumBlocks, "Block should%s be unpersisted on the driver".format(possiblyNot)) } diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 9eb434ed0ac0e..345bee6930c49 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -29,16 +29,28 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD -import org.apache.spark.storage.{BroadcastBlockId, RDDBlockId, ShuffleBlockId} +import org.apache.spark.storage.{BlockId, BroadcastBlockId, RDDBlockId, ShuffleBlockId} class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { implicit val defaultTimeout = timeout(10000 millis) + val conf = new SparkConf() + .setMaster("local[2]") + .setAppName("ContextCleanerSuite") + .set("spark.cleaner.referenceTracking.blocking", "true") before { - sc = new SparkContext("local[2]", "CleanerSuite") + sc = new SparkContext(conf) } + after { + if (sc != null) { + sc.stop() + sc = null + } + } + + test("cleanup RDD") { val rdd = newRDD.persist() val collected = rdd.collect().toList @@ -150,6 +162,40 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo postGCTester.assertCleanup() } + test("automatically cleanup RDD + shuffle + broadcast in distributed mode") { + sc.stop() + + val conf2 = new SparkConf() + .setMaster("local[4]") + //.setMaster("local-cluster[2, 1, 512]") + .setAppName("ContextCleanerSuite") + .set("spark.cleaner.referenceTracking.blocking", "true") + sc = new SparkContext(conf2) + + val numRdds = 10 + val numBroadcasts = 4 // Broadcasts are more costly + val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer + val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer + val rddIds = sc.persistentRdds.keys.toSeq + val shuffleIds = 0 until sc.newShuffleId + val broadcastIds = 0L until numBroadcasts + + val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) + runGC() + intercept[Exception] { + preGCTester.assertCleanup()(timeout(1000 millis)) + } + + // Test that GC triggers the cleanup of all variables after the dereferencing them + val postGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) + broadcastBuffer.clear() + rddBuffer.clear() + runGC() + postGCTester.assertCleanup() + } + + //------ Helper functions ------ + def newRDD = sc.makeRDD(1 to 10) def newPairRDD = newRDD.map(_ -> 1) def newShuffleRDD = newPairRDD.reduceByKey(_ + _) @@ -192,7 +238,6 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo // Wait until a weak reference object has been GCed while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { System.gc() - System.runFinalization() Thread.sleep(200) } } @@ -212,6 +257,7 @@ class CleanerTester( val toBeCleanedRDDIds = new HashSet[Int] with SynchronizedSet[Int] ++= rddIds val toBeCleanedShuffleIds = new HashSet[Int] with SynchronizedSet[Int] ++= shuffleIds val toBeCleanedBroadcstIds = new HashSet[Long] with SynchronizedSet[Long] ++= broadcastIds + val isDistributed = !sc.isLocal val cleanerListener = new CleanerListener { def rddCleaned(rddId: Int): Unit = { @@ -240,10 +286,9 @@ class CleanerTester( /** Assert that all the stuff has been cleaned up */ def assertCleanup()(implicit waitTimeout: Eventually.Timeout) { try { - eventually(waitTimeout, interval(10 millis)) { + eventually(waitTimeout, interval(100 millis)) { assert(isAllCleanedUp) } - Thread.sleep(100) // to allow async cleanup actions to be completed postCleanupValidate() } finally { logInfo("Resources left from cleaning up:\n" + uncleanedResourcesToString) @@ -255,20 +300,41 @@ class CleanerTester( assert(rddIds.nonEmpty || shuffleIds.nonEmpty || broadcastIds.nonEmpty, "Nothing to cleanup") // Verify the RDDs have been persisted and blocks are present - assert(rddIds.forall(sc.persistentRdds.contains), - "One or more RDDs have not been persisted, cannot start cleaner test") - assert(rddIds.forall(rddId => blockManager.master.contains(rddBlockId(rddId))), - "One or more RDDs' blocks cannot be found in block manager, cannot start cleaner test") + rddIds.foreach { rddId => + assert( + sc.persistentRdds.contains(rddId), + "RDD " + rddId + " have not been persisted, cannot start cleaner test" + ) + + assert( + !getRDDBlocks(rddId).isEmpty, + "Blocks of RDD " + rddId + " cannot be found in block manager, " + + "cannot start cleaner test" + ) + } // Verify the shuffle ids are registered and blocks are present - assert(shuffleIds.forall(mapOutputTrackerMaster.containsShuffle), - "One or more shuffles have not been registered cannot start cleaner test") - assert(shuffleIds.forall(sid => diskBlockManager.containsBlock(shuffleBlockId(sid))), - "One or more shuffles' blocks cannot be found in disk manager, cannot start cleaner test") - - // Verify that the broadcast is in the driver's block manager - assert(broadcastIds.forall(bid => blockManager.getStatus(broadcastBlockId(bid)).isDefined), - "One ore more broadcasts have not been persisted in the driver's block manager") + shuffleIds.foreach { shuffleId => + assert( + mapOutputTrackerMaster.containsShuffle(shuffleId), + "Shuffle " + shuffleId + " have not been registered, cannot start cleaner test" + ) + + assert( + !getShuffleBlocks(shuffleId).isEmpty, + "Blocks of shuffle " + shuffleId + " cannot be found in block manager, " + + "cannot start cleaner test" + ) + } + + // Verify that the broadcast blocks are present + broadcastIds.foreach { broadcastId => + assert( + !getBroadcastBlocks(broadcastId).isEmpty, + "Blocks of broadcast " + broadcastId + "cannot be found in block manager, " + + "cannot start cleaner test" + ) + } } /** @@ -276,41 +342,46 @@ class CleanerTester( * as there is not guarantee on how long it will take clean up the resources. */ private def postCleanupValidate() { - var attempts = 0 - while (attempts < MAX_VALIDATION_ATTEMPTS) { - attempts += 1 - logInfo("Attempt: " + attempts) - try { - // Verify all RDDs have been unpersisted - assert(rddIds.forall(!sc.persistentRdds.contains(_))) - assert(rddIds.forall(rddId => !blockManager.master.contains(rddBlockId(rddId)))) - - // Verify all shuffles have been deregistered and cleaned up - assert(shuffleIds.forall(!mapOutputTrackerMaster.containsShuffle(_))) - assert(shuffleIds.forall(sid => !diskBlockManager.containsBlock(shuffleBlockId(sid)))) - - // Verify all broadcasts have been unpersisted - assert(broadcastIds.forall { bid => - blockManager.master.getBlockStatus(broadcastBlockId(bid)).isEmpty - }) - - return - } catch { - case t: Throwable => - if (attempts >= MAX_VALIDATION_ATTEMPTS) { - throw t - } else { - Thread.sleep(VALIDATION_ATTEMPT_INTERVAL) - } - } + // Verify the RDDs have been persisted and blocks are present + rddIds.foreach { rddId => + assert( + !sc.persistentRdds.contains(rddId), + "RDD " + rddId + " was not cleared from sc.persistentRdds" + ) + + assert( + getRDDBlocks(rddId).isEmpty, + "Blocks of RDD " + rddId + " were not cleared from block manager" + ) + } + + // Verify the shuffle ids are registered and blocks are present + shuffleIds.foreach { shuffleId => + assert( + !mapOutputTrackerMaster.containsShuffle(shuffleId), + "Shuffle " + shuffleId + " was not deregistered from map output tracker" + ) + + assert( + getShuffleBlocks(shuffleId).isEmpty, + "Blocks of shuffle " + shuffleId + " were not cleared from block manager" + ) + } + + // Verify that the broadcast blocks are present + broadcastIds.foreach { broadcastId => + assert( + getBroadcastBlocks(broadcastId).isEmpty, + "Blocks of broadcast " + broadcastId + " were not cleared from block manager" + ) } } private def uncleanedResourcesToString = { s""" - |\tRDDs = ${toBeCleanedRDDIds.mkString("[", ", ", "]")} - |\tShuffles = ${toBeCleanedShuffleIds.mkString("[", ", ", "]")} - |\tBroadcasts = ${toBeCleanedBroadcstIds.mkString("[", ", ", "]")} + |\tRDDs = ${toBeCleanedRDDIds.toSeq.sorted.mkString("[", ", ", "]")} + |\tShuffles = ${toBeCleanedShuffleIds.toSeq.sorted.mkString("[", ", ", "]")} + |\tBroadcasts = ${toBeCleanedBroadcstIds.toSeq.sorted.mkString("[", ", ", "]")} """.stripMargin } @@ -319,11 +390,27 @@ class CleanerTester( toBeCleanedShuffleIds.isEmpty && toBeCleanedBroadcstIds.isEmpty - private def rddBlockId(rddId: Int) = RDDBlockId(rddId, 0) - private def shuffleBlockId(shuffleId: Int) = ShuffleBlockId(shuffleId, 0, 0) - private def broadcastBlockId(broadcastId: Long) = BroadcastBlockId(broadcastId) + private def getRDDBlocks(rddId: Int): Seq[BlockId] = { + blockManager.master.getMatcinghBlockIds( _ match { + case RDDBlockId(rddId, _) => true + case _ => false + }, askSlaves = true) + } + + private def getShuffleBlocks(shuffleId: Int): Seq[BlockId] = { + blockManager.master.getMatcinghBlockIds( _ match { + case ShuffleBlockId(shuffleId, _, _) => true + case _ => false + }, askSlaves = true) + } + + private def getBroadcastBlocks(broadcastId: Long): Seq[BlockId] = { + blockManager.master.getMatcinghBlockIds( _ match { + case BroadcastBlockId(broadcastId, _) => true + case _ => false + }, askSlaves = true) + } private def blockManager = sc.env.blockManager - private def diskBlockManager = blockManager.diskBlockManager private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index b47de5eab95a4..970b4f70ee6d7 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -262,6 +262,78 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT master.getLocations(rdd(0, 1)) should have size 0 } + test("removing broadcast") { + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) + val driverStore = store + val executorStore = new BlockManager("executor", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + val a4 = new Array[Byte](400) + + val broadcast0BlockId = BroadcastBlockId(0) + val broadcast1BlockId = BroadcastBlockId(1) + val broadcast2BlockId = BroadcastBlockId(2) + val broadcast2BlockId2 = BroadcastBlockId(2, "_") + + // insert broadcast blocks in both the stores + Seq(driverStore, executorStore).foreach { case s => + s.putSingle(broadcast0BlockId, a1, StorageLevel.DISK_ONLY) + s.putSingle(broadcast1BlockId, a2, StorageLevel.DISK_ONLY) + s.putSingle(broadcast2BlockId, a3, StorageLevel.DISK_ONLY) + s.putSingle(broadcast2BlockId2, a4, StorageLevel.DISK_ONLY) + } + + // verify whether the blocks exist in both the stores + Seq(driverStore, executorStore).foreach { case s => + s.getLocal(broadcast0BlockId) should not be (None) + s.getLocal(broadcast1BlockId) should not be (None) + s.getLocal(broadcast2BlockId) should not be (None) + s.getLocal(broadcast2BlockId2) should not be (None) + } + + // remove broadcast 0 block only from executors + master.removeBroadcast(0, removeFromMaster = false, blocking = true) + + // only broadcast 0 block should be removed from the executor store + executorStore.getLocal(broadcast0BlockId) should be (None) + executorStore.getLocal(broadcast1BlockId) should not be (None) + executorStore.getLocal(broadcast2BlockId) should not be (None) + + // nothing should be removed from the driver store + driverStore.getLocal(broadcast0BlockId) should not be (None) + driverStore.getLocal(broadcast1BlockId) should not be (None) + driverStore.getLocal(broadcast2BlockId) should not be (None) + + // remove broadcast 0 block from the driver as well + master.removeBroadcast(0, removeFromMaster = true, blocking = true) + driverStore.getLocal(broadcast0BlockId) should be (None) + driverStore.getLocal(broadcast1BlockId) should not be (None) + + // remove broadcast 1 block from both the stores asynchronously + // and verify all broadcast 1 blocks have been removed + master.removeBroadcast(1, removeFromMaster = true, blocking = false) + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + driverStore.getLocal(broadcast1BlockId) should be (None) + executorStore.getLocal(broadcast1BlockId) should be (None) + } + + // remove broadcast 2 from both the stores asynchronously + // and verify all broadcast 2 blocks have been removed + master.removeBroadcast(2, removeFromMaster = true, blocking = false) + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + driverStore.getLocal(broadcast2BlockId) should be (None) + driverStore.getLocal(broadcast2BlockId2) should be (None) + executorStore.getLocal(broadcast2BlockId) should be (None) + executorStore.getLocal(broadcast2BlockId2) should be (None) + } + executorStore.stop() + driverStore.stop() + store = null + } + test("reregistration on heart beat") { val heartBeat = PrivateMethod[Unit]('heartBeat) store = new BlockManager("", actorSystem, master, serializer, 2000, conf, @@ -785,6 +857,40 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT assert(store.master.getBlockStatus("list6", askSlaves = true).size === 1) } + test("get matching blocks") { + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) + val list = List.fill(2)(new Array[Byte](10)) + + // Tell master. By LRU, only list2 and list3 remains. + store.put("list1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.put("list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.put("list3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + + // getLocations and getBlockStatus should yield the same locations + assert(store.master.getMatcinghBlockIds(_.toString.contains("list"), askSlaves = false).size === 3) + assert(store.master.getMatcinghBlockIds(_.toString.contains("list1"), askSlaves = false).size === 1) + + // Tell master. By LRU, only list2 and list3 remains. + store.put("newlist1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.put("newlist2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + store.put("newlist3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + + // getLocations and getBlockStatus should yield the same locations + assert(store.master.getMatcinghBlockIds(_.toString.contains("newlist"), askSlaves = false).size === 1) + assert(store.master.getMatcinghBlockIds(_.toString.contains("newlist"), askSlaves = true).size === 3) + + val blockIds = Seq(RDDBlockId(1, 0), RDDBlockId(1, 1), RDDBlockId(2, 0)) + blockIds.foreach { blockId => + store.put(blockId, list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + } + val matchedBlockIds = store.master.getMatcinghBlockIds(_ match { + case RDDBlockId(1, _) => true + case _ => false + }, askSlaves = true) + assert(matchedBlockIds.toSet === Set(RDDBlockId(1, 0), RDDBlockId(1, 1))) + } + test("SPARK-1194 regression: fix the same-RDD rule for cache replacement") { store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr, mapOutputTracker) diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index 0dd34223787cd..808ddfdcf45d8 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -64,6 +64,13 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach { assert(!diskBlockManager.containsBlock(blockId)) } + test("enumerating blocks") { + val ids = (1 to 100).map(i => TestBlockId("test_" + i)) + val files = ids.map(id => diskBlockManager.getFile(id)) + files.foreach(file => writeToFile(file, 10)) + assert(diskBlockManager.getAllBlocks.toSet === ids.toSet) + } + test("block appending") { val blockId = new TestBlockId("test") val newFile = diskBlockManager.getFile(blockId) From 2b95b5eefcde33a401d4ae2f4c568005c5a77650 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 6 Apr 2014 20:04:51 -0700 Subject: [PATCH 33/37] Added more documentation on Broadcast implementations, specially which blocks are told about to the driver. Also, fixed Broadcast API to hide destroy functionality. --- .../org/apache/spark/ContextCleaner.scala | 2 +- .../apache/spark/broadcast/Broadcast.scala | 58 ++++++++++++++----- .../spark/broadcast/HttpBroadcast.scala | 36 +++++++++--- .../broadcast/HttpBroadcastFactory.scala | 4 +- .../spark/broadcast/TorrentBroadcast.scala | 52 +++++++++++++---- .../broadcast/TorrentBroadcastFactory.scala | 4 +- .../org/apache/spark/BroadcastSuite.scala | 4 +- 7 files changed, 122 insertions(+), 38 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index cd29ac05bdfb2..7cfd7cf06d33d 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -76,7 +76,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { /** Start the cleaner. */ def start() { cleaningThread.setDaemon(true) - cleaningThread.setName("ContextCleaner") + cleaningThread.setName("Spark Context Cleaner") cleaningThread.start() } diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index f28b6565a830c..738a3b1bed7f3 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -28,7 +28,8 @@ import org.apache.spark.SparkException * attempts to distribute broadcast variables using efficient broadcast algorithms to reduce * communication cost. * - * Broadcast variables are created from a variable `v` by calling [[SparkContext#broadcast]]. + * Broadcast variables are created from a variable `v` by calling + * [[org.apache.spark.SparkContext#broadcast]]. * The broadcast variable is a wrapper around `v`, and its value can be accessed by calling the * `value` method. The interpreter session below shows this: * @@ -51,15 +52,17 @@ import org.apache.spark.SparkException */ abstract class Broadcast[T](val id: Long) extends Serializable { - protected var _isValid: Boolean = true - /** - * Whether this Broadcast is actually usable. This should be false once persisted state is - * removed from the driver. + * Flag signifying whether the broadcast variable is valid + * (that is, not already destroyed) or not. */ - def isValid: Boolean = _isValid + @volatile private var _isValid = true - def value: T + /** Get the broadcasted value. */ + def value: T = { + assertValid() + getValue() + } /** * Asynchronously delete cached copies of this broadcast on the executors. @@ -74,23 +77,50 @@ abstract class Broadcast[T](val id: Long) extends Serializable { * this is called, it will need to be re-sent to each executor. * @param blocking Whether to block until unpersisting has completed */ - def unpersist(blocking: Boolean) + def unpersist(blocking: Boolean) { + assertValid() + doUnpersist(blocking) + } /** - * Remove all persisted state associated with this broadcast on both the executors and - * the driver. + * Destroy all data and metadata related to this broadcast variable. Use this with caution; + * once a broadcast variable has been destroyed, it cannot be used again. */ private[spark] def destroy(blocking: Boolean) { + assertValid() _isValid = false - onDestroy(blocking) + doDestroy(blocking) } - protected def onDestroy(blocking: Boolean) + /** + * Whether this Broadcast is actually usable. This should be false once persisted state is + * removed from the driver. + */ + private[spark] def isValid: Boolean = { + _isValid + } + + /** + * Actually get the broadcasted value. Concrete implementations of Broadcast class must + * define their own way to get the value. + */ + private[spark] def getValue(): T /** - * If this broadcast is no longer valid, throw an exception. + * Actually unpersist the broadcasted value on the executors. Concrete implementations of + * Broadcast class must define their own logic to unpersist their own data. */ - protected def assertValid() { + private[spark] def doUnpersist(blocking: Boolean) + + /** + * Actually destroy all data and metadata related to this broadcast variable. + * Implementation of Broadcast class must define their own logic to destroy their own + * state. + */ + private[spark] def doDestroy(blocking: Boolean) + + /** Check if this broadcast is valid. If not valid, exception is thrown. */ + private[spark] def assertValid() { if (!_isValid) { throw new SparkException("Attempted to use %s after it has been destroyed!".format(toString)) } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 2d5e0352f4265..02158afa972a2 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -28,16 +28,26 @@ import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils} +/** + * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses HTTP server + * as a broadcast mechanism. The first time a HTTP broadcast variable (sent as part of a + * task) is deserialized in the executor, the broadcasted data is fetched from the driver + * (through a HTTP server running at the driver) and stored in the BlockManager of the + * executor to speed up future accesses. + */ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { - def value: T = { - assertValid() - value_ - } + def getValue = value_ val blockId = BroadcastBlockId(id) + /* + * Broadcasted data is also stored in the BlockManager of the driver. + * The BlockManagerMaster + * does not need to be told about this block as not only + * need to know about this data block. + */ HttpBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) @@ -50,21 +60,24 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea /** * Remove all persisted state associated with this HTTP broadcast on the executors. */ - def unpersist(blocking: Boolean) { + def doUnpersist(blocking: Boolean) { HttpBroadcast.unpersist(id, removeFromDriver = false, blocking) } - protected def onDestroy(blocking: Boolean) { + /** + * Remove all persisted state associated with this HTTP broadcast on the executors and driver. + */ + def doDestroy(blocking: Boolean) { HttpBroadcast.unpersist(id, removeFromDriver = true, blocking) } - // Used by the JVM when serializing this object + /** Used by the JVM when serializing this object. */ private def writeObject(out: ObjectOutputStream) { assertValid() out.defaultWriteObject() } - // Used by the JVM when deserializing this object + /** Used by the JVM when deserializing this object. */ private def readObject(in: ObjectInputStream) { in.defaultReadObject() HttpBroadcast.synchronized { @@ -74,6 +87,13 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea logInfo("Started reading broadcast variable " + id) val start = System.nanoTime value_ = HttpBroadcast.read[T](id) + /* + * Storing the broadcast data in BlockManager so that all + * so that all subsequent tasks using the broadcast variable + * does not need to fetch it again. The BlockManagerMaster + * does not need to be told about this block as no one + * needs to know about this data block. + */ SparkEnv.get.blockManager.putSingle( blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) val time = (System.nanoTime - start) / 1e9 diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala index 2958e4f4c658a..e3f6cdc6154dd 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala @@ -20,7 +20,9 @@ package org.apache.spark.broadcast import org.apache.spark.{SecurityManager, SparkConf} /** - * A [[BroadcastFactory]] implementation that uses a HTTP server as the broadcast medium. + * A [[org.apache.spark.broadcast.BroadcastFactory]] implementation that uses a + * HTTP server as the broadcast mechanism. Refer to + * [[org.apache.spark.broadcast.HttpBroadcast]] for more details about this mechanism. */ class HttpBroadcastFactory extends BroadcastFactory { def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 7f37e306f0d07..2b32546c6854d 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -26,13 +26,28 @@ import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.Utils +/** + * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like + * protocol to do a distributed transfer of the broadcasted data to the executors. + * The mechanism is as follows. The driver divides the serializes the broadcasted data, + * divides it into smaller chunks, and stores them in the BlockManager of the driver. + * These chunks are reported to the BlockManagerMaster so that all the executors can + * learn the location of those chunks. The first time the broadcast variable (sent as + * part of task) is deserialized at a executor, all the chunks are fetched using + * the BlockManager. When all the chunks are fetched (initially from the driver's + * BlockManager), they are combined and deserialized to recreate the broadcasted data. + * However, the chunks are also stored in the BlockManager and reported to the + * BlockManagerMaster. As more executors fetch the chunks, BlockManagerMaster learns + * multiple locations for each chunk. Hence, subsequent fetches of each chunk will be + * made to other executors who already have those chunks, resulting in a distributed + * fetching. This prevents the driver from being the bottleneck in sending out multiple + * copies of the broadcast data (one per executor) as done by the + * [[org.apache.spark.broadcast.HttpBroadcast]]. + */ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { - def value = { - assertValid() - value_ - } + def getValue = value_ val broadcastId = BroadcastBlockId(id) @@ -53,15 +68,19 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo /** * Remove all persisted state associated with this Torrent broadcast on the executors. */ - def unpersist(blocking: Boolean) { + def doUnpersist(blocking: Boolean) { TorrentBroadcast.unpersist(id, removeFromDriver = false, blocking) } - protected def onDestroy(blocking: Boolean) { + /** + * Remove all persisted state associated with this Torrent broadcast on the executors + * and driver. + */ + def doDestroy(blocking: Boolean) { TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking) } - private def sendBroadcast() { + def sendBroadcast() { val tInfo = TorrentBroadcast.blockifyObject(value_) totalBlocks = tInfo.totalBlocks totalBytes = tInfo.totalBytes @@ -85,13 +104,13 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo } } - // Used by the JVM when serializing this object + /** Used by the JVM when serializing this object. */ private def writeObject(out: ObjectOutputStream) { assertValid() out.defaultWriteObject() } - // Used by the JVM when deserializing this object + /** Used by the JVM when deserializing this object. */ private def readObject(in: ObjectInputStream) { in.defaultReadObject() TorrentBroadcast.synchronized { @@ -111,7 +130,11 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo /* Store the merged copy in cache so that the next worker doesn't need to rebuild it. * This creates a trade-off between memory usage and latency. Storing copy doubles - * the memory footprint; not storing doubles deserialization cost. */ + * the memory footprint; not storing doubles deserialization cost. Also, + * this does not need to be reported to BlockManagerMaster since other executors + * does not need to access this block (they only need to fetch the chunks, + * which are reported). + */ SparkEnv.get.blockManager.putSingle( broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) @@ -135,7 +158,8 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo } def receiveBroadcast(): Boolean = { - // Receive meta-info + // Receive meta-info about the size of broadcast data, + // the number of chunks it is divided into, etc. val metaId = BroadcastBlockId(id, "meta") var attemptId = 10 while (attemptId > 0 && totalBlocks == -1) { @@ -158,7 +182,11 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo return false } - // Receive actual blocks + /* + * Fetch actual chunks of data. Note that all these chunks are stored in + * the BlockManager and reported to the master, so that other executors + * can find out and pull the chunks from this executor. + */ val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList) for (pid <- recvOrder) { val pieceId = BroadcastBlockId(id, "piece" + pid) diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala index feb0e945fac19..d216b58718148 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -20,7 +20,9 @@ package org.apache.spark.broadcast import org.apache.spark.{SecurityManager, SparkConf} /** - * A [[BroadcastFactory]] that creates a torrent-based implementation of broadcast. + * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like + * protocol to do a distributed transfer of the broadcasted data to the executors. Refer to + * [[org.apache.spark.broadcast.TorrentBroadcast]] for more details. */ class TorrentBroadcastFactory extends BroadcastFactory { diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index 79dcc4a159235..c9936256a5b95 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark import org.scalatest.FunSuite import org.apache.spark.storage._ -import org.apache.spark.broadcast.HttpBroadcast +import org.apache.spark.broadcast.{Broadcast, HttpBroadcast} import org.apache.spark.storage.BroadcastBlockId class BroadcastSuite extends FunSuite with LocalSparkContext { @@ -298,6 +298,8 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { // Using this variable on the executors crashes them, which hangs the test. // Instead, crash the driver by directly accessing the broadcast value. intercept[SparkException] { broadcast.value } + intercept[SparkException] { broadcast.unpersist() } + intercept[SparkException] { broadcast.destroy(blocking = true) } } else { val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum)) assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet) From 4d05314ca204d3792de0411169ee4f4a5169594f Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 6 Apr 2014 21:09:55 -0700 Subject: [PATCH 34/37] Scala style fix. --- .../main/scala/org/apache/spark/ContextCleaner.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 7cfd7cf06d33d..250d9d55c6211 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -63,8 +63,11 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private val cleaningThread = new Thread() { override def run() { keepCleaning() }} - /** Whether the cleaning thread will block on cleanup tasks */ - private val blockOnCleanupTasks = sc.conf.getBoolean("spark.cleaner.referenceTracking.blocking", false) + /** + * Whether the cleaning thread will block on cleanup tasks. + * This is set to true only for tests. */ + private val blockOnCleanupTasks = sc.conf.getBoolean( + "spark.cleaner.referenceTracking.blocking", false) @volatile private var stopped = false @@ -170,7 +173,8 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private def broadcastManager = sc.env.broadcastManager private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] - // Used for testing, explicitly blocks until cleanup is completed + // Used for testing. These methods explicitly blocks until cleanup is completed + // to ensure that more reliable testing. def cleanupRDD(rdd: RDD[_]) { doCleanupRDD(rdd.id, blocking = true) From cff023c4cde102834c8a0fb12d7d8500f33675e8 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 7 Apr 2014 13:32:17 -0700 Subject: [PATCH 35/37] Fixed issues based on Andrew's comments. --- .../apache/spark/broadcast/HttpBroadcast.scala | 6 ++---- .../spark/storage/BlockManagerMaster.scala | 2 +- .../spark/storage/BlockManagerSlaveActor.scala | 11 +++++------ .../apache/spark/storage/DiskBlockManager.scala | 1 - .../org/apache/spark/ContextCleanerSuite.scala | 17 ++++++++--------- .../spark/storage/BlockManagerSuite.scala | 14 +++++++------- .../apache/spark/util/JsonProtocolSuite.scala | 2 +- 7 files changed, 24 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 02158afa972a2..51399bb980fcd 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -43,10 +43,8 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea val blockId = BroadcastBlockId(id) /* - * Broadcasted data is also stored in the BlockManager of the driver. - * The BlockManagerMaster - * does not need to be told about this block as not only - * need to know about this data block. + * Broadcasted data is also stored in the BlockManager of the driver. The BlockManagerMaster + * does not need to be told about this block as not only need to know about this data block. */ HttpBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index d939c5da96967..4191f4e4c71e4 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -193,7 +193,7 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log * updated block statuses. This is useful when the master is not informed of the given block * by all block managers. */ - def getMatcinghBlockIds( + def getMatchinghBlockIds( filter: BlockId => Boolean, askSlaves: Boolean): Seq[BlockId] = { val msg = GetMatchingBlockIds(filter, askSlaves) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index fc22f54ceb9d8..6d4db064dff58 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -39,18 +39,18 @@ class BlockManagerSlaveActor( // Operations that involve removing blocks may be slow and should be done asynchronously override def receive = { case RemoveBlock(blockId) => - doAsync[Boolean]("removing block", sender) { + doAsync[Boolean]("removing block " + blockId, sender) { blockManager.removeBlock(blockId) true } case RemoveRdd(rddId) => - doAsync[Int]("removing RDD", sender) { + doAsync[Int]("removing RDD " + rddId, sender) { blockManager.removeRdd(rddId) } case RemoveShuffle(shuffleId) => - doAsync[Boolean]("removing shuffle", sender) { + doAsync[Boolean]("removing shuffle " + shuffleId, sender) { if (mapOutputTracker != null) { mapOutputTracker.unregisterShuffle(shuffleId) } @@ -58,7 +58,7 @@ class BlockManagerSlaveActor( } case RemoveBroadcast(broadcastId, tellMaster) => - doAsync[Int]("removing RDD", sender) { + doAsync[Int]("removing broadcast " + broadcastId, sender) { blockManager.removeBroadcast(broadcastId, tellMaster) } @@ -72,8 +72,7 @@ class BlockManagerSlaveActor( private def doAsync[T](actionMessage: String, responseActor: ActorRef)(body: => T) { val future = Future { logDebug(actionMessage) - val response = body - response + body } future.onSuccess { case response => logDebug("Done " + actionMessage + ", response is " + response) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 47a1a6d4a5869..866c52150a48e 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -47,7 +47,6 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) private var shuffleSender : ShuffleSender = null - addShutdownHook() /** diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 345bee6930c49..c828e4ebcf924 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -166,8 +166,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo sc.stop() val conf2 = new SparkConf() - .setMaster("local[4]") - //.setMaster("local-cluster[2, 1, 512]") + .setMaster("local-cluster[2, 1, 512]") .setAppName("ContextCleanerSuite") .set("spark.cleaner.referenceTracking.blocking", "true") sc = new SparkContext(conf2) @@ -180,7 +179,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo val shuffleIds = 0 until sc.newShuffleId val broadcastIds = 0L until numBroadcasts - val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) + val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) runGC() intercept[Exception] { preGCTester.assertCleanup()(timeout(1000 millis)) @@ -391,22 +390,22 @@ class CleanerTester( toBeCleanedBroadcstIds.isEmpty private def getRDDBlocks(rddId: Int): Seq[BlockId] = { - blockManager.master.getMatcinghBlockIds( _ match { - case RDDBlockId(rddId, _) => true + blockManager.master.getMatchinghBlockIds( _ match { + case RDDBlockId(`rddId`, _) => true case _ => false }, askSlaves = true) } private def getShuffleBlocks(shuffleId: Int): Seq[BlockId] = { - blockManager.master.getMatcinghBlockIds( _ match { - case ShuffleBlockId(shuffleId, _, _) => true + blockManager.master.getMatchinghBlockIds( _ match { + case ShuffleBlockId(`shuffleId`, _, _) => true case _ => false }, askSlaves = true) } private def getBroadcastBlocks(broadcastId: Long): Seq[BlockId] = { - blockManager.master.getMatcinghBlockIds( _ match { - case BroadcastBlockId(broadcastId, _) => true + blockManager.master.getMatchinghBlockIds( _ match { + case BroadcastBlockId(`broadcastId`, _) => true case _ => false }, askSlaves = true) } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 970b4f70ee6d7..fb1920bd47fb1 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -862,29 +862,29 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT securityMgr, mapOutputTracker) val list = List.fill(2)(new Array[Byte](10)) - // Tell master. By LRU, only list2 and list3 remains. + // insert some blocks store.put("list1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) store.put("list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) store.put("list3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) // getLocations and getBlockStatus should yield the same locations - assert(store.master.getMatcinghBlockIds(_.toString.contains("list"), askSlaves = false).size === 3) - assert(store.master.getMatcinghBlockIds(_.toString.contains("list1"), askSlaves = false).size === 1) + assert(store.master.getMatchinghBlockIds(_.toString.contains("list"), askSlaves = false).size === 3) + assert(store.master.getMatchinghBlockIds(_.toString.contains("list1"), askSlaves = false).size === 1) - // Tell master. By LRU, only list2 and list3 remains. + // insert some more blocks store.put("newlist1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) store.put("newlist2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) store.put("newlist3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) // getLocations and getBlockStatus should yield the same locations - assert(store.master.getMatcinghBlockIds(_.toString.contains("newlist"), askSlaves = false).size === 1) - assert(store.master.getMatcinghBlockIds(_.toString.contains("newlist"), askSlaves = true).size === 3) + assert(store.master.getMatchinghBlockIds(_.toString.contains("newlist"), askSlaves = false).size === 1) + assert(store.master.getMatchinghBlockIds(_.toString.contains("newlist"), askSlaves = true).size === 3) val blockIds = Seq(RDDBlockId(1, 0), RDDBlockId(1, 1), RDDBlockId(2, 0)) blockIds.foreach { blockId => store.put(blockId, list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) } - val matchedBlockIds = store.master.getMatcinghBlockIds(_ match { + val matchedBlockIds = store.master.getMatchinghBlockIds(_ match { case RDDBlockId(1, _) => true case _ => false }, askSlaves = true) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 7af43f0cc5276..79075a7eb847c 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -108,7 +108,7 @@ class JsonProtocolSuite extends FunSuite { // BlockId testBlockId(RDDBlockId(1, 2)) testBlockId(ShuffleBlockId(1, 2, 3)) - testBlockId(BroadcastBlockId(1L, "")) + testBlockId(BroadcastBlockId(1L, "insert_words_of_wisdom_here")) testBlockId(TaskResultBlockId(1L)) testBlockId(StreamBlockId(1, 2L)) } From d25a86e7ae61a400764cb5f9c47f55f42f32220f Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 7 Apr 2014 16:11:19 -0700 Subject: [PATCH 36/37] Fixed stupid typo. --- .../org/apache/spark/storage/BlockManagerMaster.scala | 2 +- .../scala/org/apache/spark/ContextCleanerSuite.scala | 6 +++--- .../org/apache/spark/storage/BlockManagerSuite.scala | 10 +++++----- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 4191f4e4c71e4..497a0f6eb5c1d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -193,7 +193,7 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log * updated block statuses. This is useful when the master is not informed of the given block * by all block managers. */ - def getMatchinghBlockIds( + def getMatchingBlockIds( filter: BlockId => Boolean, askSlaves: Boolean): Seq[BlockId] = { val msg = GetMatchingBlockIds(filter, askSlaves) diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index c828e4ebcf924..5b6120d965c5c 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -390,21 +390,21 @@ class CleanerTester( toBeCleanedBroadcstIds.isEmpty private def getRDDBlocks(rddId: Int): Seq[BlockId] = { - blockManager.master.getMatchinghBlockIds( _ match { + blockManager.master.getMatchingBlockIds( _ match { case RDDBlockId(`rddId`, _) => true case _ => false }, askSlaves = true) } private def getShuffleBlocks(shuffleId: Int): Seq[BlockId] = { - blockManager.master.getMatchinghBlockIds( _ match { + blockManager.master.getMatchingBlockIds( _ match { case ShuffleBlockId(`shuffleId`, _, _) => true case _ => false }, askSlaves = true) } private def getBroadcastBlocks(broadcastId: Long): Seq[BlockId] = { - blockManager.master.getMatchinghBlockIds( _ match { + blockManager.master.getMatchingBlockIds( _ match { case BroadcastBlockId(`broadcastId`, _) => true case _ => false }, askSlaves = true) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index fb1920bd47fb1..9aaf3601e430e 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -868,8 +868,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT store.put("list3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) // getLocations and getBlockStatus should yield the same locations - assert(store.master.getMatchinghBlockIds(_.toString.contains("list"), askSlaves = false).size === 3) - assert(store.master.getMatchinghBlockIds(_.toString.contains("list1"), askSlaves = false).size === 1) + assert(store.master.getMatchingBlockIds(_.toString.contains("list"), askSlaves = false).size === 3) + assert(store.master.getMatchingBlockIds(_.toString.contains("list1"), askSlaves = false).size === 1) // insert some more blocks store.put("newlist1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) @@ -877,14 +877,14 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT store.put("newlist3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) // getLocations and getBlockStatus should yield the same locations - assert(store.master.getMatchinghBlockIds(_.toString.contains("newlist"), askSlaves = false).size === 1) - assert(store.master.getMatchinghBlockIds(_.toString.contains("newlist"), askSlaves = true).size === 3) + assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = false).size === 1) + assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = true).size === 3) val blockIds = Seq(RDDBlockId(1, 0), RDDBlockId(1, 1), RDDBlockId(2, 0)) blockIds.foreach { blockId => store.put(blockId, list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) } - val matchedBlockIds = store.master.getMatchinghBlockIds(_ match { + val matchedBlockIds = store.master.getMatchingBlockIds(_ match { case RDDBlockId(1, _) => true case _ => false }, askSlaves = true) From 61b8d6e2b37d7b0e200e9a506c24d18c961f8d73 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 7 Apr 2014 21:43:48 -0700 Subject: [PATCH 37/37] Fixed issue with Tachyon + new BlockManager methods. --- .../org/apache/spark/ContextCleaner.scala | 21 +++++-------------- .../scala/org/apache/spark/SparkContext.scala | 8 ++++--- .../spark/broadcast/BroadcastFactory.scala | 6 +++--- .../spark/broadcast/HttpBroadcast.scala | 8 +++---- .../apache/spark/storage/BlockManager.scala | 9 +++++--- .../spark/storage/DiskBlockManager.scala | 2 +- .../spark/util/TimeStampedHashMap.scala | 2 -- .../apache/spark/ContextCleanerSuite.scala | 6 +++--- 8 files changed, 26 insertions(+), 36 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 250d9d55c6211..54e08d7866f75 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -65,7 +65,8 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { /** * Whether the cleaning thread will block on cleanup tasks. - * This is set to true only for tests. */ + * This is set to true only for tests. + */ private val blockOnCleanupTasks = sc.conf.getBoolean( "spark.cleaner.referenceTracking.blocking", false) @@ -133,7 +134,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** Perform RDD cleanup. */ - private def doCleanupRDD(rddId: Int, blocking: Boolean) { + def doCleanupRDD(rddId: Int, blocking: Boolean) { try { logDebug("Cleaning RDD " + rddId) sc.unpersistRDD(rddId, blocking) @@ -145,7 +146,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** Perform shuffle cleanup, asynchronously. */ - private def doCleanupShuffle(shuffleId: Int, blocking: Boolean) { + def doCleanupShuffle(shuffleId: Int, blocking: Boolean) { try { logDebug("Cleaning shuffle " + shuffleId) mapOutputTrackerMaster.unregisterShuffle(shuffleId) @@ -158,7 +159,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** Perform broadcast cleanup. */ - private def doCleanupBroadcast(broadcastId: Long, blocking: Boolean) { + def doCleanupBroadcast(broadcastId: Long, blocking: Boolean) { try { logDebug("Cleaning broadcast " + broadcastId) broadcastManager.unbroadcast(broadcastId, true, blocking) @@ -175,18 +176,6 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { // Used for testing. These methods explicitly blocks until cleanup is completed // to ensure that more reliable testing. - - def cleanupRDD(rdd: RDD[_]) { - doCleanupRDD(rdd.id, blocking = true) - } - - def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) { - doCleanupShuffle(shuffleDependency.shuffleId, blocking = true) - } - - def cleanupBroadcast[T](broadcast: Broadcast[T]) { - doCleanupBroadcast(broadcast.id, blocking = true) - } } private object ContextCleaner { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index dfc173357c12b..d7124616d3bfb 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -233,11 +233,13 @@ class SparkContext( @volatile private[spark] var dagScheduler = new DAGScheduler(this) dagScheduler.start() - private[spark] val cleaner: Option[ContextCleaner] = + private[spark] val cleaner: Option[ContextCleaner] = { if (conf.getBoolean("spark.cleaner.referenceTracking", true)) { Some(new ContextCleaner(this)) - } else None - + } else { + None + } + } cleaner.foreach(_.start()) postEnvironmentUpdate() diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index a7867bcaabfc2..c7f7c59cfb449 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -27,8 +27,8 @@ import org.apache.spark.SparkConf * entire Spark job. */ trait BroadcastFactory { - def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] - def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) - def stop() + def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit + def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 51399bb980fcd..f6a8a8af91e4b 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -86,11 +86,9 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea val start = System.nanoTime value_ = HttpBroadcast.read[T](id) /* - * Storing the broadcast data in BlockManager so that all - * so that all subsequent tasks using the broadcast variable - * does not need to fetch it again. The BlockManagerMaster - * does not need to be told about this block as no one - * needs to know about this data block. + * We cache broadcast data in the BlockManager so that subsequent tasks using it + * do not need to re-fetch. This data is only used locally and no other node + * needs to fetch this block, so we don't notify the master. */ SparkEnv.get.blockManager.putSingle( blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index a9e3e48767b1b..b021564477c47 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -220,13 +220,16 @@ private[spark] class BlockManager( } } - /** Get the BlockStatus for the block identified by the given ID, if it exists. */ + /** + * Get the BlockStatus for the block identified by the given ID, if it exists. + * NOTE: This is mainly for testing, and it doesn't fetch information from Tachyon. + */ def getStatus(blockId: BlockId): Option[BlockStatus] = { blockInfo.get(blockId).map { info => val memSize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L val diskSize = if (diskStore.contains(blockId)) diskStore.getSize(blockId) else 0L - val tachyonSize = if (tachyonStore.contains(blockId)) tachyonStore.getSize(blockId) else 0L - BlockStatus(info.level, memSize, diskSize, tachyonSize) + // Assume that block is not in Tachyon + BlockStatus(info.level, memSize, diskSize, 0L) } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 866c52150a48e..7a24c8f57f43b 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -95,7 +95,7 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD getBlockLocation(blockId).file.exists() } - /** List all the blocks currently stored in disk by the disk manager. */ + /** List all the blocks currently stored on disk by the disk manager. */ def getAllBlocks(): Seq[BlockId] = { // Get all the files inside the array of array of directories subDirs.flatten.filter(_ != null).flatMap { dir => diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala index 5c239329588d8..8de75ba9a9c92 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -108,7 +108,6 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa } } - // Should we return previous value directly or as Option? def putIfAbsent(key: A, value: B): Option[B] = { val prev = internalMap.putIfAbsent(key, TimeStampedValue(value, currentTime)) Option(prev).map(_.value) @@ -148,5 +147,4 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa def getTimestamp(key: A): Option[Long] = { getTimeStampedValue(key).map(_.timestamp) } - } diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 5b6120d965c5c..e50981cf6fb20 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -57,7 +57,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo val tester = new CleanerTester(sc, rddIds = Seq(rdd.id)) // Explicit cleanup - cleaner.cleanupRDD(rdd) + cleaner.doCleanupRDD(rdd.id, blocking = true) tester.assertCleanup() // Verify that RDDs can be re-executed after cleaning up @@ -70,7 +70,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo val tester = new CleanerTester(sc, shuffleIds = shuffleDeps.map(_.shuffleId)) // Explicit cleanup - shuffleDeps.foreach(s => cleaner.cleanupShuffle(s)) + shuffleDeps.foreach(s => cleaner.doCleanupShuffle(s.shuffleId, blocking = true)) tester.assertCleanup() // Verify that shuffles can be re-executed after cleaning up @@ -82,7 +82,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo val tester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) // Explicit cleanup - cleaner.cleanupBroadcast(broadcast) + cleaner.doCleanupBroadcast(broadcast.id, blocking = true) tester.assertCleanup() }