@@ -21,7 +21,7 @@ import java.io._
2121import java .util .concurrent .ConcurrentHashMap
2222import java .util .zip .{GZIPInputStream , GZIPOutputStream }
2323
24- import scala .collection .mutable .{HashSet , Map }
24+ import scala .collection .mutable .{HashMap , HashSet , Map }
2525import scala .collection .JavaConversions ._
2626import scala .reflect .ClassTag
2727
@@ -284,6 +284,53 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
284284 cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId)
285285 }
286286
287+ /**
288+ * Return a list of locations that each have fraction of map output greater than the specified
289+ * threshold.
290+ *
291+ * @param shuffleId id of the shuffle
292+ * @param reducerId id of the reduce task
293+ * @param numReducers total number of reducers in the shuffle
294+ * @param fractionThreshold fraction of total map output size that a location must have
295+ * for it to be considered large.
296+ *
297+ * This method is not thread-safe.
298+ */
299+ def getLocationsWithLargestOutputs (
300+ shuffleId : Int ,
301+ reducerId : Int ,
302+ numReducers : Int ,
303+ fractionThreshold : Double )
304+ : Option [Array [BlockManagerId ]] = {
305+
306+ if (mapStatuses.contains(shuffleId)) {
307+ val statuses = mapStatuses(shuffleId)
308+ if (statuses.nonEmpty) {
309+ // HashMap to add up sizes of all blocks at the same location
310+ val locs = new HashMap [BlockManagerId , Long ]
311+ var totalOutputSize = 0L
312+ var mapIdx = 0
313+ while (mapIdx < statuses.length) {
314+ val status = statuses(mapIdx)
315+ val blockSize = status.getSizeForBlock(reducerId)
316+ if (blockSize > 0 ) {
317+ locs(status.location) = locs.getOrElse(status.location, 0L ) + blockSize
318+ totalOutputSize += blockSize
319+ }
320+ mapIdx = mapIdx + 1
321+ }
322+ val topLocs = locs.filter { case (loc, size) =>
323+ size.toDouble / totalOutputSize >= fractionThreshold
324+ }
325+ // Return if we have any locations which satisfy the required threshold
326+ if (topLocs.nonEmpty) {
327+ return Some (topLocs.map(_._1).toArray)
328+ }
329+ }
330+ }
331+ None
332+ }
333+
287334 def incrementEpoch () {
288335 epochLock.synchronized {
289336 epoch += 1
0 commit comments