Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions core/src/main/scala/org/apache/spark/Dependency.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleHandle, ShuffleWriteProcessor}
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.Utils

/**
* :: DeveloperApi ::
Expand Down Expand Up @@ -96,12 +97,31 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle(
shuffleId, this)

// By default, shuffle merge is enabled for ShuffleDependency if push based shuffle
// is enabled
private[this] var _shuffleMergeEnabled =
Utils.isPushBasedShuffleEnabled(rdd.sparkContext.getConf) &&
// TODO: SPARK-35547: Push based shuffle is currently unsupported for Barrier stages
!rdd.isBarrier()

private[spark] def setShuffleMergeEnabled(shuffleMergeEnabled: Boolean): Unit = {
_shuffleMergeEnabled = shuffleMergeEnabled
}

def shuffleMergeEnabled : Boolean = _shuffleMergeEnabled

/**
* Stores the location of the list of chosen external shuffle services for handling the
* shuffle merge requests from mappers in this shuffle map stage.
*/
private[spark] var mergerLocs: Seq[BlockManagerId] = Nil

/**
* Stores the information about whether the shuffle merge is finalized for the shuffle map stage
* associated with this shuffle dependency
*/
private[this] var _shuffleMergedFinalized: Boolean = false

def setMergerLocs(mergerLocs: Seq[BlockManagerId]): Unit = {
if (mergerLocs != null) {
this.mergerLocs = mergerLocs
Expand All @@ -110,6 +130,24 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](

def getMergerLocs: Seq[BlockManagerId] = mergerLocs

private[spark] def markShuffleMergeFinalized(): Unit = {
_shuffleMergedFinalized = true
}

/**
* Returns true if push-based shuffle is disabled for this stage or empty RDD,
* or if the shuffle merge for this stage is finalized, i.e. the shuffle merge
* results for all partitions are available.
*/
def shuffleMergeFinalized: Boolean = {
// Empty RDD won't be computed therefore shuffle merge finalized should be true by default.
if (shuffleMergeEnabled && rdd.getNumPartitions > 0) {
_shuffleMergedFinalized
} else {
true
}
}

_rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
_rdd.sparkContext.shuffleDriverComponents.registerShuffle(shuffleId)
}
Expand Down
44 changes: 32 additions & 12 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ private class ShuffleStatus(
def removeOutputsOnHost(host: String): Unit = withWriteLock {
logDebug(s"Removing outputs for host ${host}")
removeOutputsByFilter(x => x.host == host)
removeMergeResultsByFilter(x => x.host == host)
}

/**
Expand All @@ -238,6 +239,12 @@ private class ShuffleStatus(
invalidateSerializedMapOutputStatusCache()
}
}
}

/**
* Removes all shuffle merge result which satisfies the filter.
*/
def removeMergeResultsByFilter(f: BlockManagerId => Boolean): Unit = withWriteLock {
for (reduceId <- mergeStatuses.indices) {
if (mergeStatuses(reduceId) != null && f(mergeStatuses(reduceId).location)) {
_numAvailableMergeResults -= 1
Expand Down Expand Up @@ -708,15 +715,16 @@ private[spark] class MapOutputTrackerMaster(
}
}

/** Unregister all map output information of the given shuffle. */
def unregisterAllMapOutput(shuffleId: Int): Unit = {
/** Unregister all map and merge output information of the given shuffle. */
def unregisterAllMapAndMergeOutput(shuffleId: Int): Unit = {
shuffleStatuses.get(shuffleId) match {
case Some(shuffleStatus) =>
shuffleStatus.removeOutputsByFilter(x => true)
shuffleStatus.removeMergeResultsByFilter(x => true)
incrementEpoch()
case None =>
throw new SparkException(
s"unregisterAllMapOutput called for nonexistent shuffle ID $shuffleId.")
s"unregisterAllMapAndMergeOutput called for nonexistent shuffle ID $shuffleId.")
}
}

Expand All @@ -731,25 +739,26 @@ private[spark] class MapOutputTrackerMaster(
}

/**
* Unregisters a merge result corresponding to the reduceId if present. If the optional mapId
* is specified, it will only unregister the merge result if the mapId is part of that merge
* Unregisters a merge result corresponding to the reduceId if present. If the optional mapIndex
* is specified, it will only unregister the merge result if the mapIndex is part of that merge
* result.
*
* @param shuffleId the shuffleId.
* @param reduceId the reduceId.
* @param bmAddress block manager address.
* @param mapId the optional mapId which should be checked to see it was part of the merge
* result.
* @param mapIndex the optional mapIndex which should be checked to see it was part of the
* merge result.
*/
def unregisterMergeResult(
shuffleId: Int,
reduceId: Int,
bmAddress: BlockManagerId,
mapId: Option[Int] = None): Unit = {
shuffleId: Int,
reduceId: Int,
bmAddress: BlockManagerId,
mapIndex: Option[Int] = None): Unit = {
shuffleStatuses.get(shuffleId) match {
case Some(shuffleStatus) =>
val mergeStatus = shuffleStatus.mergeStatuses(reduceId)
if (mergeStatus != null && (mapId.isEmpty || mergeStatus.tracker.contains(mapId.get))) {
if (mergeStatus != null &&
(mapIndex.isEmpty || mergeStatus.tracker.contains(mapIndex.get))) {
shuffleStatus.removeMergeResult(reduceId, bmAddress)
incrementEpoch()
}
Expand All @@ -758,6 +767,17 @@ private[spark] class MapOutputTrackerMaster(
}
}

def unregisterAllMergeResult(shuffleId: Int): Unit = {
shuffleStatuses.get(shuffleId) match {
case Some(shuffleStatus) =>
shuffleStatus.removeMergeResultsByFilter(x => true)
incrementEpoch()
case None =>
throw new SparkException(
s"unregisterAllMergeResult called for nonexistent shuffle ID $shuffleId.")
}
}

/** Unregister shuffle data */
def unregisterShuffle(shuffleId: Int): Unit = {
shuffleStatuses.remove(shuffleId).foreach { shuffleStatus =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2084,6 +2084,27 @@ package object config {
.booleanConf
.createWithDefault(false)

private[spark] val PUSH_BASED_SHUFFLE_MERGE_RESULTS_TIMEOUT =
ConfigBuilder("spark.shuffle.push.merge.results.timeout")
.doc("Specify the max amount of time DAGScheduler waits for the merge results from " +
"all remote shuffle services for a given shuffle. DAGScheduler will start to submit " +
"following stages if not all results are received within the timeout.")
.version("3.2.0")
.timeConf(TimeUnit.SECONDS)
.checkValue(_ >= 0L, "Timeout must be >= 0.")
.createWithDefaultString("10s")

private[spark] val PUSH_BASED_SHUFFLE_MERGE_FINALIZE_TIMEOUT =
ConfigBuilder("spark.shuffle.push.merge.finalize.timeout")
.doc("Specify the amount of time DAGScheduler waits after all mappers finish for " +
"a given shuffle map stage before it starts sending merge finalize requests to " +
"remote shuffle services. This allows the shuffle services some extra time to " +
"merge as many blocks as possible.")
.version("3.2.0")
.timeConf(TimeUnit.SECONDS)
.checkValue(_ >= 0L, "Timeout must be >= 0.")
.createWithDefaultString("10s")

private[spark] val SHUFFLE_MERGER_MAX_RETAINED_LOCATIONS =
ConfigBuilder("spark.shuffle.push.maxRetainedMergerLocations")
.doc("Maximum number of shuffle push merger locations cached for push based shuffle. " +
Expand Down Expand Up @@ -2117,7 +2138,7 @@ package object config {
s"${SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO.key} set to 0.05, we would need " +
"at least 50 mergers to enable push based shuffle for that stage.")
.version("3.1.0")
.doubleConf
.intConf
.createWithDefault(5)

private[spark] val SHUFFLE_NUM_PUSH_THREADS =
Expand Down
Loading