Skip to content

Commit 5854f77

Browse files
jerryshaocloud-fan
authored andcommitted
[SPARK-20244][CORE] Handle incorrect bytesRead metrics when using PySpark
## What changes were proposed in this pull request? Hadoop FileSystem's statistics in based on thread local variables, this is ok if the RDD computation chain is running in the same thread. But if child RDD creates another thread to consume the iterator got from Hadoop RDDs, the bytesRead computation will be error, because now the iterator's `next()` and `close()` may run in different threads. This could be happened when using PySpark with PythonRDD. So here building a map to track the `bytesRead` for different thread and add them together. This method will be used in three RDDs, `HadoopRDD`, `NewHadoopRDD` and `FileScanRDD`. I assume `FileScanRDD` cannot be called directly, so I only fixed `HadoopRDD` and `NewHadoopRDD`. ## How was this patch tested? Unit test and local cluster verification. Author: jerryshao <[email protected]> Closes #17617 from jerryshao/SPARK-20244.
1 parent 24db358 commit 5854f77

File tree

4 files changed

+66
-9
lines changed

4 files changed

+66
-9
lines changed

core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import java.text.DateFormat
2323
import java.util.{Arrays, Comparator, Date, Locale}
2424

2525
import scala.collection.JavaConverters._
26+
import scala.collection.mutable
2627
import scala.util.control.NonFatal
2728

2829
import com.google.common.primitives.Longs
@@ -143,14 +144,29 @@ class SparkHadoopUtil extends Logging {
143144
* Returns a function that can be called to find Hadoop FileSystem bytes read. If
144145
* getFSBytesReadOnThreadCallback is called from thread r at time t, the returned callback will
145146
* return the bytes read on r since t.
146-
*
147-
* @return None if the required method can't be found.
148147
*/
149148
private[spark] def getFSBytesReadOnThreadCallback(): () => Long = {
150-
val threadStats = FileSystem.getAllStatistics.asScala.map(_.getThreadStatistics)
151-
val f = () => threadStats.map(_.getBytesRead).sum
152-
val baselineBytesRead = f()
153-
() => f() - baselineBytesRead
149+
val f = () => FileSystem.getAllStatistics.asScala.map(_.getThreadStatistics.getBytesRead).sum
150+
val baseline = (Thread.currentThread().getId, f())
151+
152+
/**
153+
* This function may be called in both spawned child threads and parent task thread (in
154+
* PythonRDD), and Hadoop FileSystem uses thread local variables to track the statistics.
155+
* So we need a map to track the bytes read from the child threads and parent thread,
156+
* summing them together to get the bytes read of this task.
157+
*/
158+
new Function0[Long] {
159+
private val bytesReadMap = new mutable.HashMap[Long, Long]()
160+
161+
override def apply(): Long = {
162+
bytesReadMap.synchronized {
163+
bytesReadMap.put(Thread.currentThread().getId, f())
164+
bytesReadMap.map { case (k, v) =>
165+
v - (if (k == baseline._1) baseline._2 else 0)
166+
}.sum
167+
}
168+
}
169+
}
154170
}
155171

156172
/**

core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,13 @@ class HadoopRDD[K, V](
251251
null
252252
}
253253
// Register an on-task-completion callback to close the input stream.
254-
context.addTaskCompletionListener{ context => closeIfNeeded() }
254+
context.addTaskCompletionListener { context =>
255+
// Update the bytes read before closing is to make sure lingering bytesRead statistics in
256+
// this thread get correctly added.
257+
updateBytesRead()
258+
closeIfNeeded()
259+
}
260+
255261
private val key: K = if (reader == null) null.asInstanceOf[K] else reader.createKey()
256262
private val value: V = if (reader == null) null.asInstanceOf[V] else reader.createValue()
257263

core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,13 @@ class NewHadoopRDD[K, V](
191191
}
192192

193193
// Register an on-task-completion callback to close the input stream.
194-
context.addTaskCompletionListener(context => close())
194+
context.addTaskCompletionListener { context =>
195+
// Update the bytesRead before closing is to make sure lingering bytesRead statistics in
196+
// this thread get correctly added.
197+
updateBytesRead()
198+
close()
199+
}
200+
195201
private var havePair = false
196202
private var recordsSinceMetricsUpdate = 0
197203

core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import org.scalatest.BeforeAndAfter
3434

3535
import org.apache.spark.{SharedSparkContext, SparkFunSuite}
3636
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
37-
import org.apache.spark.util.Utils
37+
import org.apache.spark.util.{ThreadUtils, Utils}
3838

3939
class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext
4040
with BeforeAndAfter {
@@ -319,6 +319,35 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext
319319
}
320320
assert(bytesRead >= tmpFile.length())
321321
}
322+
323+
test("input metrics with old Hadoop API in different thread") {
324+
val bytesRead = runAndReturnBytesRead {
325+
sc.textFile(tmpFilePath, 4).mapPartitions { iter =>
326+
val buf = new ArrayBuffer[String]()
327+
ThreadUtils.runInNewThread("testThread", false) {
328+
iter.flatMap(_.split(" ")).foreach(buf.append(_))
329+
}
330+
331+
buf.iterator
332+
}.count()
333+
}
334+
assert(bytesRead >= tmpFile.length())
335+
}
336+
337+
test("input metrics with new Hadoop API in different thread") {
338+
val bytesRead = runAndReturnBytesRead {
339+
sc.newAPIHadoopFile(tmpFilePath, classOf[NewTextInputFormat], classOf[LongWritable],
340+
classOf[Text]).mapPartitions { iter =>
341+
val buf = new ArrayBuffer[String]()
342+
ThreadUtils.runInNewThread("testThread", false) {
343+
iter.map(_._2.toString).flatMap(_.split(" ")).foreach(buf.append(_))
344+
}
345+
346+
buf.iterator
347+
}.count()
348+
}
349+
assert(bytesRead >= tmpFile.length())
350+
}
322351
}
323352

324353
/**

0 commit comments

Comments
 (0)