@@ -24,6 +24,7 @@ import java.util.{Arrays, Comparator, Date, Locale}
2424import java .util .concurrent .ConcurrentHashMap
2525
2626import scala .collection .JavaConverters ._
27+ import scala .collection .mutable
2728import scala .util .control .NonFatal
2829
2930import com .google .common .primitives .Longs
@@ -148,13 +149,25 @@ class SparkHadoopUtil extends Logging {
148149 private [spark] def getFSBytesReadOnThreadCallback (): () => Long = {
149150 val f = () => FileSystem .getAllStatistics.asScala.map(_.getThreadStatistics.getBytesRead).sum
150151 val baseline = (Thread .currentThread().getId, f())
151- val bytesReadMap = new ConcurrentHashMap [Long , Long ]()
152152
153- () => {
154- bytesReadMap.put(Thread .currentThread().getId, f())
155- bytesReadMap.asScala.map { case (k, v) =>
156- v - (if (k == baseline._1) baseline._2 else 0 )
157- }.sum
153+ new Function0 [Long ] {
154+ private val bytesReadMap = new mutable.HashMap [Long , Long ]()
155+
156+ /**
157+ * Returns a function that can be called to calculate Hadoop FileSystem bytes read.
158+ * This function may be called in both spawned child threads and parent task thread (in
159+ * PythonRDD), and Hadoop FileSystem uses thread local variables to track the statistics.
160+ * So we need a map to track the bytes read from the child threads and parent thread,
161+ * summing them together to get the bytes read of this task.
162+ */
163+ override def apply (): Long = {
164+ bytesReadMap.synchronized {
165+ bytesReadMap.put(Thread .currentThread().getId, f())
166+ bytesReadMap.map { case (k, v) =>
167+ v - (if (k == baseline._1) baseline._2 else 0 )
168+ }.sum
169+ }
170+ }
158171 }
159172 }
160173
0 commit comments