Skip to content

Commit f1a615f

Browse files
author
Kostas Sakellis
committed
[SPARK-4092] [CORE] Fix InputMetrics for coalesce'd Rdds
When calculating the input metrics there was an assumption that one task only reads from one block - this is not true for some operations including coalesce. This patch simply increments the task's input metrics if previous ones existed of the same read method. A limitation to this patch is that if a task reads from two different blocks of different read methods, one will override the other.
1 parent a61eaed commit f1a615f

File tree

4 files changed

+102
-38
lines changed

4 files changed

+102
-38
lines changed

core/src/main/scala/org/apache/spark/CacheManager.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,14 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
4444
blockManager.get(key) match {
4545
case Some(blockResult) =>
4646
// Partition is already materialized, so just return its values
47+
val existingMetrics = context.taskMetrics.inputMetrics
48+
val prevBytesRead = existingMetrics
49+
.filter( _.readMethod == blockResult.inputMetrics.readMethod)
50+
.map(_.bytesRead)
51+
.getOrElse(0L)
52+
4753
context.taskMetrics.inputMetrics = Some(blockResult.inputMetrics)
54+
context.taskMetrics.inputMetrics.get.bytesRead += prevBytesRead
4855
new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]])
4956

5057
case None =>

core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,11 @@ class HadoopRDD[K, V](
213213
logInfo("Input split: " + split.inputSplit)
214214
val jobConf = getJobConf()
215215

216-
val inputMetrics = new InputMetrics(DataReadMethod.Hadoop)
216+
val readMethod = DataReadMethod.Hadoop
217+
val inputMetrics = context.taskMetrics.inputMetrics
218+
.filter(_.readMethod == readMethod)
219+
.getOrElse(new InputMetrics(readMethod))
220+
217221
// Find a function that will return the FileSystem bytes read by this thread. Do this before
218222
// creating RecordReader, because RecordReader's constructor might read some bytes
219223
val bytesReadCallback = if (split.inputSplit.value.isInstanceOf[FileSplit]) {
@@ -239,6 +243,8 @@ class HadoopRDD[K, V](
239243

240244
var recordsSinceMetricsUpdate = 0
241245

246+
val bytesReadAtStart = inputMetrics.bytesRead
247+
242248
override def getNext() = {
243249
try {
244250
finished = !reader.next(key, value)
@@ -252,7 +258,7 @@ class HadoopRDD[K, V](
252258
&& bytesReadCallback.isDefined) {
253259
recordsSinceMetricsUpdate = 0
254260
val bytesReadFn = bytesReadCallback.get
255-
inputMetrics.bytesRead = bytesReadFn()
261+
inputMetrics.bytesRead = bytesReadFn() + bytesReadAtStart
256262
} else {
257263
recordsSinceMetricsUpdate += 1
258264
}
@@ -264,12 +270,12 @@ class HadoopRDD[K, V](
264270
reader.close()
265271
if (bytesReadCallback.isDefined) {
266272
val bytesReadFn = bytesReadCallback.get
267-
inputMetrics.bytesRead = bytesReadFn()
273+
inputMetrics.bytesRead = bytesReadFn() + bytesReadAtStart
268274
} else if (split.inputSplit.value.isInstanceOf[FileSplit]) {
269275
// If we can't get the bytes read from the FS stats, fall back to the split size,
270276
// which may be inaccurate.
271277
try {
272-
inputMetrics.bytesRead = split.inputSplit.value.getLength
278+
inputMetrics.bytesRead = split.inputSplit.value.getLength + bytesReadAtStart
273279
context.taskMetrics.inputMetrics = Some(inputMetrics)
274280
} catch {
275281
case e: java.io.IOException =>

core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,11 @@ class NewHadoopRDD[K, V](
108108
logInfo("Input split: " + split.serializableHadoopSplit)
109109
val conf = confBroadcast.value.value
110110

111-
val inputMetrics = new InputMetrics(DataReadMethod.Hadoop)
111+
val readMethod = DataReadMethod.Hadoop
112+
val inputMetrics = context.taskMetrics.inputMetrics
113+
.filter(_.readMethod == readMethod)
114+
.getOrElse(new InputMetrics(readMethod))
115+
112116
// Find a function that will return the FileSystem bytes read by this thread. Do this before
113117
// creating RecordReader, because RecordReader's constructor might read some bytes
114118
val bytesReadCallback = if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) {
@@ -139,6 +143,8 @@ class NewHadoopRDD[K, V](
139143
var finished = false
140144
var recordsSinceMetricsUpdate = 0
141145

146+
val bytesReadAtStart = inputMetrics.bytesRead
147+
142148
override def hasNext: Boolean = {
143149
if (!finished && !havePair) {
144150
finished = !reader.nextKeyValue
@@ -158,7 +164,7 @@ class NewHadoopRDD[K, V](
158164
&& bytesReadCallback.isDefined) {
159165
recordsSinceMetricsUpdate = 0
160166
val bytesReadFn = bytesReadCallback.get
161-
inputMetrics.bytesRead = bytesReadFn()
167+
inputMetrics.bytesRead = bytesReadFn() + bytesReadAtStart
162168
} else {
163169
recordsSinceMetricsUpdate += 1
164170
}
@@ -173,12 +179,12 @@ class NewHadoopRDD[K, V](
173179
// Update metrics with final amount
174180
if (bytesReadCallback.isDefined) {
175181
val bytesReadFn = bytesReadCallback.get
176-
inputMetrics.bytesRead = bytesReadFn()
182+
inputMetrics.bytesRead = bytesReadFn() + bytesReadAtStart
177183
} else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) {
178184
// If we can't get the bytes read from the FS stats, fall back to the split size,
179185
// which may be inaccurate.
180186
try {
181-
inputMetrics.bytesRead = split.serializableHadoopSplit.value.getLength
187+
inputMetrics.bytesRead = split.serializableHadoopSplit.value.getLength + bytesReadAtStart
182188
context.taskMetrics.inputMetrics = Some(inputMetrics)
183189
} catch {
184190
case e: java.io.IOException =>

core/src/test/scala/org/apache/spark/metrics/InputMetricsSuite.scala

Lines changed: 75 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@
1717

1818
package org.apache.spark.metrics
1919

20+
import org.apache.hadoop.io.{Text, LongWritable}
21+
import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat}
2022
import org.scalatest.FunSuite
2123

24+
import org.apache.spark.util.Utils
2225
import org.apache.spark.SharedSparkContext
2326
import org.apache.spark.scheduler.{SparkListenerTaskEnd, SparkListener}
2427

@@ -27,50 +30,92 @@ import scala.collection.mutable.ArrayBuffer
2730
import java.io.{FileWriter, PrintWriter, File}
2831

2932
class InputMetricsSuite extends FunSuite with SharedSparkContext {
30-
test("input metrics when reading text file with single split") {
31-
val file = new File(getClass.getSimpleName + ".txt")
32-
val pw = new PrintWriter(new FileWriter(file))
33-
pw.println("some stuff")
34-
pw.println("some other stuff")
35-
pw.println("yet more stuff")
36-
pw.println("too much stuff")
33+
34+
@transient var tmpDir: File = _
35+
@transient var tmpFile: File = _
36+
@transient var tmpFilePath: String = _
37+
38+
override def beforeAll() {
39+
super.beforeAll()
40+
41+
tmpDir = Utils.createTempDir()
42+
val testTempDir = new File(tmpDir, "test")
43+
testTempDir.mkdir()
44+
45+
tmpFile = new File(testTempDir, getClass.getSimpleName + ".txt")
46+
val pw = new PrintWriter(new FileWriter(tmpFile))
47+
for (x <- 1 to 1000000) {
48+
pw.println("s")
49+
}
3750
pw.close()
38-
file.deleteOnExit()
3951

40-
val taskBytesRead = new ArrayBuffer[Long]()
41-
sc.addSparkListener(new SparkListener() {
42-
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
43-
taskBytesRead += taskEnd.taskMetrics.inputMetrics.get.bytesRead
44-
}
45-
})
46-
sc.textFile("file://" + file.getAbsolutePath, 2).count()
52+
// Path to tmpFile
53+
tmpFilePath = "file://" + tmpFile.getAbsolutePath
54+
}
4755

48-
// Wait for task end events to come in
49-
sc.listenerBus.waitUntilEmpty(500)
50-
assert(taskBytesRead.length == 2)
51-
assert(taskBytesRead.sum >= file.length())
56+
override def afterAll() {
57+
super.afterAll()
58+
Utils.deleteRecursively(tmpDir)
5259
}
5360

54-
test("input metrics when reading text file with multiple splits") {
55-
val file = new File(getClass.getSimpleName + ".txt")
56-
val pw = new PrintWriter(new FileWriter(file))
57-
for (i <- 0 until 10000) {
58-
pw.println("some stuff")
61+
test("input metrics for old hadoop with coalesce") {
62+
val bytesRead = runAndReturnBytesRead {
63+
sc.textFile(tmpFilePath, 4).count()
5964
}
60-
pw.close()
61-
file.deleteOnExit()
65+
val bytesRead2 = runAndReturnBytesRead {
66+
sc.textFile(tmpFilePath, 4).coalesce(2).count()
67+
}
68+
assert(bytesRead2 == bytesRead)
69+
assert(bytesRead2 >= tmpFile.length())
70+
}
71+
72+
test("input metrics with cache and coalesce") {
73+
// prime the cache manager
74+
val rdd = sc.textFile(tmpFilePath, 4).cache()
75+
rdd.collect()
76+
77+
val bytesRead = runAndReturnBytesRead {
78+
rdd.count()
79+
}
80+
val bytesRead2 = runAndReturnBytesRead {
81+
rdd.coalesce(4).count()
82+
}
83+
84+
// for count and coelesce, the same bytes should be read.
85+
assert(bytesRead2 >= bytesRead2)
86+
}
6287

88+
test("input metrics for new Hadoop API with coalesce") {
89+
val bytesRead = runAndReturnBytesRead {
90+
sc.newAPIHadoopFile(tmpFilePath, classOf[NewTextInputFormat], classOf[LongWritable],
91+
classOf[Text]).count()
92+
}
93+
val bytesRead2 = runAndReturnBytesRead {
94+
sc.newAPIHadoopFile(tmpFilePath, classOf[NewTextInputFormat], classOf[LongWritable],
95+
classOf[Text]).coalesce(5).count()
96+
}
97+
assert(bytesRead2 == bytesRead)
98+
assert(bytesRead >= tmpFile.length())
99+
}
100+
101+
test("input metrics when reading text file") {
102+
val bytesRead = runAndReturnBytesRead {
103+
sc.textFile(tmpFilePath, 2).count()
104+
}
105+
assert(bytesRead >= tmpFile.length())
106+
}
107+
108+
private def runAndReturnBytesRead(job : => Unit): Long = {
63109
val taskBytesRead = new ArrayBuffer[Long]()
64110
sc.addSparkListener(new SparkListener() {
65111
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
66112
taskBytesRead += taskEnd.taskMetrics.inputMetrics.get.bytesRead
67113
}
68114
})
69-
sc.textFile("file://" + file.getAbsolutePath, 2).count()
70115

71-
// Wait for task end events to come in
116+
job
117+
72118
sc.listenerBus.waitUntilEmpty(500)
73-
assert(taskBytesRead.length == 2)
74-
assert(taskBytesRead.sum >= file.length())
119+
taskBytesRead.sum
75120
}
76121
}

0 commit comments

Comments
 (0)