Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 100 additions & 4 deletions core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ import java.io.{IOException, ObjectOutputStream}
import scala.reflect.ClassTag

import org.apache.spark._
import org.apache.spark.util.Utils
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
import org.apache.spark.util.{CompletionIterator, Utils}

private[spark]
class CartesianPartition(
Expand All @@ -49,7 +50,8 @@ private[spark]
class CartesianRDD[T: ClassTag, U: ClassTag](
sc: SparkContext,
var rdd1 : RDD[T],
var rdd2 : RDD[U])
var rdd2 : RDD[U],
val cacheFetchedInLocal: Boolean = false)
extends RDD[(T, U)](sc, Nil)
with Serializable {

Expand All @@ -71,9 +73,103 @@ class CartesianRDD[T: ClassTag, U: ClassTag](
}

override def compute(split: Partition, context: TaskContext): Iterator[(T, U)] = {
val blockManager = SparkEnv.get.blockManager
val currSplit = split.asInstanceOf[CartesianPartition]
for (x <- rdd1.iterator(currSplit.s1, context);
y <- rdd2.iterator(currSplit.s2, context)) yield (x, y)
val blockId2 = RDDBlockId(rdd2.id, currSplit.s2.index)
var cachedInLocal = false
var holdReadLock = false

// Try to get data from the local, otherwise it will be cached to the local if user set
// cacheFetchedInLocal as true.
def getOrElseCache(
rdd: RDD[U],
partition: Partition,
context: TaskContext,
level: StorageLevel): Iterator[U] = {
getLocalValues() match {
case Some(result) =>
return result
case None => if (holdReadLock) {
blockManager.releaseLock(blockId2)
throw new SparkException(s"get() failed for block $blockId2 even though we held a lock")
}
}

val iterator = rdd.iterator(partition, context)
val status = blockManager.getStatus(blockId2)
if (!cacheFetchedInLocal || (status.isDefined && status.get.storageLevel.isValid)) {
// If user don't want cache the block fetched from remotely, just return it.
// Or if the block is cached in local, wo shouldn't cache it again.
return iterator
}

// Keep read lock, because next we need read it. And don't tell master.
val putSuccess = blockManager.putIterator[U](blockId2, iterator, level, false, true)
if (putSuccess) {
cachedInLocal = true
// After we cached the block, we also hold the block read lock until this task finished.
holdReadLock = true
logInfo(s"Cache the block $blockId2 to local successful.")
val readLocalBlock = blockManager.getLocalValues(blockId2).getOrElse {
blockManager.releaseLock(blockId2)
throw new SparkException(s"get() failed for block $blockId2 even though we held a lock")
}

new InterruptibleIterator[U](context, readLocalBlock.data.asInstanceOf[Iterator[U]])
} else {
blockManager.releaseLock(blockId2)
// There shouldn't a error caused by put in memory, because we use MEMORY_AND_DISK to
// cache it.
throw new SparkException(s"Cache block $blockId2 in local failed even though it's $level")
}
}

// Get block from local, and update the metrics.
def getLocalValues(): Option[Iterator[U]] = {
blockManager.getLocalValues(blockId2) match {
case Some(result) =>
val existingMetrics = context.taskMetrics().inputMetrics
existingMetrics.incBytesRead(result.bytes)
val localIter =
new InterruptibleIterator[U](context, result.data.asInstanceOf[Iterator[U]]) {
override def next(): U = {
existingMetrics.incRecordsRead(1)
delegate.next()
}
}
Some(localIter)
case None =>
None
}
}

val resultIter =
for (x <- rdd1.iterator(currSplit.s1, context);
y <- getOrElseCache(rdd2, currSplit.s2, context, StorageLevel.MEMORY_AND_DISK))
yield (x, y)

CompletionIterator[(T, U), Iterator[(T, U)]](resultIter,
removeCachedBlock(blockId2, holdReadLock, cachedInLocal))
}

/**
* Remove the cached block. If we hold the read lock, we also need release it.
*/
def removeCachedBlock(
blockId: RDDBlockId,
holdReadLock: Boolean,
cachedInLocal: Boolean): Unit = {
val blockManager = SparkEnv.get.blockManager
if (holdReadLock) {
// If hold the read lock, we need release it.
blockManager.releaseLock(blockId)
}
// Whether the block it persisted by the user.
val persistedInLocal =
blockManager.master.getLocations(blockId).contains(blockManager.blockManagerId)
if (!persistedInLocal && (cachedInLocal || blockManager.isRemovable(blockId))) {
blockManager.removeOrMarkAsRemovable(blockId, false)
}
}

override def getDependencies: Seq[Dependency[_]] = List(
Expand Down
7 changes: 5 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -670,9 +670,12 @@ abstract class RDD[T: ClassTag](
/**
* Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of
* elements (a, b) where a is in `this` and b is in `other`.
*
* @param cacheFetchedInLocal Whether cache the remotely fetched block in local.
*/
def cartesian[U: ClassTag](other: RDD[U]): RDD[(T, U)] = withScope {
new CartesianRDD(sc, this, other)
def cartesian[U: ClassTag](other: RDD[U],
cacheFetchedInLocal: Boolean = false): RDD[(T, U)] = withScope {
new CartesianRDD(sc, this, other, cacheFetchedInLocal)
}

/**
Expand Down
43 changes: 41 additions & 2 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.storage
import java.io._
import java.nio.ByteBuffer
import java.nio.channels.Channels
import java.util.concurrent.ConcurrentHashMap

import scala.collection.mutable
import scala.collection.mutable.HashMap
Expand Down Expand Up @@ -199,6 +200,9 @@ private[spark] class BlockManager(

private var blockReplicationPolicy: BlockReplicationPolicy = _

// Record the removable block.
private lazy val removableBlocks = ConcurrentHashMap.newKeySet[BlockId]()

/**
* Initializes the BlockManager with the given appId. This is not performed in the constructor as
* the appId may not be known at BlockManager instantiation time (in particular for the driver,
Expand Down Expand Up @@ -784,9 +788,11 @@ private[spark] class BlockManager(
blockId: BlockId,
values: Iterator[T],
level: StorageLevel,
tellMaster: Boolean = true): Boolean = {
tellMaster: Boolean = true,
keepReadLock: Boolean = false): Boolean = {
require(values != null, "Values is null")
doPutIterator(blockId, () => values, level, implicitly[ClassTag[T]], tellMaster) match {
doPutIterator(blockId, () => values, level, implicitly[ClassTag[T]], tellMaster,
keepReadLock)match {
case None =>
true
case Some(iter) =>
Expand Down Expand Up @@ -1461,6 +1467,38 @@ private[spark] class BlockManager(
}
}

/**
* Whether the block is removable.
*/
def isRemovable(blockId: BlockId): Boolean = {
removableBlocks.contains(blockId)
}

/**
* Try to remove the block without blocking. Mark it as removable if it is in use.
*/
def removeOrMarkAsRemovable(blockId: BlockId, tellMaster: Boolean = true): Unit = {
// Try to lock for writing without blocking
blockInfoManager.lockForWriting(blockId, false) match {
case None =>
// Because lock in unblocking manner, so the block may not exist or be used by other task.
blockInfoManager.get(blockId) match {
case None =>
logWarning(s"Asked to remove block $blockId, which does not exist")
removableBlocks.remove(blockId)
case Some(_) =>
// The block is in use, mark it as removable
logDebug(s"Marking block $blockId as removable")
removableBlocks.add(blockId)
}
case Some(info) =>
logDebug(s"Removing block $blockId")
removeBlockInternal(blockId, tellMaster = tellMaster && info.tellMaster)
addUpdatedBlockStatusToTaskMetrics(blockId, BlockStatus.empty)
removableBlocks.remove(blockId)
}
}

private def addUpdatedBlockStatusToTaskMetrics(blockId: BlockId, status: BlockStatus): Unit = {
Option(TaskContext.get()).foreach { c =>
c.taskMetrics().incUpdatedBlockStatuses(blockId -> status)
Expand All @@ -1478,6 +1516,7 @@ private[spark] class BlockManager(
// Closing should be idempotent, but maybe not for the NioBlockTransferService.
shuffleClient.close()
}
removableBlocks.clear()
diskBlockManager.stop()
rpcEnv.stop(slaveEndpoint)
blockInfoManager.clear()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
package org.apache.spark.storage

import java.nio.ByteBuffer
import java.util.Properties

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration._
import scala.concurrent.Future
import scala.language.implicitConversions
import scala.language.postfixOps
import scala.reflect.ClassTag
Expand Down Expand Up @@ -67,6 +68,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
val mapOutputTracker = new MapOutputTrackerMaster(new SparkConf(false), bcastManager, true)
val shuffleManager = new SortShuffleManager(new SparkConf(false))

private implicit val ec = ExecutionContext.global

// Reuse a serializer across tests to avoid creating a new thread-local buffer on each test
val serializer = new KryoSerializer(new SparkConf(false).set("spark.kryoserializer.buffer", "1m"))

Expand Down Expand Up @@ -101,6 +104,16 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
blockManager
}

private def withTaskId[T](taskAttemptId: Long)(block: => T): T = {
try {
TaskContext.setTaskContext(
new TaskContextImpl(0, 0, taskAttemptId, 0, null, new Properties, null))
block
} finally {
TaskContext.unset()
}
}

override def beforeEach(): Unit = {
super.beforeEach()
// Set the arch to 64-bit and compressedOops to true to get a deterministic test-case
Expand Down Expand Up @@ -1280,6 +1293,40 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
assert(master.getLocations("item").isEmpty)
}

test("remote block without blocking") {
store = makeBlockManager(8000, "executor1")
val arr = new Array[Byte](4000)
store.registerTask(0)
store.registerTask(1)
withTaskId(0) {
store.putSingle("block", arr, StorageLevel.MEMORY_AND_DISK, true)
// lock the block with read lock
store.get("block")
}
val future = Future {
withTaskId(1) {
// block is in use, mark it as removable
store.removeOrMarkAsRemovable("block")
store.isRemovable("block")
}
}
Thread.sleep(300)
assert(store.getStatus("block").isDefined, "block should not be removed")
assert(ThreadUtils.awaitResult(future, 1.seconds), "block should be marked as removable")
withTaskId(0) {
store.releaseLock("block")
}
val future1 = Future {
withTaskId(1) {
// remove it
store.removeOrMarkAsRemovable("block")
!store.isRemovable("block")
}
}
assert(ThreadUtils.awaitResult(future1, 1.seconds), "block should not be marked as removable")
assert(master.getLocations("block").isEmpty, "block should be removed")
}

class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService {
var numCalls = 0

Expand Down Expand Up @@ -1307,7 +1354,6 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
blockData: ManagedBuffer,
level: StorageLevel,
classTag: ClassTag[_]): Future[Unit] = {
import scala.concurrent.ExecutionContext.Implicits.global
Future {}
}

Expand Down