From 5bd4b7b3c922b580317c3df466e454b3dc3d66f7 Mon Sep 17 00:00:00 2001 From: sandeep katta Date: Mon, 27 Jan 2020 12:07:14 +0530 Subject: [PATCH 1/3] Update inputmetrics in DataSourceRDD --- .../datasources/v2/DataSourceRDD.scala | 96 ++++++++++++++----- .../DataSourceScanExecRedactionSuite.scala | 30 ++++++ 2 files changed, 104 insertions(+), 22 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 71560f60f531c..234718bec12ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -17,10 +17,15 @@ package org.apache.spark.sql.execution.datasources.v2 +import scala.language.existentials + import org.apache.spark._ +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} +import org.apache.spark.sql.vectorized.ColumnarBatch class DataSourceRDDPartition(val index: Int, val inputPartition: InputPartition) extends Partition with Serializable @@ -47,31 +52,16 @@ class DataSourceRDD( override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { val inputPartition = castPartition(split).inputPartition - val reader: PartitionReader[_] = if (columnarReads) { - partitionReaderFactory.createColumnarReader(inputPartition) + val (iter, reader) = if (columnarReads) { + val batchReader = partitionReaderFactory.createColumnarReader(inputPartition) + val iter = new MetricsBatchIterator(new PartitionIterator[ColumnarBatch](batchReader)) + (iter, batchReader) } else { - partitionReaderFactory.createReader(inputPartition) + val rowReader = partitionReaderFactory.createReader(inputPartition) + val iter = new MetricsIterator(new PartitionIterator[InternalRow](rowReader)) + (iter, rowReader) } - context.addTaskCompletionListener[Unit](_ => reader.close()) - val iter = new Iterator[Any] { - private[this] var valuePrepared = false - - override def hasNext: Boolean = { - if (!valuePrepared) { - valuePrepared = reader.next() - } - valuePrepared - } - - override def next(): Any = { - if (!hasNext) { - throw new java.util.NoSuchElementException("End of stream") - } - valuePrepared = false - reader.get() - } - } // TODO: SPARK-25083 remove the type erasure hack in data source scan new InterruptibleIterator(context, iter.asInstanceOf[Iterator[InternalRow]]) } @@ -80,3 +70,65 @@ class DataSourceRDD( castPartition(split).inputPartition.preferredLocations() } } + +class PartitionIterator[T](reader: PartitionReader[T]) extends Iterator[T] { + private[this] var valuePrepared = false + + override def hasNext: Boolean = { + if (!valuePrepared) { + valuePrepared = reader.next() + } + valuePrepared + } + + override def next(): T = { + if (!hasNext) { + throw new java.util.NoSuchElementException("End of stream") + } + valuePrepared = false + reader.get() + } +} + +class MetricsHandler extends Logging with Serializable { + private val inputMetrics = TaskContext.get().taskMetrics().inputMetrics + private val startingBytesRead = inputMetrics.bytesRead + private val getBytesRead = SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + + def updateMetrics(numRows: Int, force: Boolean = false): Unit = { + inputMetrics.incRecordsRead(numRows) + val shouldUpdateBytesRead = + inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0 + if (shouldUpdateBytesRead || force) { + inputMetrics.setBytesRead(startingBytesRead + getBytesRead()) + } + } +} + +private class MetricsIterator[I](iter: Iterator[I]) extends Iterator[I] { + protected val metricsHandler = new MetricsHandler + + override def hasNext: Boolean = { + if (iter.hasNext) { + true + } else { + metricsHandler.updateMetrics(0, force = true) + false + } + } + + override def next(): I = { + val item = iter.next + metricsHandler.updateMetrics(1) + item + } +} + +private class MetricsBatchIterator( + iter: Iterator[ColumnarBatch]) extends MetricsIterator[ColumnarBatch](iter) { + override def next(): ColumnarBatch = { + val batch: ColumnarBatch = iter.next + metricsHandler.updateMetrics(batch.numRows) + batch + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala index 073aed8206ed7..9bd2af27d0a5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala @@ -16,9 +16,12 @@ */ package org.apache.spark.sql.execution +import scala.collection.mutable + import org.apache.hadoop.fs.Path import org.apache.spark.SparkConf +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql.{DataFrame, QueryTest} import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan @@ -167,4 +170,31 @@ class DataSourceV2ScanExecRedactionSuite extends DataSourceScanRedactionTest { } } } + + test("SPARK-30362: test input metrics for DSV2") { + Seq("json", "orc", "parquet").foreach { format => + withTempPath { path => + val dir = path.getCanonicalPath + spark.range(0, 10).write.format(format).save(dir) + val df = spark.read.format(format).load(dir) + val bytesReads = new mutable.ArrayBuffer[Long]() + val recordsRead = new mutable.ArrayBuffer[Long]() + val bytesReadListener = new SparkListener() { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + bytesReads += taskEnd.taskMetrics.inputMetrics.bytesRead + recordsRead += taskEnd.taskMetrics.inputMetrics.recordsRead + } + } + sparkContext.addSparkListener(bytesReadListener) + try { + df.collect() + sparkContext.listenerBus.waitUntilEmpty() + assert(bytesReads.sum > 0) + assert(recordsRead.sum == 10) + } finally { + sparkContext.removeSparkListener(bytesReadListener) + } + } + } + } } From 532fd1ffbddcea866c1ff03724862126cfc7bea0 Mon Sep 17 00:00:00 2001 From: sandeep katta Date: Tue, 28 Jan 2020 10:07:38 +0530 Subject: [PATCH 2/3] fixed review comments --- .../spark/sql/execution/datasources/v2/DataSourceRDD.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 234718bec12ae..e0438ead7bbbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -71,7 +71,7 @@ class DataSourceRDD( } } -class PartitionIterator[T](reader: PartitionReader[T]) extends Iterator[T] { +private class PartitionIterator[T](reader: PartitionReader[T]) extends Iterator[T] { private[this] var valuePrepared = false override def hasNext: Boolean = { @@ -90,7 +90,7 @@ class PartitionIterator[T](reader: PartitionReader[T]) extends Iterator[T] { } } -class MetricsHandler extends Logging with Serializable { +private class MetricsHandler extends Logging with Serializable { private val inputMetrics = TaskContext.get().taskMetrics().inputMetrics private val startingBytesRead = inputMetrics.bytesRead private val getBytesRead = SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() From 34425ad1eae32fe1ca177ad359685c410a467760 Mon Sep 17 00:00:00 2001 From: sandeep katta Date: Thu, 30 Jan 2020 19:28:36 +0530 Subject: [PATCH 3/3] fixed review comments --- .../datasources/v2/DataSourceRDD.scala | 9 ++-- .../DataSourceScanExecRedactionSuite.scala | 42 ++++++++++--------- 2 files changed, 28 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index e0438ead7bbbc..63403b9577237 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -58,7 +58,7 @@ class DataSourceRDD( (iter, batchReader) } else { val rowReader = partitionReaderFactory.createReader(inputPartition) - val iter = new MetricsIterator(new PartitionIterator[InternalRow](rowReader)) + val iter = new MetricsRowIterator(new PartitionIterator[InternalRow](rowReader)) (iter, rowReader) } context.addTaskCompletionListener[Unit](_ => reader.close()) @@ -105,7 +105,7 @@ private class MetricsHandler extends Logging with Serializable { } } -private class MetricsIterator[I](iter: Iterator[I]) extends Iterator[I] { +private abstract class MetricsIterator[I](iter: Iterator[I]) extends Iterator[I] { protected val metricsHandler = new MetricsHandler override def hasNext: Boolean = { @@ -116,8 +116,11 @@ private class MetricsIterator[I](iter: Iterator[I]) extends Iterator[I] { false } } +} - override def next(): I = { +private class MetricsRowIterator( + iter: Iterator[InternalRow]) extends MetricsIterator[InternalRow](iter) { + override def next(): InternalRow = { val item = iter.next metricsHandler.updateMetrics(1) item diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala index 9bd2af27d0a5a..f1411b263c77b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala @@ -172,27 +172,29 @@ class DataSourceV2ScanExecRedactionSuite extends DataSourceScanRedactionTest { } test("SPARK-30362: test input metrics for DSV2") { - Seq("json", "orc", "parquet").foreach { format => - withTempPath { path => - val dir = path.getCanonicalPath - spark.range(0, 10).write.format(format).save(dir) - val df = spark.read.format(format).load(dir) - val bytesReads = new mutable.ArrayBuffer[Long]() - val recordsRead = new mutable.ArrayBuffer[Long]() - val bytesReadListener = new SparkListener() { - override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { - bytesReads += taskEnd.taskMetrics.inputMetrics.bytesRead - recordsRead += taskEnd.taskMetrics.inputMetrics.recordsRead + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { + Seq("json", "orc", "parquet").foreach { format => + withTempPath { path => + val dir = path.getCanonicalPath + spark.range(0, 10).write.format(format).save(dir) + val df = spark.read.format(format).load(dir) + val bytesReads = new mutable.ArrayBuffer[Long]() + val recordsRead = new mutable.ArrayBuffer[Long]() + val bytesReadListener = new SparkListener() { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + bytesReads += taskEnd.taskMetrics.inputMetrics.bytesRead + recordsRead += taskEnd.taskMetrics.inputMetrics.recordsRead + } + } + sparkContext.addSparkListener(bytesReadListener) + try { + df.collect() + sparkContext.listenerBus.waitUntilEmpty() + assert(bytesReads.sum > 0) + assert(recordsRead.sum == 10) + } finally { + sparkContext.removeSparkListener(bytesReadListener) } - } - sparkContext.addSparkListener(bytesReadListener) - try { - df.collect() - sparkContext.listenerBus.waitUntilEmpty() - assert(bytesReads.sum > 0) - assert(recordsRead.sum == 10) - } finally { - sparkContext.removeSparkListener(bytesReadListener) } } }