Skip to content

Commit 1a09552

Browse files
committed
[SPARK-9851] Support submitting map stages individually in DAGScheduler
This patch adds support for submitting map stages in a DAG individually so that we can make downstream decisions after seeing statistics about their output, as part of SPARK-9850. I also added more comments to many of the key classes in DAGScheduler. By itself, the patch is not super useful except maybe to switch between a shuffle and broadcast join, but with the other subtasks of SPARK-9850 we'll be able to do more interesting decisions. The main entry point is SparkContext.submitMapStage, which lets you run a map stage and see stats about the map output sizes. Other stats could also be collected through accumulators. See AdaptiveSchedulingSuite for a short example. Author: Matei Zaharia <[email protected]> Closes #8180 from mateiz/spark-9851.
1 parent 7b6c856 commit 1a09552

File tree

12 files changed

+710
-63
lines changed

12 files changed

+710
-63
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark
19+
20+
/**
21+
* Holds statistics about the output sizes in a map stage. May become a DeveloperApi in the future.
22+
*
23+
* @param shuffleId ID of the shuffle
24+
* @param bytesByPartitionId approximate number of output bytes for each map output partition
25+
* (may be inexact due to use of compressed map statuses)
26+
*/
27+
private[spark] class MapOutputStatistics(val shuffleId: Int, val bytesByPartitionId: Array[Long])

core/src/main/scala/org/apache/spark/MapOutputTracker.scala

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark
1919

2020
import java.io._
21+
import java.util.Arrays
2122
import java.util.concurrent.ConcurrentHashMap
2223
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
2324

@@ -132,13 +133,43 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
132133
* describing the shuffle blocks that are stored at that block manager.
133134
*/
134135
def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int)
135-
: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
136+
: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
136137
logDebug(s"Fetching outputs for shuffle $shuffleId, reduce $reduceId")
137-
val startTime = System.currentTimeMillis
138+
val statuses = getStatuses(shuffleId)
139+
// Synchronize on the returned array because, on the driver, it gets mutated in place
140+
statuses.synchronized {
141+
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
142+
}
143+
}
138144

145+
/**
146+
* Return statistics about all of the outputs for a given shuffle.
147+
*/
148+
def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = {
149+
val statuses = getStatuses(dep.shuffleId)
150+
// Synchronize on the returned array because, on the driver, it gets mutated in place
151+
statuses.synchronized {
152+
val totalSizes = new Array[Long](dep.partitioner.numPartitions)
153+
for (s <- statuses) {
154+
for (i <- 0 until totalSizes.length) {
155+
totalSizes(i) += s.getSizeForBlock(i)
156+
}
157+
}
158+
new MapOutputStatistics(dep.shuffleId, totalSizes)
159+
}
160+
}
161+
162+
/**
163+
* Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize
164+
* on this array when reading it, because on the driver, we may be changing it in place.
165+
*
166+
* (It would be nice to remove this restriction in the future.)
167+
*/
168+
private def getStatuses(shuffleId: Int): Array[MapStatus] = {
139169
val statuses = mapStatuses.get(shuffleId).orNull
140170
if (statuses == null) {
141171
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
172+
val startTime = System.currentTimeMillis
142173
var fetchedStatuses: Array[MapStatus] = null
143174
fetching.synchronized {
144175
// Someone else is fetching it; wait for them to be done
@@ -160,7 +191,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
160191
}
161192

162193
if (fetchedStatuses == null) {
163-
// We won the race to fetch the output locs; do so
194+
// We won the race to fetch the statuses; do so
164195
logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
165196
// This try-finally prevents hangs due to timeouts:
166197
try {
@@ -175,22 +206,18 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
175206
}
176207
}
177208
}
178-
logDebug(s"Fetching map output location for shuffle $shuffleId, reduce $reduceId took " +
209+
logDebug(s"Fetching map output statuses for shuffle $shuffleId took " +
179210
s"${System.currentTimeMillis - startTime} ms")
180211

181212
if (fetchedStatuses != null) {
182-
fetchedStatuses.synchronized {
183-
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
184-
}
213+
return fetchedStatuses
185214
} else {
186215
logError("Missing all output locations for shuffle " + shuffleId)
187216
throw new MetadataFetchFailedException(
188-
shuffleId, reduceId, "Missing all output locations for shuffle " + shuffleId)
217+
shuffleId, -1, "Missing all output locations for shuffle " + shuffleId)
189218
}
190219
} else {
191-
statuses.synchronized {
192-
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
193-
}
220+
return statuses
194221
}
195222
}
196223

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1984,6 +1984,23 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
19841984
new SimpleFutureAction(waiter, resultFunc)
19851985
}
19861986

1987+
/**
1988+
* Submit a map stage for execution. This is currently an internal API only, but might be
1989+
* promoted to DeveloperApi in the future.
1990+
*/
1991+
private[spark] def submitMapStage[K, V, C](dependency: ShuffleDependency[K, V, C])
1992+
: SimpleFutureAction[MapOutputStatistics] = {
1993+
assertNotStopped()
1994+
val callSite = getCallSite()
1995+
var result: MapOutputStatistics = null
1996+
val waiter = dagScheduler.submitMapStage(
1997+
dependency,
1998+
(r: MapOutputStatistics) => { result = r },
1999+
callSite,
2000+
localProperties.get)
2001+
new SimpleFutureAction[MapOutputStatistics](waiter, result)
2002+
}
2003+
19872004
/**
19882005
* Cancel active jobs for the specified group. See [[org.apache.spark.SparkContext.setJobGroup]]
19892006
* for more information.

core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,42 @@ import org.apache.spark.TaskContext
2323
import org.apache.spark.util.CallSite
2424

2525
/**
26-
* Tracks information about an active job in the DAGScheduler.
26+
* A running job in the DAGScheduler. Jobs can be of two types: a result job, which computes a
27+
* ResultStage to execute an action, or a map-stage job, which computes the map outputs for a
28+
* ShuffleMapStage before any downstream stages are submitted. The latter is used for adaptive
29+
* query planning, to look at map output statistics before submitting later stages. We distinguish
30+
* between these two types of jobs using the finalStage field of this class.
31+
*
32+
* Jobs are only tracked for "leaf" stages that clients directly submitted, through DAGScheduler's
33+
* submitJob or submitMapStage methods. However, either type of job may cause the execution of
34+
* other earlier stages (for RDDs in the DAG it depends on), and multiple jobs may share some of
35+
* these previous stages. These dependencies are managed inside DAGScheduler.
36+
*
37+
* @param jobId A unique ID for this job.
38+
* @param finalStage The stage that this job computes (either a ResultStage for an action or a
39+
* ShuffleMapStage for submitMapStage).
40+
* @param callSite Where this job was initiated in the user's program (shown on UI).
41+
* @param listener A listener to notify if tasks in this job finish or the job fails.
42+
* @param properties Scheduling properties attached to the job, such as fair scheduler pool name.
2743
*/
2844
private[spark] class ActiveJob(
2945
val jobId: Int,
30-
val finalStage: ResultStage,
31-
val func: (TaskContext, Iterator[_]) => _,
32-
val partitions: Array[Int],
46+
val finalStage: Stage,
3347
val callSite: CallSite,
3448
val listener: JobListener,
3549
val properties: Properties) {
3650

37-
val numPartitions = partitions.length
51+
/**
52+
* Number of partitions we need to compute for this job. Note that result stages may not need
53+
* to compute all partitions in their target RDD, for actions like first() and lookup().
54+
*/
55+
val numPartitions = finalStage match {
56+
case r: ResultStage => r.partitions.length
57+
case m: ShuffleMapStage => m.rdd.partitions.length
58+
}
59+
60+
/** Which partitions of the stage have finished */
3861
val finished = Array.fill[Boolean](numPartitions)(false)
62+
3963
var numFinished = 0
4064
}

0 commit comments

Comments
 (0)