diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 71560f60f531c..63403b9577237 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -17,10 +17,15 @@ package org.apache.spark.sql.execution.datasources.v2 +import scala.language.existentials + import org.apache.spark._ +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} +import org.apache.spark.sql.vectorized.ColumnarBatch class DataSourceRDDPartition(val index: Int, val inputPartition: InputPartition) extends Partition with Serializable @@ -47,31 +52,16 @@ class DataSourceRDD( override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { val inputPartition = castPartition(split).inputPartition - val reader: PartitionReader[_] = if (columnarReads) { - partitionReaderFactory.createColumnarReader(inputPartition) + val (iter, reader) = if (columnarReads) { + val batchReader = partitionReaderFactory.createColumnarReader(inputPartition) + val iter = new MetricsBatchIterator(new PartitionIterator[ColumnarBatch](batchReader)) + (iter, batchReader) } else { - partitionReaderFactory.createReader(inputPartition) + val rowReader = partitionReaderFactory.createReader(inputPartition) + val iter = new MetricsRowIterator(new PartitionIterator[InternalRow](rowReader)) + (iter, rowReader) } - context.addTaskCompletionListener[Unit](_ => reader.close()) - val iter = new Iterator[Any] { - private[this] var valuePrepared = false - - override def hasNext: Boolean = { - if (!valuePrepared) { - valuePrepared = reader.next() - } - valuePrepared - } - - override def next(): Any = { - if (!hasNext) { - throw new java.util.NoSuchElementException("End of stream") - } - valuePrepared = false - reader.get() - } - } // TODO: SPARK-25083 remove the type erasure hack in data source scan new InterruptibleIterator(context, iter.asInstanceOf[Iterator[InternalRow]]) } @@ -80,3 +70,68 @@ class DataSourceRDD( castPartition(split).inputPartition.preferredLocations() } } + +private class PartitionIterator[T](reader: PartitionReader[T]) extends Iterator[T] { + private[this] var valuePrepared = false + + override def hasNext: Boolean = { + if (!valuePrepared) { + valuePrepared = reader.next() + } + valuePrepared + } + + override def next(): T = { + if (!hasNext) { + throw new java.util.NoSuchElementException("End of stream") + } + valuePrepared = false + reader.get() + } +} + +private class MetricsHandler extends Logging with Serializable { + private val inputMetrics = TaskContext.get().taskMetrics().inputMetrics + private val startingBytesRead = inputMetrics.bytesRead + private val getBytesRead = SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + + def updateMetrics(numRows: Int, force: Boolean = false): Unit = { + inputMetrics.incRecordsRead(numRows) + val shouldUpdateBytesRead = + inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0 + if (shouldUpdateBytesRead || force) { + inputMetrics.setBytesRead(startingBytesRead + getBytesRead()) + } + } +} + +private abstract class MetricsIterator[I](iter: Iterator[I]) extends Iterator[I] { + protected val metricsHandler = new MetricsHandler + + override def hasNext: Boolean = { + if (iter.hasNext) { + true + } else { + metricsHandler.updateMetrics(0, force = true) + false + } + } +} + +private class MetricsRowIterator( + iter: Iterator[InternalRow]) extends MetricsIterator[InternalRow](iter) { + override def next(): InternalRow = { + val item = iter.next + metricsHandler.updateMetrics(1) + item + } +} + +private class MetricsBatchIterator( + iter: Iterator[ColumnarBatch]) extends MetricsIterator[ColumnarBatch](iter) { + override def next(): ColumnarBatch = { + val batch: ColumnarBatch = iter.next + metricsHandler.updateMetrics(batch.numRows) + batch + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala index 073aed8206ed7..f1411b263c77b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala @@ -16,9 +16,12 @@ */ package org.apache.spark.sql.execution +import scala.collection.mutable + import org.apache.hadoop.fs.Path import org.apache.spark.SparkConf +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql.{DataFrame, QueryTest} import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan @@ -167,4 +170,33 @@ class DataSourceV2ScanExecRedactionSuite extends DataSourceScanRedactionTest { } } } + + test("SPARK-30362: test input metrics for DSV2") { + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { + Seq("json", "orc", "parquet").foreach { format => + withTempPath { path => + val dir = path.getCanonicalPath + spark.range(0, 10).write.format(format).save(dir) + val df = spark.read.format(format).load(dir) + val bytesReads = new mutable.ArrayBuffer[Long]() + val recordsRead = new mutable.ArrayBuffer[Long]() + val bytesReadListener = new SparkListener() { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + bytesReads += taskEnd.taskMetrics.inputMetrics.bytesRead + recordsRead += taskEnd.taskMetrics.inputMetrics.recordsRead + } + } + sparkContext.addSparkListener(bytesReadListener) + try { + df.collect() + sparkContext.listenerBus.waitUntilEmpty() + assert(bytesReads.sum > 0) + assert(recordsRead.sum == 10) + } finally { + sparkContext.removeSparkListener(bytesReadListener) + } + } + } + } + } }