diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala index 57108dcedcf0c..88e4cefdace31 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala @@ -22,7 +22,8 @@ import java.io.{IOException, ObjectOutputStream} import scala.reflect.ClassTag import org.apache.spark._ -import org.apache.spark.util.Utils +import org.apache.spark.storage.{RDDBlockId, StorageLevel} +import org.apache.spark.util.{CompletionIterator, Utils} private[spark] class CartesianPartition( @@ -49,7 +50,8 @@ private[spark] class CartesianRDD[T: ClassTag, U: ClassTag]( sc: SparkContext, var rdd1 : RDD[T], - var rdd2 : RDD[U]) + var rdd2 : RDD[U], + val cacheFetchedInLocal: Boolean = false) extends RDD[(T, U)](sc, Nil) with Serializable { @@ -71,9 +73,103 @@ class CartesianRDD[T: ClassTag, U: ClassTag]( } override def compute(split: Partition, context: TaskContext): Iterator[(T, U)] = { + val blockManager = SparkEnv.get.blockManager val currSplit = split.asInstanceOf[CartesianPartition] - for (x <- rdd1.iterator(currSplit.s1, context); - y <- rdd2.iterator(currSplit.s2, context)) yield (x, y) + val blockId2 = RDDBlockId(rdd2.id, currSplit.s2.index) + var cachedInLocal = false + var holdReadLock = false + + // Try to get data from the local, otherwise it will be cached to the local if user set + // cacheFetchedInLocal as true. + def getOrElseCache( + rdd: RDD[U], + partition: Partition, + context: TaskContext, + level: StorageLevel): Iterator[U] = { + getLocalValues() match { + case Some(result) => + return result + case None => if (holdReadLock) { + blockManager.releaseLock(blockId2) + throw new SparkException(s"get() failed for block $blockId2 even though we held a lock") + } + } + + val iterator = rdd.iterator(partition, context) + val status = blockManager.getStatus(blockId2) + if (!cacheFetchedInLocal || (status.isDefined && status.get.storageLevel.isValid)) { + // If user don't want cache the block fetched from remotely, just return it. + // Or if the block is cached in local, wo shouldn't cache it again. + return iterator + } + + // Keep read lock, because next we need read it. And don't tell master. + val putSuccess = blockManager.putIterator[U](blockId2, iterator, level, false, true) + if (putSuccess) { + cachedInLocal = true + // After we cached the block, we also hold the block read lock until this task finished. + holdReadLock = true + logInfo(s"Cache the block $blockId2 to local successful.") + val readLocalBlock = blockManager.getLocalValues(blockId2).getOrElse { + blockManager.releaseLock(blockId2) + throw new SparkException(s"get() failed for block $blockId2 even though we held a lock") + } + + new InterruptibleIterator[U](context, readLocalBlock.data.asInstanceOf[Iterator[U]]) + } else { + blockManager.releaseLock(blockId2) + // There shouldn't a error caused by put in memory, because we use MEMORY_AND_DISK to + // cache it. + throw new SparkException(s"Cache block $blockId2 in local failed even though it's $level") + } + } + + // Get block from local, and update the metrics. + def getLocalValues(): Option[Iterator[U]] = { + blockManager.getLocalValues(blockId2) match { + case Some(result) => + val existingMetrics = context.taskMetrics().inputMetrics + existingMetrics.incBytesRead(result.bytes) + val localIter = + new InterruptibleIterator[U](context, result.data.asInstanceOf[Iterator[U]]) { + override def next(): U = { + existingMetrics.incRecordsRead(1) + delegate.next() + } + } + Some(localIter) + case None => + None + } + } + + val resultIter = + for (x <- rdd1.iterator(currSplit.s1, context); + y <- getOrElseCache(rdd2, currSplit.s2, context, StorageLevel.MEMORY_AND_DISK)) + yield (x, y) + + CompletionIterator[(T, U), Iterator[(T, U)]](resultIter, + removeCachedBlock(blockId2, holdReadLock, cachedInLocal)) + } + + /** + * Remove the cached block. If we hold the read lock, we also need release it. + */ + def removeCachedBlock( + blockId: RDDBlockId, + holdReadLock: Boolean, + cachedInLocal: Boolean): Unit = { + val blockManager = SparkEnv.get.blockManager + if (holdReadLock) { + // If hold the read lock, we need release it. + blockManager.releaseLock(blockId) + } + // Whether the block it persisted by the user. + val persistedInLocal = + blockManager.master.getLocations(blockId).contains(blockManager.blockManagerId) + if (!persistedInLocal && (cachedInLocal || blockManager.isRemovable(blockId))) { + blockManager.removeOrMarkAsRemovable(blockId, false) + } } override def getDependencies: Seq[Dependency[_]] = List( diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 63a87e7f09d85..3c64bf5e8b0dd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -670,9 +670,12 @@ abstract class RDD[T: ClassTag]( /** * Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of * elements (a, b) where a is in `this` and b is in `other`. + * + * @param cacheFetchedInLocal Whether cache the remotely fetched block in local. */ - def cartesian[U: ClassTag](other: RDD[U]): RDD[(T, U)] = withScope { - new CartesianRDD(sc, this, other) + def cartesian[U: ClassTag](other: RDD[U], + cacheFetchedInLocal: Boolean = false): RDD[(T, U)] = withScope { + new CartesianRDD(sc, this, other, cacheFetchedInLocal) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 137d24b525155..da4c8a8705373 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -20,6 +20,7 @@ package org.apache.spark.storage import java.io._ import java.nio.ByteBuffer import java.nio.channels.Channels +import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable import scala.collection.mutable.HashMap @@ -199,6 +200,9 @@ private[spark] class BlockManager( private var blockReplicationPolicy: BlockReplicationPolicy = _ + // Record the removable block. + private lazy val removableBlocks = ConcurrentHashMap.newKeySet[BlockId]() + /** * Initializes the BlockManager with the given appId. This is not performed in the constructor as * the appId may not be known at BlockManager instantiation time (in particular for the driver, @@ -784,9 +788,11 @@ private[spark] class BlockManager( blockId: BlockId, values: Iterator[T], level: StorageLevel, - tellMaster: Boolean = true): Boolean = { + tellMaster: Boolean = true, + keepReadLock: Boolean = false): Boolean = { require(values != null, "Values is null") - doPutIterator(blockId, () => values, level, implicitly[ClassTag[T]], tellMaster) match { + doPutIterator(blockId, () => values, level, implicitly[ClassTag[T]], tellMaster, + keepReadLock)match { case None => true case Some(iter) => @@ -1461,6 +1467,38 @@ private[spark] class BlockManager( } } + /** + * Whether the block is removable. + */ + def isRemovable(blockId: BlockId): Boolean = { + removableBlocks.contains(blockId) + } + + /** + * Try to remove the block without blocking. Mark it as removable if it is in use. + */ + def removeOrMarkAsRemovable(blockId: BlockId, tellMaster: Boolean = true): Unit = { + // Try to lock for writing without blocking + blockInfoManager.lockForWriting(blockId, false) match { + case None => + // Because lock in unblocking manner, so the block may not exist or be used by other task. + blockInfoManager.get(blockId) match { + case None => + logWarning(s"Asked to remove block $blockId, which does not exist") + removableBlocks.remove(blockId) + case Some(_) => + // The block is in use, mark it as removable + logDebug(s"Marking block $blockId as removable") + removableBlocks.add(blockId) + } + case Some(info) => + logDebug(s"Removing block $blockId") + removeBlockInternal(blockId, tellMaster = tellMaster && info.tellMaster) + addUpdatedBlockStatusToTaskMetrics(blockId, BlockStatus.empty) + removableBlocks.remove(blockId) + } + } + private def addUpdatedBlockStatusToTaskMetrics(blockId: BlockId, status: BlockStatus): Unit = { Option(TaskContext.get()).foreach { c => c.taskMetrics().incUpdatedBlockStatuses(blockId -> status) @@ -1478,6 +1516,7 @@ private[spark] class BlockManager( // Closing should be idempotent, but maybe not for the NioBlockTransferService. shuffleClient.close() } + removableBlocks.clear() diskBlockManager.stop() rpcEnv.stop(slaveEndpoint) blockInfoManager.clear() diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 1e7bcdb6740f6..10f43bf1845fe 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -18,10 +18,11 @@ package org.apache.spark.storage import java.nio.ByteBuffer +import java.util.Properties import scala.collection.mutable.ArrayBuffer +import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ -import scala.concurrent.Future import scala.language.implicitConversions import scala.language.postfixOps import scala.reflect.ClassTag @@ -67,6 +68,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val mapOutputTracker = new MapOutputTrackerMaster(new SparkConf(false), bcastManager, true) val shuffleManager = new SortShuffleManager(new SparkConf(false)) + private implicit val ec = ExecutionContext.global + // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test val serializer = new KryoSerializer(new SparkConf(false).set("spark.kryoserializer.buffer", "1m")) @@ -101,6 +104,16 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE blockManager } + private def withTaskId[T](taskAttemptId: Long)(block: => T): T = { + try { + TaskContext.setTaskContext( + new TaskContextImpl(0, 0, taskAttemptId, 0, null, new Properties, null)) + block + } finally { + TaskContext.unset() + } + } + override def beforeEach(): Unit = { super.beforeEach() // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case @@ -1280,6 +1293,40 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(master.getLocations("item").isEmpty) } + test("remote block without blocking") { + store = makeBlockManager(8000, "executor1") + val arr = new Array[Byte](4000) + store.registerTask(0) + store.registerTask(1) + withTaskId(0) { + store.putSingle("block", arr, StorageLevel.MEMORY_AND_DISK, true) + // lock the block with read lock + store.get("block") + } + val future = Future { + withTaskId(1) { + // block is in use, mark it as removable + store.removeOrMarkAsRemovable("block") + store.isRemovable("block") + } + } + Thread.sleep(300) + assert(store.getStatus("block").isDefined, "block should not be removed") + assert(ThreadUtils.awaitResult(future, 1.seconds), "block should be marked as removable") + withTaskId(0) { + store.releaseLock("block") + } + val future1 = Future { + withTaskId(1) { + // remove it + store.removeOrMarkAsRemovable("block") + !store.isRemovable("block") + } + } + assert(ThreadUtils.awaitResult(future1, 1.seconds), "block should not be marked as removable") + assert(master.getLocations("block").isEmpty, "block should be removed") + } + class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService { var numCalls = 0 @@ -1307,7 +1354,6 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE blockData: ManagedBuffer, level: StorageLevel, classTag: ClassTag[_]): Future[Unit] = { - import scala.concurrent.ExecutionContext.Implicits.global Future {} }