diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java index c91f2b4bf38f..d6cf070cf4c8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java @@ -51,7 +51,8 @@ public interface PartitionReader extends Closeable { T get(); /** - * Returns an array of custom task metrics. By default it returns empty array. + * Returns an array of custom task metrics. By default it returns empty array. Note that it is + * not recommended to put heavy logic in this method as it may affect reading performance. */ default CustomTaskMetric[] currentMetricsValues() { CustomTaskMetric[] NO_METRICS = {}; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index c199df676ced..1987c9e63a64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -45,7 +45,7 @@ case class BatchScanExec( override lazy val readerFactory: PartitionReaderFactory = batch.createReaderFactory() override lazy val inputRDD: RDD[InternalRow] = { - new DataSourceRDD(sparkContext, partitions, readerFactory, supportsColumnar) + new DataSourceRDD(sparkContext, partitions, readerFactory, supportsColumnar, customMetrics) } override def doCanonicalize(): BatchScanExec = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ContinuousScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ContinuousScanExec.scala index dc95d157e40f..fea89c581e18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ContinuousScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ContinuousScanExec.scala @@ -58,6 +58,7 @@ case class ContinuousScanExec( sqlContext.conf.continuousStreamingExecutorPollIntervalMs, partitions, schema, - readerFactory.asInstanceOf[ContinuousPartitionReaderFactory]) + readerFactory.asInstanceOf[ContinuousPartitionReaderFactory], + customMetrics) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 5724be7d0591..7850dfa39d16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.vectorized.ColumnarBatch class DataSourceRDDPartition(val index: Int, val inputPartition: InputPartition) @@ -37,7 +38,8 @@ class DataSourceRDD( sc: SparkContext, @transient private val inputPartitions: Seq[InputPartition], partitionReaderFactory: PartitionReaderFactory, - columnarReads: Boolean) + columnarReads: Boolean, + customMetrics: Map[String, SQLMetric]) extends RDD[InternalRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { @@ -55,11 +57,13 @@ class DataSourceRDD( val inputPartition = castPartition(split).inputPartition val (iter, reader) = if (columnarReads) { val batchReader = partitionReaderFactory.createColumnarReader(inputPartition) - val iter = new MetricsBatchIterator(new PartitionIterator[ColumnarBatch](batchReader)) + val iter = new MetricsBatchIterator( + new PartitionIterator[ColumnarBatch](batchReader, customMetrics)) (iter, batchReader) } else { val rowReader = partitionReaderFactory.createReader(inputPartition) - val iter = new MetricsRowIterator(new PartitionIterator[InternalRow](rowReader)) + val iter = new MetricsRowIterator( + new PartitionIterator[InternalRow](rowReader, customMetrics)) (iter, rowReader) } context.addTaskCompletionListener[Unit](_ => reader.close()) @@ -72,7 +76,9 @@ class DataSourceRDD( } } -private class PartitionIterator[T](reader: PartitionReader[T]) extends Iterator[T] { +private class PartitionIterator[T]( + reader: PartitionReader[T], + customMetrics: Map[String, SQLMetric]) extends Iterator[T] { private[this] var valuePrepared = false override def hasNext: Boolean = { @@ -86,6 +92,12 @@ private class PartitionIterator[T](reader: PartitionReader[T]) extends Iterator[ if (!hasNext) { throw QueryExecutionErrors.endOfStreamError() } + reader.currentMetricsValues.foreach { metric => + assert(customMetrics.contains(metric.name()), + s"Custom metrics ${customMetrics.keys.mkString(", ")} do not contain the metric " + + s"${metric.name()}") + customMetrics(metric.name()).set(metric.value()) + } valuePrepared = false reader.get() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala index 083c6bc7999b..1248f89b2bdd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala @@ -32,8 +32,14 @@ import org.apache.spark.util.Utils trait DataSourceV2ScanExecBase extends LeafExecNode { - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + lazy val customMetrics = scan.supportedCustomMetrics().map { customMetric => + customMetric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, customMetric) + }.toMap + + override lazy val metrics = { + Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) ++ + customMetrics + } def scan: Scan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MicroBatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MicroBatchScanExec.scala index bca28e3cacb6..1430a32c8e81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MicroBatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MicroBatchScanExec.scala @@ -46,6 +46,6 @@ case class MicroBatchScanExec( override lazy val readerFactory: PartitionReaderFactory = stream.createReaderFactory() override lazy val inputRDD: RDD[InternalRow] = { - new DataSourceRDD(sparkContext, partitions, readerFactory, supportsColumnar) + new DataSourceRDD(sparkContext, partitions, readerFactory, supportsColumnar, customMetrics) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala new file mode 100644 index 000000000000..3cb20f87ae63 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.metric + +import java.text.NumberFormat +import java.util.Locale + +import org.apache.spark.sql.connector.CustomMetric + +object CustomMetrics { + private[spark] val V2_CUSTOM = "v2Custom" + + /** + * Given a class name, builds and returns a metric type for a V2 custom metric class + * `CustomMetric`. + */ + def buildV2CustomMetricTypeName(customMetric: CustomMetric): String = { + s"${V2_CUSTOM}_${customMetric.getClass.getCanonicalName}" + } + + /** + * Given a V2 custom metric type name, this method parses it and returns the corresponding + * `CustomMetric` class name. + */ + def parseV2CustomMetricType(metricType: String): Option[String] = { + if (metricType.startsWith(s"${V2_CUSTOM}_")) { + Some(metricType.drop(V2_CUSTOM.length + 1)) + } else { + None + } + } +} + +/** + * Built-in `CustomMetric` that sums up metric values. + */ +class CustomSumMetric extends CustomMetric { + override def name(): String = "CustomSumMetric" + + override def description(): String = "Sum up CustomMetric" + + override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = { + taskMetrics.sum.toString + } +} + +/** + * Built-in `CustomMetric` that computes average of metric values. + */ +class CustomAvgMetric extends CustomMetric { + override def name(): String = "CustomAvgMetric" + + override def description(): String = "Average CustomMetric" + + override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = { + val average = if (taskMetrics.isEmpty) { + 0.0 + } else { + taskMetrics.sum.toDouble / taskMetrics.length + } + val numberFormat = NumberFormat.getNumberInstance(Locale.US) + numberFormat.format(average) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index b1705d06344f..da39e8c455e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -24,6 +24,7 @@ import scala.concurrent.duration._ import org.apache.spark.SparkContext import org.apache.spark.scheduler.AccumulableInfo +import org.apache.spark.sql.connector.CustomMetric import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils} @@ -107,6 +108,15 @@ object SQLMetrics { acc } + /** + * Create a metric to report data source v2 custom metric. + */ + def createV2CustomMetric(sc: SparkContext, customMetric: CustomMetric): SQLMetric = { + val acc = new SQLMetric(CustomMetrics.buildV2CustomMetricTypeName(customMetric)) + acc.register(sc, name = Some(customMetric.name()), countFailedValues = false) + acc + } + /** * Create a metric to report the size information (including total, min, med, max) like data size, * spill size, etc. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala index 5ee27c71aa73..4e32cefbe31a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala @@ -22,6 +22,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.connector.read.streaming.ContinuousPartitionReaderFactory +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.StructType import org.apache.spark.util.NextIterator @@ -52,7 +53,8 @@ class ContinuousDataSourceRDD( epochPollIntervalMs: Long, private val inputPartitions: Seq[InputPartition], schema: StructType, - partitionReaderFactory: ContinuousPartitionReaderFactory) + partitionReaderFactory: ContinuousPartitionReaderFactory, + customMetrics: Map[String, SQLMetric]) extends RDD[InternalRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { @@ -88,8 +90,12 @@ class ContinuousDataSourceRDD( partition.queueReader } + val partitionReader = readerForPartition.getPartitionReader() new NextIterator[InternalRow] { override def getNext(): InternalRow = { + partitionReader.currentMetricsValues.foreach { metric => + customMetrics(metric.name()).set(metric.value()) + } readerForPartition.next() match { case null => finished = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala index dff2fa69e42f..02893f274902 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -26,6 +26,7 @@ import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.connector.read.PartitionReader import org.apache.spark.sql.connector.read.streaming.{ContinuousPartitionReader, PartitionOffset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.ThreadUtils @@ -47,6 +48,8 @@ class ContinuousQueuedDataReader( // Important sequencing - we must get our starting point before the provider threads start running private var currentOffset: PartitionOffset = reader.getOffset + def getPartitionReader(): PartitionReader[InternalRow] = reader + /** * The record types in the read buffer. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 47faa248a552..a3238551b2fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -22,16 +22,19 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.util.control.NonFatal import org.apache.spark.{JobExecutionStatus, SparkConf} import org.apache.spark.internal.Logging import org.apache.spark.internal.config.Status._ import org.apache.spark.scheduler._ +import org.apache.spark.sql.connector.CustomMetric import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.metric._ import org.apache.spark.sql.internal.StaticSQLConf._ import org.apache.spark.status.{ElementTrackingStore, KVUtils, LiveEntity} +import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashMap class SQLAppStatusListener( @@ -199,7 +202,37 @@ class SQLAppStatusListener( } private def aggregateMetrics(exec: LiveExecutionData): Map[Long, String] = { - val metricTypes = exec.metrics.map { m => (m.accumulatorId, m.metricType) }.toMap + val accumIds = exec.metrics.map(_.accumulatorId).toSet + + val metricAggregationMap = new mutable.HashMap[String, (Array[Long], Array[Long]) => String]() + val metricAggregationMethods = exec.metrics.map { m => + val optClassName = CustomMetrics.parseV2CustomMetricType(m.metricType) + val metricAggMethod = optClassName.map { className => + if (metricAggregationMap.contains(className)) { + metricAggregationMap(className) + } else { + // Try to initiate custom metric object + try { + val metric = Utils.loadExtensions(classOf[CustomMetric], Seq(className), conf).head + val method = + (metrics: Array[Long], _: Array[Long]) => metric.aggregateTaskMetrics(metrics) + metricAggregationMap.put(className, method) + method + } catch { + case NonFatal(_) => + // Cannot initialize custom metric object, we might be in history server that does + // not have the custom metric class. + val defaultMethod = (_: Array[Long], _: Array[Long]) => "N/A" + metricAggregationMap.put(className, defaultMethod) + defaultMethod + } + } + }.getOrElse( + // Built-in SQLMetric + SQLMetrics.stringValue(m.metricType, _, _) + ) + (m.accumulatorId, metricAggMethod) + }.toMap val liveStageMetrics = exec.stages.toSeq .flatMap { stageId => Option(stageMetrics.get(stageId)) } @@ -212,7 +245,7 @@ class SQLAppStatusListener( val maxMetricsFromAllStages = new mutable.HashMap[Long, Array[Long]]() - taskMetrics.filter(m => metricTypes.contains(m._1)).foreach { case (id, values) => + taskMetrics.filter(m => accumIds.contains(m._1)).foreach { case (id, values) => val prev = allMetrics.getOrElse(id, null) val updated = if (prev != null) { prev ++ values @@ -223,7 +256,7 @@ class SQLAppStatusListener( } // Find the max for each metric id between all stages. - val validMaxMetrics = maxMetrics.filter(m => metricTypes.contains(m._1)) + val validMaxMetrics = maxMetrics.filter(m => accumIds.contains(m._1)) validMaxMetrics.foreach { case (id, value, taskId, stageId, attemptId) => val updated = maxMetricsFromAllStages.getOrElse(id, Array(value, stageId, attemptId, taskId)) if (value > updated(0)) { @@ -236,7 +269,7 @@ class SQLAppStatusListener( } exec.driverAccumUpdates.foreach { case (id, value) => - if (metricTypes.contains(id)) { + if (accumIds.contains(id)) { val prev = allMetrics.getOrElse(id, null) val updated = if (prev != null) { // If the driver updates same metrics as tasks and has higher value then remove @@ -256,7 +289,7 @@ class SQLAppStatusListener( } val aggregatedMetrics = allMetrics.map { case (id, values) => - id -> SQLMetrics.stringValue(metricTypes(id), values, maxMetricsFromAllStages.getOrElse(id, + id -> metricAggregationMethods(id)(values, maxMetricsFromAllStages.getOrElse(id, Array.empty[Long])) }.toMap diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/CustomMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/CustomMetricsSuite.scala new file mode 100644 index 000000000000..e2fa03ff23c9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/CustomMetricsSuite.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.metric + +import org.apache.spark.SparkFunSuite + +class CustomMetricsSuite extends SparkFunSuite { + + test("Build/parse custom metric metric type") { + Seq(new CustomSumMetric, new CustomAvgMetric).foreach { customMetric => + val metricType = CustomMetrics.buildV2CustomMetricTypeName(customMetric) + + assert(metricType == CustomMetrics.V2_CUSTOM + "_" + customMetric.getClass.getCanonicalName) + assert(CustomMetrics.parseV2CustomMetricType(metricType).isDefined) + assert(CustomMetrics.parseV2CustomMetricType(metricType).get == + customMetric.getClass.getCanonicalName) + } + } + + test("Built-in CustomSumMetric") { + val metric = new CustomSumMetric + + val metricValues1 = Array(0L, 1L, 5L, 5L, 7L, 10L) + assert(metric.aggregateTaskMetrics(metricValues1) == metricValues1.sum.toString) + + val metricValues2 = Array.empty[Long] + assert(metric.aggregateTaskMetrics(metricValues2) == "0") + } + + test("Built-in CustomAvgMetric") { + val metric = new CustomAvgMetric + + val metricValues1 = Array(0L, 1L, 5L, 5L, 7L, 10L) + assert(metric.aggregateTaskMetrics(metricValues1) == "4.667") + + val metricValues2 = Array.empty[Long] + assert(metric.aggregateTaskMetrics(metricValues2) == "0") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index 00f23718a0e9..a58265124d70 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -37,13 +37,17 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.connector.{CustomMetric, CustomTaskMetric, RangeInputPartition, SimpleScanBuilder} +import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} import org.apache.spark.sql.execution.{LeafExecNode, QueryExecution, SparkPlanInfo, SQLExecution} import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecution +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.functions.count import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.UI_RETAINED_EXECUTIONS import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.StructType import org.apache.spark.status.ElementTrackingStore import org.apache.spark.util.{AccumulatorMetadata, JsonProtocol, LongAccumulator} import org.apache.spark.util.kvstore.InMemoryStore @@ -811,6 +815,42 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils listener.onOtherEvent(SparkListenerSQLExecutionEnd( executionId, System.currentTimeMillis())) } + + + test("SPARK-34338: Report metrics from Datasource v2 scan") { + val statusStore = spark.sharedState.statusStore + val oldCount = statusStore.executionsList().size + + val schema = new StructType().add("i", "int").add("j", "int") + val physicalPlan = BatchScanExec(schema.toAttributes, new CustomMetricScanBuilder()) + val dummyQueryExecution = new QueryExecution(spark, LocalRelation()) { + override lazy val sparkPlan = physicalPlan + override lazy val executedPlan = physicalPlan + } + + SQLExecution.withNewExecutionId(dummyQueryExecution) { + physicalPlan.execute().collect() + } + + // Wait until the new execution is started and being tracked. + while (statusStore.executionsCount() < oldCount) { + Thread.sleep(100) + } + + // Wait for listener to finish computing the metrics for the execution. + while (statusStore.executionsList().isEmpty || + statusStore.executionsList().last.metricValues == null) { + Thread.sleep(100) + } + + val execId = statusStore.executionsList().last.executionId + val metrics = statusStore.executionMetrics(execId) + val expectedMetric = physicalPlan.metrics("custom_metric") + val expectedValue = "custom_metric: 12345, 12345" + + assert(metrics.contains(expectedMetric.id)) + assert(metrics(expectedMetric.id) === expectedValue) + } } @@ -885,3 +925,50 @@ class SQLAppStatusListenerMemoryLeakSuite extends SparkFunSuite { } } } + +class SimpleCustomMetric extends CustomMetric { + override def name(): String = "custom_metric" + override def description(): String = "a simple custom metric" + override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = { + s"custom_metric: ${taskMetrics.mkString(", ")}" + } +} + +// The followings are for custom metrics of V2 data source. +object CustomMetricReaderFactory extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val RangeInputPartition(start, end) = partition + new PartitionReader[InternalRow] { + private var current = start - 1 + + override def next(): Boolean = { + current += 1 + current < end + } + + override def get(): InternalRow = InternalRow(current, -current) + + override def close(): Unit = {} + + override def currentMetricsValues(): Array[CustomTaskMetric] = { + val metric = new CustomTaskMetric { + override def name(): String = "custom_metric" + override def value(): Long = 12345 + } + Array(metric) + } + } + } +} + +class CustomMetricScanBuilder extends SimpleScanBuilder { + override def planInputPartitions(): Array[InputPartition] = { + Array(RangeInputPartition(0, 5), RangeInputPartition(5, 10)) + } + + override def supportedCustomMetrics(): Array[CustomMetric] = { + Array(new SimpleCustomMetric) + } + + override def createReaderFactory(): PartitionReaderFactory = CustomMetricReaderFactory +}