diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index 315ed91f81df3..3f667a4a0f9c5 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -20,25 +20,25 @@ package org.apache.spark import scala.collection.mutable.{ArrayBuffer, HashSet} import org.apache.spark.rdd.RDD -import org.apache.spark.storage.{BlockId, BlockManager, BlockStatus, RDDBlockId, StorageLevel} +import org.apache.spark.storage._ /** - * Spark class responsible for passing RDDs split contents to the BlockManager and making + * Spark class responsible for passing RDDs partition contents to the BlockManager and making * sure a node doesn't load two copies of an RDD at once. */ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { - /** Keys of RDD splits that are being computed/loaded. */ + /** Keys of RDD partitions that are being computed/loaded. */ private val loading = new HashSet[RDDBlockId]() - /** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */ + /** Gets or computes an RDD partition. Used by RDD.iterator() when an RDD is cached. */ def getOrCompute[T]( rdd: RDD[T], - split: Partition, + partition: Partition, context: TaskContext, storageLevel: StorageLevel): Iterator[T] = { - val key = RDDBlockId(rdd.id, split.index) + val key = RDDBlockId(rdd.id, partition.index) logDebug(s"Looking for partition $key") blockManager.get(key) match { case Some(values) => @@ -46,79 +46,28 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]]) case None => - // Mark the split as loading (unless someone else marks it first) - loading.synchronized { - if (loading.contains(key)) { - logInfo(s"Another thread is loading $key, waiting for it to finish...") - while (loading.contains(key)) { - try { - loading.wait() - } catch { - case e: Exception => - logWarning(s"Got an exception while waiting for another thread to load $key", e) - } - } - logInfo(s"Finished waiting for $key") - /* See whether someone else has successfully loaded it. The main way this would fail - * is for the RDD-level cache eviction policy if someone else has loaded the same RDD - * partition but we didn't want to make space for it. However, that case is unlikely - * because it's unlikely that two threads would work on the same RDD partition. One - * downside of the current code is that threads wait serially if this does happen. */ - blockManager.get(key) match { - case Some(values) => - return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]]) - case None => - logInfo(s"Whoever was loading $key failed; we'll try it ourselves") - loading.add(key) - } - } else { - loading.add(key) - } + // Acquire a lock for loading this partition + // If another thread already holds the lock, wait for it to finish return its results + val storedValues = acquireLockForPartition[T](key) + if (storedValues.isDefined) { + return new InterruptibleIterator[T](context, storedValues.get) } + + // Otherwise, we have to load the partition ourselves try { - // If we got here, we have to load the split logInfo(s"Partition $key not found, computing it") - val computedValues = rdd.computeOrReadCheckpoint(split, context) + val computedValues = rdd.computeOrReadCheckpoint(partition, context) - // Persist the result, so long as the task is not running locally + // If the task is running locally, do not persist the result if (context.runningLocally) { return computedValues } - // Keep track of blocks with updated statuses - var updatedBlocks = Seq[(BlockId, BlockStatus)]() - val returnValue: Iterator[T] = { - if (storageLevel.useDisk && !storageLevel.useMemory) { - /* In the case that this RDD is to be persisted using DISK_ONLY - * the iterator will be passed directly to the blockManager (rather then - * caching it to an ArrayBuffer first), then the resulting block data iterator - * will be passed back to the user. If the iterator generates a lot of data, - * this means that it doesn't all have to be held in memory at one time. - * This could also apply to MEMORY_ONLY_SER storage, but we need to make sure - * blocks aren't dropped by the block store before enabling that. */ - updatedBlocks = blockManager.put(key, computedValues, storageLevel, tellMaster = true) - blockManager.get(key) match { - case Some(values) => - values.asInstanceOf[Iterator[T]] - case None => - logInfo(s"Failure to store $key") - throw new SparkException("Block manager failed to return persisted value") - } - } else { - // In this case the RDD is cached to an array buffer. This will save the results - // if we're dealing with a 'one-time' iterator - val elements = new ArrayBuffer[Any] - elements ++= computedValues - updatedBlocks = blockManager.put(key, elements, storageLevel, tellMaster = true) - elements.iterator.asInstanceOf[Iterator[T]] - } - } - - // Update task metrics to include any blocks whose storage status is updated - val metrics = context.taskMetrics - metrics.updatedBlocks = Some(updatedBlocks) - - new InterruptibleIterator(context, returnValue) + // Otherwise, cache the values and keep track of any updates in block statuses + val updatedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] + val cachedValues = putInBlockManager(key, computedValues, storageLevel, updatedBlocks) + context.taskMetrics.updatedBlocks = Some(updatedBlocks) + new InterruptibleIterator(context, cachedValues) } finally { loading.synchronized { @@ -128,4 +77,76 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { } } } + + /** + * Acquire a loading lock for the partition identified by the given block ID. + * + * If the lock is free, just acquire it and return None. Otherwise, another thread is already + * loading the partition, so we wait for it to finish and return the values loaded by the thread. + */ + private def acquireLockForPartition[T](id: RDDBlockId): Option[Iterator[T]] = { + loading.synchronized { + if (!loading.contains(id)) { + // If the partition is free, acquire its lock to compute its value + loading.add(id) + None + } else { + // Otherwise, wait for another thread to finish and return its result + logInfo(s"Another thread is loading $id, waiting for it to finish...") + while (loading.contains(id)) { + try { + loading.wait() + } catch { + case e: Exception => + logWarning(s"Exception while waiting for another thread to load $id", e) + } + } + logInfo(s"Finished waiting for $id") + val values = blockManager.get(id) + if (!values.isDefined) { + /* The block is not guaranteed to exist even after the other thread has finished. + * For instance, the block could be evicted after it was put, but before our get. + * In this case, we still need to load the partition ourselves. */ + logInfo(s"Whoever was loading $id failed; we'll try it ourselves") + loading.add(id) + } + values.map(_.asInstanceOf[Iterator[T]]) + } + } + } + + /** + * Cache the values of a partition, keeping track of any updates in the storage statuses + * of other blocks along the way. + */ + private def putInBlockManager[T]( + key: BlockId, + values: Iterator[T], + storageLevel: StorageLevel, + updatedBlocks: ArrayBuffer[(BlockId, BlockStatus)]): Iterator[T] = { + + if (!storageLevel.useMemory) { + /* This RDD is not to be cached in memory, so we can just pass the computed values + * as an iterator directly to the BlockManager, rather than first fully unrolling + * it in memory. The latter option potentially uses much more memory and risks OOM + * exceptions that can be avoided. */ + updatedBlocks ++= blockManager.put(key, values, storageLevel, tellMaster = true) + blockManager.get(key) match { + case Some(v) => v.asInstanceOf[Iterator[T]] + case None => + logInfo(s"Failure to store $key") + throw new BlockException(key, s"Block manager failed to return cached value for $key!") + } + } else { + /* This RDD is to be cached in memory. In this case we cannot pass the computed values + * to the BlockManager as an iterator and expect to read it back later. This is because + * we may end up dropping a partition from memory store before getting it back, e.g. + * when the entirety of the RDD does not fit in memory. */ + val elements = new ArrayBuffer[Any] + elements ++= values + updatedBlocks ++= blockManager.put(key, elements, storageLevel, tellMaster = true) + elements.iterator.asInstanceOf[Iterator[T]] + } + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 0e8d551e4b2ab..bbf9f7388b074 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -17,11 +17,12 @@ package org.apache.spark.scheduler +import scala.language.existentials + import java.io._ import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.mutable.HashMap -import scala.language.existentials import org.apache.spark._ import org.apache.spark.rdd.{RDD, RDDCheckpointData} diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 0098b5a59d1a5..859cdc524a581 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -25,10 +25,7 @@ import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.mutable.HashMap import org.apache.spark._ -import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.rdd.{RDD, RDDCheckpointData} -import org.apache.spark.serializer.Serializer -import org.apache.spark.storage._ import org.apache.spark.shuffle.ShuffleWriter private[spark] object ShuffleMapTask { @@ -150,7 +147,7 @@ private[spark] class ShuffleMapTask( for (elem <- rdd.iterator(split, context)) { writer.write(elem.asInstanceOf[Product2[Any, Any]]) } - return writer.stop(success = true).get + writer.stop(success = true).get } catch { case e: Exception => if (writer != null) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala index 9a9be047c7245..b9b53b1a2f118 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala @@ -24,11 +24,11 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.Logging /** - * Abstract class to store blocks + * Abstract class to store blocks. */ -private[spark] -abstract class BlockStore(val blockManager: BlockManager) extends Logging { - def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel) : PutResult +private[spark] abstract class BlockStore(val blockManager: BlockManager) extends Logging { + + def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel): PutResult /** * Put in a block and, possibly, also return its content as either bytes or another Iterator. @@ -37,11 +37,17 @@ abstract class BlockStore(val blockManager: BlockManager) extends Logging { * @return a PutResult that contains the size of the data, as well as the values put if * returnValues is true (if not, the result's data field can be null) */ - def putValues(blockId: BlockId, values: Iterator[Any], level: StorageLevel, - returnValues: Boolean) : PutResult + def putValues( + blockId: BlockId, + values: Iterator[Any], + level: StorageLevel, + returnValues: Boolean): PutResult - def putValues(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel, - returnValues: Boolean) : PutResult + def putValues( + blockId: BlockId, + values: ArrayBuffer[Any], + level: StorageLevel, + returnValues: Boolean): PutResult /** * Return the size of a block in bytes. diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 084a566c48560..71f66c826c5b3 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -58,11 +58,11 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) val elements = new ArrayBuffer[Any] elements ++= values val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef]) - tryToPut(blockId, elements, sizeEstimate, true) - PutResult(sizeEstimate, Left(values.toIterator)) + val putAttempt = tryToPut(blockId, elements, sizeEstimate, deserialized = true) + PutResult(sizeEstimate, Left(values.toIterator), putAttempt.droppedBlocks) } else { - tryToPut(blockId, bytes, bytes.limit, false) - PutResult(bytes.limit(), Right(bytes.duplicate())) + val putAttempt = tryToPut(blockId, bytes, bytes.limit, deserialized = false) + PutResult(bytes.limit(), Right(bytes.duplicate()), putAttempt.droppedBlocks) } }