@@ -52,8 +52,7 @@ import org.apache.spark.util._
5252 */
5353private class ShuffleStatus (
5454 numPartitions : Int ,
55- numReducers : Int ,
56- isPushBasedShuffleEnabled : Boolean = false ) extends Logging {
55+ numReducers : Int = - 1 ) extends Logging {
5756
5857 private val (readLock, writeLock) = {
5958 val lock = new ReentrantReadWriteLock ()
@@ -97,7 +96,7 @@ private class ShuffleStatus(
9796 * provides a reducer oriented view of the shuffle status specifically for the results of
9897 * merging shuffle partition blocks into per-partition merged shuffle files.
9998 */
100- val mergeStatuses = if (isPushBasedShuffleEnabled ) {
99+ val mergeStatuses = if (numReducers > 0 ) {
101100 new Array [MergeStatus ](numReducers)
102101 } else {
103102 Array .empty[MergeStatus ]
@@ -674,9 +673,14 @@ private[spark] class MapOutputTrackerMaster(
674673 }
675674
676675 def registerShuffle (shuffleId : Int , numMaps : Int , numReduces : Int ): Unit = {
677- if (shuffleStatuses.put(shuffleId,
678- new ShuffleStatus (numMaps, numReduces, pushBasedShuffleEnabled)).isDefined) {
679- throw new IllegalArgumentException (" Shuffle ID " + shuffleId + " registered twice" )
676+ if (pushBasedShuffleEnabled) {
677+ if (shuffleStatuses.put(shuffleId, new ShuffleStatus (numMaps, numReduces)).isDefined) {
678+ throw new IllegalArgumentException (" Shuffle ID " + shuffleId + " registered twice" )
679+ }
680+ } else {
681+ if (shuffleStatuses.put(shuffleId, new ShuffleStatus (numMaps)).isDefined) {
682+ throw new IllegalArgumentException (" Shuffle ID " + shuffleId + " registered twice" )
683+ }
680684 }
681685 }
682686
@@ -1399,7 +1403,8 @@ private[spark] object MapOutputTracker extends Logging {
13991403 // TODO: SPARK-35036: Instead of reading map blocks in case of AQE with Push based shuffle,
14001404 // TODO: improve push based shuffle to read partial merged blocks satisfying the start/end
14011405 // TODO: map indexes
1402- if (mergeStatuses.isDefined && startMapIndex == 0 && endMapIndex == mapStatuses.length) {
1406+ if (mergeStatuses.exists(_.nonEmpty) && startMapIndex == 0
1407+ && endMapIndex == mapStatuses.length) {
14031408 // We have MergeStatus and full range of mapIds are requested so return a merged block.
14041409 val numMaps = mapStatuses.length
14051410 mergeStatuses.get.zipWithIndex.slice(startPartition, endPartition).foreach {
@@ -1413,7 +1418,8 @@ private[spark] object MapOutputTracker extends Logging {
14131418 ((ShuffleBlockId (shuffleId, SHUFFLE_PUSH_MAP_ID , partId), mergeStatus.totalSize, - 1 ))
14141419 // For the "holes" in this pre-merged shuffle partition, i.e., unmerged mapper
14151420 // shuffle partition blocks, fetch the original map produced shuffle partition blocks
1416- mergeStatus.getMissingMaps(numMaps).map(mapStatuses.zipWithIndex)
1421+ val mapStatusesWithIndex = mapStatuses.zipWithIndex
1422+ mergeStatus.getMissingMaps(numMaps).map(mapStatusesWithIndex)
14171423 } else {
14181424 // If MergeStatus is not available for the given partition, fall back to
14191425 // fetching all the original mapper shuffle partition blocks
0 commit comments