Skip to content

Commit eb9a439

Browse files
viiryacloud-fan
authored andcommitted
[SPARK-34338][SQL] Report metrics from Datasource v2 scan
### What changes were proposed in this pull request? This patch proposes to leverage `CustomMetric`, `CustomTaskMetric` API to report custom metrics from DS v2 scan to Spark. ### Why are the changes needed? This is related to #31398. In SPARK-34297, we want to add a couple of metrics when reading from Kafka in SS. We need some public API change in DS v2 to make it possible. This extracts only DS v2 change and make it general for DS v2 instead of micro-batch DS v2 API. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit test. Implement a simple test DS v2 class locally and run it: ```scala scala> import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.execution.datasources.v2._ scala> classOf[CustomMetricDataSourceV2].getName res0: String = org.apache.spark.sql.execution.datasources.v2.CustomMetricDataSourceV2 scala> val df = spark.read.format(res0).load() df: org.apache.spark.sql.DataFrame = [i: int, j: int] scala> df.collect ``` <img width="703" alt="Screen Shot 2021-03-30 at 11 07 13 PM" src="https://user-images.githubusercontent.com/68855/113098080-d8a49800-91ac-11eb-8681-be408a0f2e69.png"> Closes #31451 from viirya/dsv2-metrics. Authored-by: Liang-Chi Hsieh <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 3614448 commit eb9a439

File tree

13 files changed

+308
-16
lines changed

13 files changed

+308
-16
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ public interface PartitionReader<T> extends Closeable {
5151
T get();
5252

5353
/**
54-
* Returns an array of custom task metrics. By default it returns empty array.
54+
* Returns an array of custom task metrics. By default it returns empty array. Note that it is
55+
* not recommended to put heavy logic in this method as it may affect reading performance.
5556
*/
5657
default CustomTaskMetric[] currentMetricsValues() {
5758
CustomTaskMetric[] NO_METRICS = {};

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ case class BatchScanExec(
4545
override lazy val readerFactory: PartitionReaderFactory = batch.createReaderFactory()
4646

4747
override lazy val inputRDD: RDD[InternalRow] = {
48-
new DataSourceRDD(sparkContext, partitions, readerFactory, supportsColumnar)
48+
new DataSourceRDD(sparkContext, partitions, readerFactory, supportsColumnar, customMetrics)
4949
}
5050

5151
override def doCanonicalize(): BatchScanExec = {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ContinuousScanExec.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ case class ContinuousScanExec(
5858
sqlContext.conf.continuousStreamingExecutorPollIntervalMs,
5959
partitions,
6060
schema,
61-
readerFactory.asInstanceOf[ContinuousPartitionReaderFactory])
61+
readerFactory.asInstanceOf[ContinuousPartitionReaderFactory],
62+
customMetrics)
6263
}
6364
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD
2626
import org.apache.spark.sql.catalyst.InternalRow
2727
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
2828
import org.apache.spark.sql.errors.QueryExecutionErrors
29+
import org.apache.spark.sql.execution.metric.SQLMetric
2930
import org.apache.spark.sql.vectorized.ColumnarBatch
3031

3132
class DataSourceRDDPartition(val index: Int, val inputPartition: InputPartition)
@@ -37,7 +38,8 @@ class DataSourceRDD(
3738
sc: SparkContext,
3839
@transient private val inputPartitions: Seq[InputPartition],
3940
partitionReaderFactory: PartitionReaderFactory,
40-
columnarReads: Boolean)
41+
columnarReads: Boolean,
42+
customMetrics: Map[String, SQLMetric])
4143
extends RDD[InternalRow](sc, Nil) {
4244

4345
override protected def getPartitions: Array[Partition] = {
@@ -55,11 +57,13 @@ class DataSourceRDD(
5557
val inputPartition = castPartition(split).inputPartition
5658
val (iter, reader) = if (columnarReads) {
5759
val batchReader = partitionReaderFactory.createColumnarReader(inputPartition)
58-
val iter = new MetricsBatchIterator(new PartitionIterator[ColumnarBatch](batchReader))
60+
val iter = new MetricsBatchIterator(
61+
new PartitionIterator[ColumnarBatch](batchReader, customMetrics))
5962
(iter, batchReader)
6063
} else {
6164
val rowReader = partitionReaderFactory.createReader(inputPartition)
62-
val iter = new MetricsRowIterator(new PartitionIterator[InternalRow](rowReader))
65+
val iter = new MetricsRowIterator(
66+
new PartitionIterator[InternalRow](rowReader, customMetrics))
6367
(iter, rowReader)
6468
}
6569
context.addTaskCompletionListener[Unit](_ => reader.close())
@@ -72,7 +76,9 @@ class DataSourceRDD(
7276
}
7377
}
7478

75-
private class PartitionIterator[T](reader: PartitionReader[T]) extends Iterator[T] {
79+
private class PartitionIterator[T](
80+
reader: PartitionReader[T],
81+
customMetrics: Map[String, SQLMetric]) extends Iterator[T] {
7682
private[this] var valuePrepared = false
7783

7884
override def hasNext: Boolean = {
@@ -86,6 +92,12 @@ private class PartitionIterator[T](reader: PartitionReader[T]) extends Iterator[
8692
if (!hasNext) {
8793
throw QueryExecutionErrors.endOfStreamError()
8894
}
95+
reader.currentMetricsValues.foreach { metric =>
96+
assert(customMetrics.contains(metric.name()),
97+
s"Custom metrics ${customMetrics.keys.mkString(", ")} do not contain the metric " +
98+
s"${metric.name()}")
99+
customMetrics(metric.name()).set(metric.value())
100+
}
89101
valuePrepared = false
90102
reader.get()
91103
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,14 @@ import org.apache.spark.util.Utils
3232

3333
trait DataSourceV2ScanExecBase extends LeafExecNode {
3434

35-
override lazy val metrics = Map(
36-
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
35+
lazy val customMetrics = scan.supportedCustomMetrics().map { customMetric =>
36+
customMetric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, customMetric)
37+
}.toMap
38+
39+
override lazy val metrics = {
40+
Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) ++
41+
customMetrics
42+
}
3743

3844
def scan: Scan
3945

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MicroBatchScanExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,6 @@ case class MicroBatchScanExec(
4646
override lazy val readerFactory: PartitionReaderFactory = stream.createReaderFactory()
4747

4848
override lazy val inputRDD: RDD[InternalRow] = {
49-
new DataSourceRDD(sparkContext, partitions, readerFactory, supportsColumnar)
49+
new DataSourceRDD(sparkContext, partitions, readerFactory, supportsColumnar, customMetrics)
5050
}
5151
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.metric
19+
20+
import java.text.NumberFormat
21+
import java.util.Locale
22+
23+
import org.apache.spark.sql.connector.CustomMetric
24+
25+
object CustomMetrics {
26+
private[spark] val V2_CUSTOM = "v2Custom"
27+
28+
/**
29+
* Given a class name, builds and returns a metric type for a V2 custom metric class
30+
* `CustomMetric`.
31+
*/
32+
def buildV2CustomMetricTypeName(customMetric: CustomMetric): String = {
33+
s"${V2_CUSTOM}_${customMetric.getClass.getCanonicalName}"
34+
}
35+
36+
/**
37+
* Given a V2 custom metric type name, this method parses it and returns the corresponding
38+
* `CustomMetric` class name.
39+
*/
40+
def parseV2CustomMetricType(metricType: String): Option[String] = {
41+
if (metricType.startsWith(s"${V2_CUSTOM}_")) {
42+
Some(metricType.drop(V2_CUSTOM.length + 1))
43+
} else {
44+
None
45+
}
46+
}
47+
}
48+
49+
/**
50+
* Built-in `CustomMetric` that sums up metric values.
51+
*/
52+
class CustomSumMetric extends CustomMetric {
53+
override def name(): String = "CustomSumMetric"
54+
55+
override def description(): String = "Sum up CustomMetric"
56+
57+
override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = {
58+
taskMetrics.sum.toString
59+
}
60+
}
61+
62+
/**
63+
* Built-in `CustomMetric` that computes average of metric values.
64+
*/
65+
class CustomAvgMetric extends CustomMetric {
66+
override def name(): String = "CustomAvgMetric"
67+
68+
override def description(): String = "Average CustomMetric"
69+
70+
override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = {
71+
val average = if (taskMetrics.isEmpty) {
72+
0.0
73+
} else {
74+
taskMetrics.sum.toDouble / taskMetrics.length
75+
}
76+
val numberFormat = NumberFormat.getNumberInstance(Locale.US)
77+
numberFormat.format(average)
78+
}
79+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import scala.concurrent.duration._
2424

2525
import org.apache.spark.SparkContext
2626
import org.apache.spark.scheduler.AccumulableInfo
27+
import org.apache.spark.sql.connector.CustomMetric
2728
import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates
2829
import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils}
2930

@@ -107,6 +108,15 @@ object SQLMetrics {
107108
acc
108109
}
109110

111+
/**
112+
* Create a metric to report data source v2 custom metric.
113+
*/
114+
def createV2CustomMetric(sc: SparkContext, customMetric: CustomMetric): SQLMetric = {
115+
val acc = new SQLMetric(CustomMetrics.buildV2CustomMetricTypeName(customMetric))
116+
acc.register(sc, name = Some(customMetric.name()), countFailedValues = false)
117+
acc
118+
}
119+
110120
/**
111121
* Create a metric to report the size information (including total, min, med, max) like data size,
112122
* spill size, etc.

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import org.apache.spark.rdd.RDD
2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.connector.read.InputPartition
2424
import org.apache.spark.sql.connector.read.streaming.ContinuousPartitionReaderFactory
25+
import org.apache.spark.sql.execution.metric.SQLMetric
2526
import org.apache.spark.sql.types.StructType
2627
import org.apache.spark.util.NextIterator
2728

@@ -52,7 +53,8 @@ class ContinuousDataSourceRDD(
5253
epochPollIntervalMs: Long,
5354
private val inputPartitions: Seq[InputPartition],
5455
schema: StructType,
55-
partitionReaderFactory: ContinuousPartitionReaderFactory)
56+
partitionReaderFactory: ContinuousPartitionReaderFactory,
57+
customMetrics: Map[String, SQLMetric])
5658
extends RDD[InternalRow](sc, Nil) {
5759

5860
override protected def getPartitions: Array[Partition] = {
@@ -88,8 +90,12 @@ class ContinuousDataSourceRDD(
8890
partition.queueReader
8991
}
9092

93+
val partitionReader = readerForPartition.getPartitionReader()
9194
new NextIterator[InternalRow] {
9295
override def getNext(): InternalRow = {
96+
partitionReader.currentMetricsValues.foreach { metric =>
97+
customMetrics(metric.name()).set(metric.value())
98+
}
9399
readerForPartition.next() match {
94100
case null =>
95101
finished = true

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.{SparkEnv, SparkException, TaskContext}
2626
import org.apache.spark.internal.Logging
2727
import org.apache.spark.sql.catalyst.InternalRow
2828
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
29+
import org.apache.spark.sql.connector.read.PartitionReader
2930
import org.apache.spark.sql.connector.read.streaming.{ContinuousPartitionReader, PartitionOffset}
3031
import org.apache.spark.sql.types.StructType
3132
import org.apache.spark.util.ThreadUtils
@@ -47,6 +48,8 @@ class ContinuousQueuedDataReader(
4748
// Important sequencing - we must get our starting point before the provider threads start running
4849
private var currentOffset: PartitionOffset = reader.getOffset
4950

51+
def getPartitionReader(): PartitionReader[InternalRow] = reader
52+
5053
/**
5154
* The record types in the read buffer.
5255
*/

0 commit comments

Comments
 (0)