Skip to content

Commit a406079

Browse files
committed
[HOTFIX] Some clean-up in shuffle code.
Before diving into review apache#4450 I did a look through the existing shuffle code. Unfortunately, there are some very confusing things in this code. This patch makes a few small changes to simplify things. It is not easily to concisely describe the changes because of how convoluted the issues were: 1. There was a trait named ShuffleBlockManager that only deals with one logical function which is retrieving shuffle block data given shuffle block coordinates. This trait has two implementors FileShuffleBlockManager and IndexShuffleBlockManager. Confusingly the vast majority of those implementations have nothing to do with this particular functionality. So I've renamed the trait to ShuffleBlockResolver and documented it. 2. The aformentioned trait had two almost identical methods, for no good reason. I removed one method (getBytes) and modified callers to use the other one. I think the behavior is preserved in all cases. 3. The sort shuffle code uses an identifier "0" in the reduce slot of a BlockID as a placeholder. I made it into a constant since it needs to be consistent across multiple places. I think for (3) there is actually a better solution that would avoid the need to do this type of workaround/hack in the first place, but it's more complex so I'm punting it for now.
1 parent df35500 commit a406079

File tree

12 files changed

+47
-52
lines changed

12 files changed

+47
-52
lines changed

core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ private[spark] trait ShuffleWriterGroup {
6767
// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getHashBasedShuffleBlockData().
6868
private[spark]
6969
class FileShuffleBlockManager(conf: SparkConf)
70-
extends ShuffleBlockManager with Logging {
70+
extends ShuffleBlockResolver with Logging {
7171

7272
private val transportConf = SparkTransportConf.fromSparkConf(conf)
7373

@@ -175,11 +175,6 @@ class FileShuffleBlockManager(conf: SparkConf)
175175
}
176176
}
177177

178-
override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = {
179-
val segment = getBlockData(blockId)
180-
Some(segment.nioByteBuffer())
181-
}
182-
183178
override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
184179
if (consolidateShuffleFiles) {
185180
// Search all file groups associated with this shuffle.

core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
2727
import org.apache.spark.network.netty.SparkTransportConf
2828
import org.apache.spark.storage._
2929

30+
import IndexShuffleBlockManager.NOOP_REDUCE_ID
31+
3032
/**
3133
* Create and maintain the shuffle blocks' mapping between logic block and physical file location.
3234
* Data of shuffle blocks from the same map task are stored in a single consolidated data file.
@@ -39,25 +41,18 @@ import org.apache.spark.storage._
3941
// Note: Changes to the format in this file should be kept in sync with
4042
// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getSortBasedShuffleBlockData().
4143
private[spark]
42-
class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager {
44+
class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockResolver {
4345

4446
private lazy val blockManager = SparkEnv.get.blockManager
4547

4648
private val transportConf = SparkTransportConf.fromSparkConf(conf)
4749

48-
/**
49-
* Mapping to a single shuffleBlockId with reduce ID 0.
50-
* */
51-
def consolidateId(shuffleId: Int, mapId: Int): ShuffleBlockId = {
52-
ShuffleBlockId(shuffleId, mapId, 0)
53-
}
54-
5550
def getDataFile(shuffleId: Int, mapId: Int): File = {
56-
blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, 0))
51+
blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID))
5752
}
5853

5954
private def getIndexFile(shuffleId: Int, mapId: Int): File = {
60-
blockManager.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, 0))
55+
blockManager.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID))
6156
}
6257

6358
/**
@@ -97,10 +92,6 @@ class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager {
9792
}
9893
}
9994

100-
override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = {
101-
Some(getBlockData(blockId).nioByteBuffer())
102-
}
103-
10495
override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
10596
// The block is actually going to be a range of a single map output file for this map, so
10697
// find out the consolidated file, then the offset within that from our index
@@ -123,3 +114,10 @@ class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager {
123114

124115
override def stop(): Unit = {}
125116
}
117+
118+
private[spark] object IndexShuffleBlockManager {
119+
// No-op reduce ID used in interactions with disk store and BlockObjectWriter.
120+
// The disk store currently expects puts to relate to a (map, reduce) pair, but in the sort
121+
// shuffle outputs from a map for several
122+
val NOOP_REDUCE_ID = 0
123+
}

core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala renamed to core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,19 @@ import org.apache.spark.network.buffer.ManagedBuffer
2222
import org.apache.spark.storage.ShuffleBlockId
2323

2424
private[spark]
25-
trait ShuffleBlockManager {
25+
/**
26+
* Implementers of this trait understand how to retrieve block data for a logical shuffle block
27+
* identifier (i.e. map, reduce, and shuffle). Implementations may use files or file segments to
28+
* encapsulate shuffle data. This is used by the BlockStore to abstract over different shuffle
29+
* implementations when shuffle data is retrieved.
30+
*/
31+
trait ShuffleBlockResolver {
2632
type ShuffleId = Int
2733

2834
/**
29-
* Get shuffle block data managed by the local ShuffleBlockManager.
30-
* @return Some(ByteBuffer) if block found, otherwise None.
35+
* Retrieve the data for the specified block. If the data for that block is not available,
36+
* throws an unspecified exception.
3137
*/
32-
def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer]
33-
3438
def getBlockData(blockId: ShuffleBlockId): ManagedBuffer
3539

3640
def stop(): Unit

core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,10 @@ private[spark] trait ShuffleManager {
5555
*/
5656
def unregisterShuffle(shuffleId: Int): Boolean
5757

58-
def shuffleBlockManager: ShuffleBlockManager
58+
/**
59+
* Return a resolver capable of retrieving shuffle block data based on block coordinates.
60+
*/
61+
def shuffleBlockResolver: ShuffleBlockResolver
5962

6063
/** Shut down this ShuffleManager. */
6164
def stop(): Unit

core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.scheduler.MapStatus
2323
* Obtained inside a map task to write out records to the shuffle system.
2424
*/
2525
private[spark] trait ShuffleWriter[K, V] {
26-
/** Write a bunch of records to this task's output */
26+
/** Write a sequence of records to this task's output */
2727
def write(records: Iterator[_ <: Product2[K, V]]): Unit
2828

2929
/** Close this writer, passing along whether the map completed */

core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,20 +53,20 @@ private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager
5353
override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext)
5454
: ShuffleWriter[K, V] = {
5555
new HashShuffleWriter(
56-
shuffleBlockManager, handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context)
56+
shuffleBlockResolver, handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context)
5757
}
5858

5959
/** Remove a shuffle's metadata from the ShuffleManager. */
6060
override def unregisterShuffle(shuffleId: Int): Boolean = {
61-
shuffleBlockManager.removeShuffle(shuffleId)
61+
shuffleBlockResolver.removeShuffle(shuffleId)
6262
}
6363

64-
override def shuffleBlockManager: FileShuffleBlockManager = {
64+
override def shuffleBlockResolver: FileShuffleBlockManager = {
6565
fileShuffleBlockManager
6666
}
6767

6868
/** Shut down this ShuffleManager. */
6969
override def stop(): Unit = {
70-
shuffleBlockManager.stop()
70+
shuffleBlockResolver.stop()
7171
}
7272
}

core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,26 +58,26 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
5858
val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, V, _]]
5959
shuffleMapNumber.putIfAbsent(baseShuffleHandle.shuffleId, baseShuffleHandle.numMaps)
6060
new SortShuffleWriter(
61-
shuffleBlockManager, baseShuffleHandle, mapId, context)
61+
shuffleBlockResolver, baseShuffleHandle, mapId, context)
6262
}
6363

6464
/** Remove a shuffle's metadata from the ShuffleManager. */
6565
override def unregisterShuffle(shuffleId: Int): Boolean = {
6666
if (shuffleMapNumber.containsKey(shuffleId)) {
6767
val numMaps = shuffleMapNumber.remove(shuffleId)
6868
(0 until numMaps).map{ mapId =>
69-
shuffleBlockManager.removeDataByMap(shuffleId, mapId)
69+
shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
7070
}
7171
}
7272
true
7373
}
7474

75-
override def shuffleBlockManager: IndexShuffleBlockManager = {
75+
override def shuffleBlockResolver: IndexShuffleBlockManager = {
7676
indexShuffleBlockManager
7777
}
7878

7979
/** Shut down this ShuffleManager. */
8080
override def stop(): Unit = {
81-
shuffleBlockManager.stop()
81+
shuffleBlockResolver.stop()
8282
}
83-
}
83+
}

core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,15 @@ private[spark] class SortShuffleWriter[K, V, C](
5858
// In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
5959
// care whether the keys get sorted in each partition; that will be done on the reduce side
6060
// if the operation being run is sortByKey.
61-
sorter = new ExternalSorter[K, V, V](
62-
None, Some(dep.partitioner), None, dep.serializer)
61+
sorter = new ExternalSorter[K, V, V](None, Some(dep.partitioner), None, dep.serializer)
6362
sorter.insertAll(records)
6463
}
6564

6665
// Don't bother including the time to open the merged output file in the shuffle write time,
6766
// because it just opens a single file, so is typically too fast to measure accurately
6867
// (see SPARK-3570).
6968
val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId)
70-
val blockId = shuffleBlockManager.consolidateId(dep.shuffleId, mapId)
69+
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID)
7170
val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile)
7271
shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths)
7372

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ private[spark] class BlockManager(
301301
*/
302302
override def getBlockData(blockId: BlockId): ManagedBuffer = {
303303
if (blockId.isShuffle) {
304-
shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
304+
shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
305305
} else {
306306
val blockBytesOpt = doGetLocal(blockId, asBlockResult = false)
307307
.asInstanceOf[Option[ByteBuffer]]
@@ -439,14 +439,10 @@ private[spark] class BlockManager(
439439
// As an optimization for map output fetches, if the block is for a shuffle, return it
440440
// without acquiring a lock; the disk store never deletes (recent) items so this should work
441441
if (blockId.isShuffle) {
442-
val shuffleBlockManager = shuffleManager.shuffleBlockManager
443-
shuffleBlockManager.getBytes(blockId.asInstanceOf[ShuffleBlockId]) match {
444-
case Some(bytes) =>
445-
Some(bytes)
446-
case None =>
447-
throw new BlockException(
448-
blockId, s"Block $blockId not found on disk, though it should be")
449-
}
442+
val shuffleBlockManager = shuffleManager.shuffleBlockResolver
443+
// TODO: This should gracefully handle case where local block is not available. Currently
444+
// downstream code will throw an exception.
445+
Some(shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer())
450446
} else {
451447
doGetLocal(blockId, asBlockResult = false).asInstanceOf[Option[ByteBuffer]]
452448
}

core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,7 @@ private[spark] class ExternalSorter[K, V, C](
673673
* For now, we just merge all the spilled files in once pass, but this can be modified to
674674
* support hierarchical merging.
675675
*/
676-
def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
676+
private def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
677677
val usingMap = aggregator.isDefined
678678
val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer
679679
if (spills.isEmpty && partitionWriters == null) {
@@ -781,7 +781,7 @@ private[spark] class ExternalSorter[K, V, C](
781781
/**
782782
* Read a partition file back as an iterator (used in our iterator method)
783783
*/
784-
def readPartitionFile(writer: BlockObjectWriter): Iterator[Product2[K, C]] = {
784+
private def readPartitionFile(writer: BlockObjectWriter): Iterator[Product2[K, C]] = {
785785
if (writer.isOpen) {
786786
writer.commitAndClose()
787787
}

0 commit comments

Comments
 (0)