Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 20 additions & 17 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1383,33 +1383,36 @@ class DAGScheduler(
return rddPrefs.map(TaskLocation(_))
}

// If the RDD has narrow dependencies, pick the first partition of the first narrow dependency
// that has any placement preferences. Ideally we would choose based on transfer sizes,
// but this will do for now.
rdd.dependencies.foreach {
case n: NarrowDependency[_] =>
// If the RDD has narrow dependencies, pick the first partition of the first narrow dep
// that has any placement preferences. Ideally we would choose based on transfer sizes,
// but this will do for now.
for (inPart <- n.getParents(partition)) {
val locs = getPreferredLocsInternal(n.rdd, inPart, visited)
if (locs != Nil) {
return locs
}
}
case s: ShuffleDependency[_, _, _] =>
// For shuffle dependencies, pick locations which have at least REDUCER_PREF_LOCS_FRACTION
// of data as preferred locations
if (shuffleLocalityEnabled &&
rdd.partitions.size < SHUFFLE_PREF_REDUCE_THRESHOLD &&
s.rdd.partitions.size < SHUFFLE_PREF_MAP_THRESHOLD) {
// Get the preferred map output locations for this reducer
val topLocsForReducer = mapOutputTracker.getLocationsWithLargestOutputs(s.shuffleId,
partition, rdd.partitions.size, REDUCER_PREF_LOCS_FRACTION)
if (topLocsForReducer.nonEmpty) {
return topLocsForReducer.get.map(loc => TaskLocation(loc.host, loc.executorId))
}
}

case _ =>
}

// If the RDD has shuffle dependencies and shuffle locality is enabled, pick locations that
// have at least REDUCER_PREF_LOCS_FRACTION of data as preferred locations
if (shuffleLocalityEnabled && rdd.partitions.length < SHUFFLE_PREF_REDUCE_THRESHOLD) {
rdd.dependencies.foreach {
case s: ShuffleDependency[_, _, _] =>
if (s.rdd.partitions.length < SHUFFLE_PREF_MAP_THRESHOLD) {
// Get the preferred map output locations for this reducer
val topLocsForReducer = mapOutputTracker.getLocationsWithLargestOutputs(s.shuffleId,
partition, rdd.partitions.length, REDUCER_PREF_LOCS_FRACTION)
if (topLocsForReducer.nonEmpty) {
return topLocsForReducer.get.map(loc => TaskLocation(loc.host, loc.executorId))
}
}
case _ =>
}
}
Nil
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,7 @@ class DAGSchedulerSuite
assertLocations(reduceTaskSet, Seq(Seq("hostA")))
complete(reduceTaskSet, Seq((Success, 42)))
assert(results === Map(0 -> 42))
assertDataStructuresEmpty
assertDataStructuresEmpty()
}

test("reduce task locality preferences should only include machines with largest map outputs") {
Expand All @@ -950,7 +950,29 @@ class DAGSchedulerSuite
assertLocations(reduceTaskSet, Seq(hosts))
complete(reduceTaskSet, Seq((Success, 42)))
assert(results === Map(0 -> 42))
assertDataStructuresEmpty
assertDataStructuresEmpty()
}

test("stages with both narrow and shuffle dependencies use narrow ones for locality") {
// Create an RDD that has both a shuffle dependency and a narrow dependency (e.g. for a join)
val rdd1 = new MyRDD(sc, 1, Nil)
val rdd2 = new MyRDD(sc, 1, Nil, locations = Seq(Seq("hostB")))
val shuffleDep = new ShuffleDependency(rdd1, null)
val narrowDep = new OneToOneDependency(rdd2)
val shuffleId = shuffleDep.shuffleId
val reduceRdd = new MyRDD(sc, 1, List(shuffleDep, narrowDep))
submit(reduceRdd, Array(0))
complete(taskSets(0), Seq(
(Success, makeMapStatus("hostA", 1))))
assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet ===
HashSet(makeBlockManagerId("hostA")))

// Reducer should run where RDD 2 has preferences, even though though it also has a shuffle dep
val reduceTaskSet = taskSets(1)
assertLocations(reduceTaskSet, Seq(Seq("hostB")))
complete(reduceTaskSet, Seq((Success, 42)))
assert(results === Map(0 -> 42))
assertDataStructuresEmpty()
}

test("Spark exceptions should include call site in stack trace") {
Expand Down