Skip to content

Commit aeb79c7

Browse files
committed
Merge branch 'master' of github.com:apache/spark into handle-configs-bash
2 parents 2732ac0 + ec79063 commit aeb79c7

File tree

34 files changed

+662
-221
lines changed

34 files changed

+662
-221
lines changed

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,9 @@ object SparkEnv extends Logging {
156156
conf.set("spark.driver.port", boundPort.toString)
157157
}
158158

159-
// Create an instance of the class named by the given Java system property, or by
160-
// defaultClassName if the property is not set, and return it as a T
161-
def instantiateClass[T](propertyName: String, defaultClassName: String): T = {
162-
val name = conf.get(propertyName, defaultClassName)
163-
val cls = Class.forName(name, true, Utils.getContextOrSparkClassLoader)
159+
// Create an instance of the class with the given name, possibly initializing it with our conf
160+
def instantiateClass[T](className: String): T = {
161+
val cls = Class.forName(className, true, Utils.getContextOrSparkClassLoader)
164162
// Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just
165163
// SparkConf, then one taking no arguments
166164
try {
@@ -178,11 +176,17 @@ object SparkEnv extends Logging {
178176
}
179177
}
180178

181-
val serializer = instantiateClass[Serializer](
179+
// Create an instance of the class named by the given SparkConf property, or defaultClassName
180+
// if the property is not set, possibly initializing it with our conf
181+
def instantiateClassFromConf[T](propertyName: String, defaultClassName: String): T = {
182+
instantiateClass[T](conf.get(propertyName, defaultClassName))
183+
}
184+
185+
val serializer = instantiateClassFromConf[Serializer](
182186
"spark.serializer", "org.apache.spark.serializer.JavaSerializer")
183187
logDebug(s"Using serializer: ${serializer.getClass}")
184188

185-
val closureSerializer = instantiateClass[Serializer](
189+
val closureSerializer = instantiateClassFromConf[Serializer](
186190
"spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer")
187191

188192
def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
@@ -246,8 +250,13 @@ object SparkEnv extends Logging {
246250
"."
247251
}
248252

249-
val shuffleManager = instantiateClass[ShuffleManager](
250-
"spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager")
253+
// Let the user specify short names for shuffle managers
254+
val shortShuffleMgrNames = Map(
255+
"hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager",
256+
"sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager")
257+
val shuffleMgrName = conf.get("spark.shuffle.manager", "hash")
258+
val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName)
259+
val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass)
251260

252261
val shuffleMemoryManager = new ShuffleMemoryManager(conf)
253262

core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717

1818
package org.apache.spark.broadcast
1919

20-
import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream}
20+
import java.io.{ByteArrayOutputStream, ByteArrayInputStream, InputStream,
21+
ObjectInputStream, ObjectOutputStream, OutputStream}
2122

2223
import scala.reflect.ClassTag
2324
import scala.util.Random
2425

2526
import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
27+
import org.apache.spark.io.CompressionCodec
2628
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
27-
import org.apache.spark.util.Utils
2829

2930
/**
3031
* A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like
@@ -214,11 +215,15 @@ private[broadcast] object TorrentBroadcast extends Logging {
214215
private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024
215216
private var initialized = false
216217
private var conf: SparkConf = null
218+
private var compress: Boolean = false
219+
private var compressionCodec: CompressionCodec = null
217220

218221
def initialize(_isDriver: Boolean, conf: SparkConf) {
219222
TorrentBroadcast.conf = conf // TODO: we might have to fix it in tests
220223
synchronized {
221224
if (!initialized) {
225+
compress = conf.getBoolean("spark.broadcast.compress", true)
226+
compressionCodec = CompressionCodec.createCodec(conf)
222227
initialized = true
223228
}
224229
}
@@ -228,8 +233,13 @@ private[broadcast] object TorrentBroadcast extends Logging {
228233
initialized = false
229234
}
230235

231-
def blockifyObject[T](obj: T): TorrentInfo = {
232-
val byteArray = Utils.serialize[T](obj)
236+
def blockifyObject[T: ClassTag](obj: T): TorrentInfo = {
237+
val bos = new ByteArrayOutputStream()
238+
val out: OutputStream = if (compress) compressionCodec.compressedOutputStream(bos) else bos
239+
val ser = SparkEnv.get.serializer.newInstance()
240+
val serOut = ser.serializeStream(out)
241+
serOut.writeObject[T](obj).close()
242+
val byteArray = bos.toByteArray
233243
val bais = new ByteArrayInputStream(byteArray)
234244

235245
var blockNum = byteArray.length / BLOCK_SIZE
@@ -255,7 +265,7 @@ private[broadcast] object TorrentBroadcast extends Logging {
255265
info
256266
}
257267

258-
def unBlockifyObject[T](
268+
def unBlockifyObject[T: ClassTag](
259269
arrayOfBlocks: Array[TorrentBlock],
260270
totalBytes: Int,
261271
totalBlocks: Int): T = {
@@ -264,7 +274,16 @@ private[broadcast] object TorrentBroadcast extends Logging {
264274
System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
265275
i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length)
266276
}
267-
Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader)
277+
278+
val in: InputStream = {
279+
val arrIn = new ByteArrayInputStream(retByteArray)
280+
if (compress) compressionCodec.compressedInputStream(arrIn) else arrIn
281+
}
282+
val ser = SparkEnv.get.serializer.newInstance()
283+
val serIn = ser.deserializeStream(in)
284+
val obj = serIn.readObject[T]()
285+
serIn.close()
286+
obj
268287
}
269288

270289
/**

core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,6 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) {
228228

229229
/** Fill in values by parsing user options. */
230230
private def parseOpts(opts: Seq[String]): Unit = {
231-
var inSparkOpts = true
232231
val EQ_SEPARATED_OPT="""(--[^=]+)=(.+)""".r
233232

234233
// Delineates parsing of Spark options from parsing of user options.

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@ private[spark] class Executor(
374374
for (taskRunner <- runningTasks.values()) {
375375
if (!taskRunner.attemptedTask.isEmpty) {
376376
Option(taskRunner.task).flatMap(_.metrics).foreach { metrics =>
377+
metrics.updateShuffleReadMetrics
377378
tasksMetrics += ((taskRunner.taskId, metrics))
378379
}
379380
}

core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.executor
1919

20+
import scala.collection.mutable.ArrayBuffer
21+
2022
import org.apache.spark.annotation.DeveloperApi
2123
import org.apache.spark.storage.{BlockId, BlockStatus}
2224

@@ -81,12 +83,27 @@ class TaskMetrics extends Serializable {
8183
var inputMetrics: Option[InputMetrics] = None
8284

8385
/**
84-
* If this task reads from shuffle output, metrics on getting shuffle data will be collected here
86+
* If this task reads from shuffle output, metrics on getting shuffle data will be collected here.
87+
* This includes read metrics aggregated over all the task's shuffle dependencies.
8588
*/
8689
private var _shuffleReadMetrics: Option[ShuffleReadMetrics] = None
8790

8891
def shuffleReadMetrics = _shuffleReadMetrics
8992

93+
/**
94+
* This should only be used when recreating TaskMetrics, not when updating read metrics in
95+
* executors.
96+
*/
97+
private[spark] def setShuffleReadMetrics(shuffleReadMetrics: Option[ShuffleReadMetrics]) {
98+
_shuffleReadMetrics = shuffleReadMetrics
99+
}
100+
101+
/**
102+
* ShuffleReadMetrics per dependency for collecting independently while task is in progress.
103+
*/
104+
@transient private lazy val depsShuffleReadMetrics: ArrayBuffer[ShuffleReadMetrics] =
105+
new ArrayBuffer[ShuffleReadMetrics]()
106+
90107
/**
91108
* If this task writes to shuffle output, metrics on the written shuffle data will be collected
92109
* here
@@ -98,19 +115,31 @@ class TaskMetrics extends Serializable {
98115
*/
99116
var updatedBlocks: Option[Seq[(BlockId, BlockStatus)]] = None
100117

101-
/** Adds the given ShuffleReadMetrics to any existing shuffle metrics for this task. */
102-
def updateShuffleReadMetrics(newMetrics: ShuffleReadMetrics) = synchronized {
103-
_shuffleReadMetrics match {
104-
case Some(existingMetrics) =>
105-
existingMetrics.shuffleFinishTime = math.max(
106-
existingMetrics.shuffleFinishTime, newMetrics.shuffleFinishTime)
107-
existingMetrics.fetchWaitTime += newMetrics.fetchWaitTime
108-
existingMetrics.localBlocksFetched += newMetrics.localBlocksFetched
109-
existingMetrics.remoteBlocksFetched += newMetrics.remoteBlocksFetched
110-
existingMetrics.remoteBytesRead += newMetrics.remoteBytesRead
111-
case None =>
112-
_shuffleReadMetrics = Some(newMetrics)
118+
/**
119+
* A task may have multiple shuffle readers for multiple dependencies. To avoid synchronization
120+
* issues from readers in different threads, in-progress tasks use a ShuffleReadMetrics for each
121+
* dependency, and merge these metrics before reporting them to the driver. This method returns
122+
* a ShuffleReadMetrics for a dependency and registers it for merging later.
123+
*/
124+
private [spark] def createShuffleReadMetricsForDependency(): ShuffleReadMetrics = synchronized {
125+
val readMetrics = new ShuffleReadMetrics()
126+
depsShuffleReadMetrics += readMetrics
127+
readMetrics
128+
}
129+
130+
/**
131+
* Aggregates shuffle read metrics for all registered dependencies into shuffleReadMetrics.
132+
*/
133+
private[spark] def updateShuffleReadMetrics() = synchronized {
134+
val merged = new ShuffleReadMetrics()
135+
for (depMetrics <- depsShuffleReadMetrics) {
136+
merged.fetchWaitTime += depMetrics.fetchWaitTime
137+
merged.localBlocksFetched += depMetrics.localBlocksFetched
138+
merged.remoteBlocksFetched += depMetrics.remoteBlocksFetched
139+
merged.remoteBytesRead += depMetrics.remoteBytesRead
140+
merged.shuffleFinishTime = math.max(merged.shuffleFinishTime, depMetrics.shuffleFinishTime)
113141
}
142+
_shuffleReadMetrics = Some(merged)
114143
}
115144
}
116145

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,6 +1233,11 @@ abstract class RDD[T: ClassTag](
12331233
dependencies.head.rdd.asInstanceOf[RDD[U]]
12341234
}
12351235

1236+
/** Returns the jth parent RDD: e.g. rdd.parent[T](0) is equivalent to rdd.firstParent[T] */
1237+
protected[spark] def parent[U: ClassTag](j: Int) = {
1238+
dependencies(j).rdd.asInstanceOf[RDD[U]]
1239+
}
1240+
12361241
/** The [[org.apache.spark.SparkContext]] that this RDD was created on. */
12371242
def context = sc
12381243

core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
3232
shuffleId: Int,
3333
reduceId: Int,
3434
context: TaskContext,
35-
serializer: Serializer)
35+
serializer: Serializer,
36+
shuffleMetrics: ShuffleReadMetrics)
3637
: Iterator[T] =
3738
{
3839
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
@@ -73,17 +74,11 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
7374
}
7475
}
7576

76-
val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
77+
val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer, shuffleMetrics)
7778
val itr = blockFetcherItr.flatMap(unpackBlock)
7879

7980
val completionIter = CompletionIterator[T, Iterator[T]](itr, {
80-
val shuffleMetrics = new ShuffleReadMetrics
81-
shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
82-
shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime
83-
shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead
84-
shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks
85-
shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks
86-
context.taskMetrics.updateShuffleReadMetrics(shuffleMetrics)
81+
context.taskMetrics.updateShuffleReadMetrics()
8782
})
8883

8984
new InterruptibleIterator[T](context, completionIter)

core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,10 @@ private[spark] class HashShuffleReader[K, C](
3636

3737
/** Read the combined key-values for this reduce task */
3838
override def read(): Iterator[Product2[K, C]] = {
39+
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
3940
val ser = Serializer.getSerializer(dep.serializer)
40-
val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)
41+
val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser,
42+
readMetrics)
4143

4244
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
4345
if (dep.mapSideCombine) {
@@ -58,7 +60,7 @@ private[spark] class HashShuffleReader[K, C](
5860
// Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
5961
// the ExternalSorter won't spill to disk.
6062
val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
61-
sorter.write(aggregatedIter)
63+
sorter.insertAll(aggregatedIter)
6264
context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled
6365
context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled
6466
sorter.iterator

0 commit comments

Comments
 (0)