Skip to content

Commit 4126c1b

Browse files
committed
Add support for migrating shuffle files
1 parent 249b214 commit 4126c1b

File tree

19 files changed

+462
-46
lines changed

19 files changed

+462
-46
lines changed

core/src/main/scala/org/apache/spark/MapOutputTracker.scala

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ import org.apache.spark.util._
4949
*
5050
* All public methods of this class are thread-safe.
5151
*/
52-
private class ShuffleStatus(numPartitions: Int) {
52+
private class ShuffleStatus(numPartitions: Int) extends Logging {
5353

5454
private val (readLock, writeLock) = {
5555
val lock = new ReentrantReadWriteLock()
@@ -121,6 +121,20 @@ private class ShuffleStatus(numPartitions: Int) {
121121
mapStatuses(mapIndex) = status
122122
}
123123

124+
/**
125+
* Update the map output location (e.g. during migration).
126+
*/
127+
def updateMapOutput(mapId: Long, bmAddress: BlockManagerId): Unit = withWriteLock {
128+
val mapStatusOpt = mapStatuses.find(_.mapId == mapId)
129+
mapStatusOpt match {
130+
case Some(mapStatus) =>
131+
mapStatus.updateLocation(bmAddress)
132+
invalidateSerializedMapOutputStatusCache()
133+
case None =>
134+
logError("Asked to update map output ${mapId} for untracked map status.")
135+
}
136+
}
137+
124138
/**
125139
* Remove the map output which was served by the specified block manager.
126140
* This is a no-op if there is no registered map output or if the registered output is from a
@@ -479,6 +493,13 @@ private[spark] class MapOutputTrackerMaster(
479493
}
480494
}
481495

496+
def updateMapOutput(shuffleId: Int, mapId: Long, bmAddress: BlockManagerId): Unit = {
497+
shuffleStatuses.get(shuffleId) match {
498+
case Some(shuffleStatus) => shuffleStatus.updateMapOutput(mapId, bmAddress)
499+
case None => logError("Asked to update map output for unknown shuffle ${shuffleId}")
500+
}
501+
}
502+
482503
def registerMapOutput(shuffleId: Int, mapIndex: Int, status: MapStatus): Unit = {
483504
shuffleStatuses(shuffleId).addMapOutput(mapIndex, status)
484505
}

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ import org.apache.spark.resource._
5757
import org.apache.spark.resource.ResourceUtils._
5858
import org.apache.spark.rpc.RpcEndpointRef
5959
import org.apache.spark.scheduler._
60-
import org.apache.spark.scheduler.cluster.StandaloneSchedulerBackend
60+
import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, StandaloneSchedulerBackend}
6161
import org.apache.spark.scheduler.local.LocalSchedulerBackend
6262
import org.apache.spark.shuffle.ShuffleDataIOUtils
6363
import org.apache.spark.shuffle.api.ShuffleDriverComponents
@@ -1586,7 +1586,7 @@ class SparkContext(config: SparkConf) extends Logging {
15861586
listenerBus.removeListener(listener)
15871587
}
15881588

1589-
private[spark] def getExecutorIds(): Seq[String] = {
1589+
def getExecutorIds(): Seq[String] = {
15901590
schedulerBackend match {
15911591
case b: ExecutorAllocationClient =>
15921592
b.getExecutorIds()
@@ -1725,6 +1725,17 @@ class SparkContext(config: SparkConf) extends Logging {
17251725
}
17261726
}
17271727

1728+
1729+
@DeveloperApi
1730+
def decommissionExecutors(executorIds: Seq[String]): Unit = {
1731+
schedulerBackend match {
1732+
case b: CoarseGrainedSchedulerBackend =>
1733+
executorIds.foreach(b.decommissionExecutor)
1734+
case _ =>
1735+
logWarning("Decommissioning executors is not supported by current scheduler.")
1736+
}
1737+
}
1738+
17281739
/** The version of Spark on which this application is running. */
17291740
def version: String = SPARK_VERSION
17301741

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,8 @@ object SparkEnv extends Logging {
367367
externalShuffleClient
368368
} else {
369369
None
370-
}, blockManagerInfo)),
370+
}, blockManagerInfo,
371+
mapOutputTracker.asInstanceOf[MapOutputTrackerMaster])),
371372
registerOrLookupEndpoint(
372373
BlockManagerMaster.DRIVER_HEARTBEAT_ENDPOINT_NAME,
373374
new BlockManagerMasterHeartbeatEndpoint(rpcEnv, isLocal, blockManagerInfo)),

core/src/main/scala/org/apache/spark/internal/config/package.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,21 @@ package object config {
420420
.booleanConf
421421
.createWithDefault(false)
422422

423+
private[spark] val STORAGE_SHUFFLE_DECOMMISSION_ENABLED =
424+
ConfigBuilder("spark.storage.decommission.shuffle_blocks")
425+
.doc("Whether to transfer shuffle blocks during block manager decommissioning. Requires " +
426+
"an indexed shuffle resolver (like sort based shuffe)")
427+
.version("3.1.0")
428+
.booleanConf
429+
.createWithDefault(true)
430+
431+
private[spark] val STORAGE_RDD_DECOMMISSION_ENABLED =
432+
ConfigBuilder("spark.storage.decommission.rdd_blocks")
433+
.doc("Whether to transfer RDD blocks during block manager decommissioning.")
434+
.version("3.1.0")
435+
.booleanConf
436+
.createWithDefault(true)
437+
423438
private[spark] val STORAGE_DECOMMISSION_MAX_REPLICATION_FAILURE_PER_BLOCK =
424439
ConfigBuilder("spark.storage.decommission.maxReplicationFailuresPerBlock")
425440
.internal()

core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@ import org.apache.spark.util.Utils
3333
* task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks.
3434
*/
3535
private[spark] sealed trait MapStatus {
36-
/** Location where this task was run. */
36+
/** Location where this task output is. */
3737
def location: BlockManagerId
3838

39+
def updateLocation(bm: BlockManagerId): Unit
40+
3941
/**
4042
* Estimated size for the reduce block, in bytes.
4143
*
@@ -126,6 +128,10 @@ private[spark] class CompressedMapStatus(
126128

127129
override def location: BlockManagerId = loc
128130

131+
override def updateLocation(bm: BlockManagerId): Unit = {
132+
loc = bm
133+
}
134+
129135
override def getSizeForBlock(reduceId: Int): Long = {
130136
MapStatus.decompressSize(compressedSizes(reduceId))
131137
}
@@ -178,6 +184,10 @@ private[spark] class HighlyCompressedMapStatus private (
178184

179185
override def location: BlockManagerId = loc
180186

187+
override def updateLocation(bm: BlockManagerId): Unit = {
188+
loc = bm
189+
}
190+
181191
override def getSizeForBlock(reduceId: Int): Long = {
182192
assert(hugeBlockSizes != null)
183193
if (emptyBlocks.contains(reduceId)) {

core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,18 @@
1818
package org.apache.spark.shuffle
1919

2020
import java.io._
21+
import java.nio.ByteBuffer
2122
import java.nio.channels.Channels
2223
import java.nio.file.Files
2324

2425
import org.apache.spark.{SparkConf, SparkEnv}
2526
import org.apache.spark.internal.Logging
2627
import org.apache.spark.io.NioBufferedFileInputStream
2728
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
29+
import org.apache.spark.network.client.StreamCallbackWithID
2830
import org.apache.spark.network.netty.SparkTransportConf
2931
import org.apache.spark.network.shuffle.ExecutorDiskUtils
32+
import org.apache.spark.serializer.SerializerManager
3033
import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID
3134
import org.apache.spark.storage._
3235
import org.apache.spark.util.Utils
@@ -55,6 +58,25 @@ private[spark] class IndexShuffleBlockResolver(
5558

5659
def getDataFile(shuffleId: Int, mapId: Long): File = getDataFile(shuffleId, mapId, None)
5760

61+
/**
62+
* Get the shuffle files that are stored locally. Used for block migrations.
63+
*/
64+
def getStoredShuffles(): Set[(Int, Long)] = {
65+
// Matches ShuffleIndexBlockId name
66+
val pattern = "shuffle_(\\d+)_(\\d+)_.+\\.index".r
67+
val rootDirs = blockManager.diskBlockManager.localDirs
68+
// ExecutorDiskUtil puts things inside one level hashed sub directories
69+
val searchDirs = rootDirs.flatMap(_.listFiles()).filter(_.isDirectory()) ++ rootDirs
70+
val filenames = searchDirs.flatMap(_.list())
71+
logDebug(s"Got block files ${filenames.toList}")
72+
filenames.flatMap{ fname =>
73+
pattern.findAllIn(fname).matchData.map {
74+
matched => (matched.group(1).toInt, matched.group(2).toLong)
75+
}
76+
}.toSet
77+
}
78+
79+
5880
/**
5981
* Get the shuffle data file.
6082
*
@@ -148,6 +170,86 @@ private[spark] class IndexShuffleBlockResolver(
148170
}
149171
}
150172

173+
/**
174+
* Write a provided shuffle block as a stream. Used for block migrations.
175+
* ShuffleBlockBatchIds must contain the full range represented in the ShuffleIndexBlock.
176+
* Requires the caller to delete any shuffle index blocks where the shuffle block fails to
177+
* put.
178+
*/
179+
def putShuffleBlockAsStream(blockId: BlockId, serializerManager: SerializerManager):
180+
StreamCallbackWithID = {
181+
val file = blockId match {
182+
case ShuffleIndexBlockId(shuffleId, mapId, _) =>
183+
getIndexFile(shuffleId, mapId)
184+
case ShuffleBlockBatchId(shuffleId, mapId, _, _) =>
185+
getDataFile(shuffleId, mapId)
186+
case _ =>
187+
throw new Exception(s"Unexpected shuffle block transfer ${blockId}")
188+
}
189+
val fileTmp = Utils.tempFileWith(file)
190+
val channel = Channels.newChannel(
191+
serializerManager.wrapStream(blockId,
192+
new FileOutputStream(fileTmp)))
193+
194+
new StreamCallbackWithID {
195+
196+
override def getID: String = blockId.name
197+
198+
override def onData(streamId: String, buf: ByteBuffer): Unit = {
199+
while (buf.hasRemaining) {
200+
channel.write(buf)
201+
}
202+
}
203+
204+
override def onComplete(streamId: String): Unit = {
205+
logTrace(s"Done receiving block $blockId, now putting into local shuffle service")
206+
channel.close()
207+
val diskSize = fileTmp.length()
208+
this.synchronized {
209+
if (file.exists()) {
210+
file.delete()
211+
}
212+
if (!fileTmp.renameTo(file)) {
213+
throw new IOException(s"fail to rename file ${fileTmp} to ${file}")
214+
}
215+
}
216+
blockManager.reportBlockStatus(blockId, BlockStatus(
217+
StorageLevel(
218+
useDisk = true,
219+
useMemory = false,
220+
useOffHeap = false,
221+
deserialized = false,
222+
replication = 0)
223+
, 0, diskSize))
224+
}
225+
226+
override def onFailure(streamId: String, cause: Throwable): Unit = {
227+
// the framework handles the connection itself, we just need to do local cleanup
228+
channel.close()
229+
fileTmp.delete()
230+
}
231+
}
232+
}
233+
234+
/**
235+
* Get the index & data block for migration.
236+
*/
237+
def getMigrationBlocks(shuffleId: Int, mapId: Long):
238+
((BlockId, ManagedBuffer), (BlockId, ManagedBuffer)) = {
239+
// Load the index block
240+
val indexFile = getIndexFile(shuffleId, mapId)
241+
val indexBlockId = ShuffleIndexBlockId(shuffleId, mapId, 0)
242+
val indexFileSize = indexFile.length()
243+
val indexBlockData = new FileSegmentManagedBuffer(transportConf, indexFile, 0, indexFileSize)
244+
245+
// Load the data block
246+
val dataFile = getDataFile(shuffleId, mapId)
247+
val dataBlockId = ShuffleDataBlockId(shuffleId, mapId, 0)
248+
val dataBlockData = new FileSegmentManagedBuffer(transportConf, dataFile, 0, dataFile.length())
249+
((indexBlockId, indexBlockData), (dataBlockId, dataBlockData))
250+
}
251+
252+
151253
/**
152254
* Write an index file with the offsets of each block, plus a final offset at the end for the
153255
* end of the output file. This will be used by getBlockData to figure out where each block
@@ -169,7 +271,7 @@ private[spark] class IndexShuffleBlockResolver(
169271
val dataFile = getDataFile(shuffleId, mapId)
170272
// There is only one IndexShuffleBlockResolver per executor, this synchronization make sure
171273
// the following check and rename are atomic.
172-
synchronized {
274+
this.synchronized {
173275
val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length)
174276
if (existingLengths != null) {
175277
// Another attempt for the same task has already written our map outputs successfully,

core/src/main/scala/org/apache/spark/storage/BlockId.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ sealed abstract class BlockId {
4040
def isRDD: Boolean = isInstanceOf[RDDBlockId]
4141
def isShuffle: Boolean = isInstanceOf[ShuffleBlockId] || isInstanceOf[ShuffleBlockBatchId]
4242
def isBroadcast: Boolean = isInstanceOf[BroadcastBlockId]
43+
def isInternalShuffle: Boolean = {
44+
isInstanceOf[ShuffleDataBlockId] || isInstanceOf[ShuffleIndexBlockId]
45+
}
4346

4447
override def toString: String = name
4548
}

0 commit comments

Comments
 (0)