Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 126 additions & 2 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,9 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
: Seq[String] = {
if (shuffleLocalityEnabled && dep.rdd.partitions.length < SHUFFLE_PREF_MAP_THRESHOLD &&
dep.partitioner.numPartitions < SHUFFLE_PREF_REDUCE_THRESHOLD) {
val blockManagerIds = getLocationsWithLargestOutputs(dep.shuffleId, partitionId,
dep.partitioner.numPartitions, REDUCER_PREF_LOCS_FRACTION)
// replace getLocationsWithLargestOutputs with getLocationsWithOverAllSituation
val blockManagerIds = getLocationsWithGlobalMode(dep.shuffleId,
dep.partitioner.numPartitions)
if (blockManagerIds.nonEmpty) {
blockManagerIds.get.map(_.host)
} else {
Expand Down Expand Up @@ -421,6 +422,129 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
None
}

/**
* Return a list of locations that each have fraction of map output according to load balancing
* and achieve fetching least data.
*
* @param shuffleId id of the shuffle
* @param numReducers total number of reducers in the shuffle
*
*/
def getLocationsWithGlobalMode(shuffleId: Int, numReducers: Int): Option[Array[BlockManagerId]] = {
val statuses = mapStatuses.get(shuffleId).orNull
assert(statuses != null)
val splitsByLocation = new HashMap[BlockManagerId, Array[Long]]
var sumOfAllBytes: Long = 0
statuses.foreach {
status =>
if (status == null) {
throw new MetadataFetchFailedException(
shuffleId, -1, "Missing an output location for shuffle " + shuffleId)
} else {
val location = status.location
if (!splitsByLocation.contains(location)) {
splitsByLocation(location) = new Array[Long](numReducers)
}
for (index <- 0 until numReducers) {
val byteSize = status.getSizeForBlock(index)
splitsByLocation(location)(index) += byteSize
sumOfAllBytes += byteSize
}
}
}
if (splitsByLocation.nonEmpty) {
val numOfLocations = splitsByLocation.size
val preferredLocationsOfReduces = new Array[BlockManagerId](numReducers)
val bytesOfReduces = new Array[Long](numReducers)
val blockManagerIdMaps = new HashMap[Int, BlockManagerId]
val splitIndexOfLocation = new Array[HashSet[Int]](numOfLocations)
var locIndex = 0
//caclulate the bytesize of each reducer
splitsByLocation.toSeq.map(
kvItems => {
val (blockManagerId, byteSizes) = kvItems
blockManagerIdMaps(locIndex) = blockManagerId
splitIndexOfLocation(locIndex) = new HashSet[Int]
for (index <- 0 until byteSizes.length) {
bytesOfReduces(index) += byteSizes(index)
}
locIndex += 1
})

val indexOfBytesOfReduces = new HashMap[Int, Long]
for ((size, index) <- bytesOfReduces.zipWithIndex) {
indexOfBytesOfReduces.getOrElseUpdate(index, size)
}
val sortedIndexOfBytesOfReducer = indexOfBytesOfReduces.toSeq.sortWith(_._2 > _._2)
val splitSumOfByteSizeOfLocation = new Array[Long](numOfLocations)

//Divide the tasks into n groups according to the number of nodes and data size,
// ensuring that the data size for each group is nearly equal to achieve load balancing.
for (index <- sortedIndexOfBytesOfReducer.indices) {
var minLocIndex = 0
for (locIndex <- 1 until numOfLocations) {
if (splitSumOfByteSizeOfLocation(locIndex) < splitSumOfByteSizeOfLocation(minLocIndex)) {
minLocIndex = locIndex
}
}
val (loc, byteSize) = sortedIndexOfBytesOfReducer(index)
splitSumOfByteSizeOfLocation(minLocIndex) += byteSize
splitIndexOfLocation(minLocIndex).add(loc)
}

// Determine the amount of local data if the tasks of every group are executed on every node.
// Thus, a n × n matrix is created.
val splitBytesOfLocationsAndGroup = new Array[Array[Long]](numOfLocations)
for (index <- splitBytesOfLocationsAndGroup.indices) {
splitBytesOfLocationsAndGroup(index) = new Array[Long](numOfLocations)
}
for (row <- splitIndexOfLocation.indices) {
val iter: Iterator[Int] = splitIndexOfLocation(row).iterator
while (iter.hasNext) {
val index = iter.next()
val bytesOfLocations: Seq[(BlockManagerId, Long)] = splitsByLocation.toSeq.map(s => (s._1, s._2(index)))
for (col <- bytesOfLocations.indices) {
splitBytesOfLocationsAndGroup(row)(col) += bytesOfLocations(col)._2
}
}
}
//Choose the largest value in the matrix to identify which group is allocated to which node.
// Mark the row and column at which the selected group is located to ensure that the group
// is not chosen next time. Goto Step 4 until no group is available.
for (index <- 0 until numOfLocations) {
var maxCol = 0
var maxRow = 0
var maxValue = splitBytesOfLocationsAndGroup(maxRow)(maxCol)
for (row <- splitBytesOfLocationsAndGroup.indices) {
for (col <- splitBytesOfLocationsAndGroup(row).indices) {
if (splitBytesOfLocationsAndGroup(row)(col) > maxValue) {
maxRow = row
maxCol = col
maxValue = splitBytesOfLocationsAndGroup(row)(col)
}
}
}
val iter: Iterator[Int] = splitIndexOfLocation(maxRow).iterator
while (iter.hasNext) {
val index = iter.next()
preferredLocationsOfReduces(index) = blockManagerIdMaps(maxCol)
}
for (row <- splitBytesOfLocationsAndGroup.indices) {
splitBytesOfLocationsAndGroup(row)(maxCol) = -1
}
for (col <- splitBytesOfLocationsAndGroup.indices) {
splitBytesOfLocationsAndGroup(maxRow)(col) = -1
}
}
Some(preferredLocationsOfReduces)
}
else
None
}




def incrementEpoch() {
epochLock.synchronized {
epoch += 1
Expand Down
29 changes: 29 additions & 0 deletions core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -242,4 +242,33 @@ class MapOutputTrackerSuite extends SparkFunSuite {
tracker.stop()
rpcEnv.shutdown()
}


test("getLocationsWithOverAllSituation with multiple outputs") {
val rpcEnv = createRpcEnv("test")
val tracker = new MapOutputTrackerMaster(conf)
tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
// Setup 4 map tasks
tracker.registerShuffle(10, 4)
tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000),
Array(1L,2L,3L,4L,5L,6L,7L)))
tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000),
Array(2L,8L,2L,5L,2L,2L,9L)))
tracker.registerMapOutput(10, 2, MapStatus(BlockManagerId("c", "hostC", 1000),
Array(3L,2L,7L,4L,2L,6L,1L)))
tracker.registerMapOutput(10, 3, MapStatus(BlockManagerId("c", "hostC", 1000),
Array(8L,5L,3L,4L,7L,2L,5L)))

val topLocs = tracker.getLocationsWithGlobalMode(10, 7)
assert(topLocs.nonEmpty)
assert(topLocs.get.size === 7)
assert(topLocs.get ===
Array(BlockManagerId("c", "hostC", 1000), BlockManagerId("c", "hostC", 1000), BlockManagerId("b", "hostB", 1000),
BlockManagerId("a", "hostA", 1000), BlockManagerId("a", "hostA", 1000), BlockManagerId("c", "hostC", 1000),
BlockManagerId("b", "hostB", 1000)))

tracker.stop()
rpcEnv.shutdown()
}
}