-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-34338][SQL] Report metrics from Datasource v2 scan #31451
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4d941c7
f72834c
4215ec0
d3bf283
aedb965
918c90d
e8576ec
d5d8678
cf05fb7
6dced9c
a50bf40
7d50027
b5be5ba
0f38782
50ed317
06eb9c7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD | |
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} | ||
| import org.apache.spark.sql.errors.QueryExecutionErrors | ||
| import org.apache.spark.sql.execution.metric.SQLMetric | ||
| import org.apache.spark.sql.vectorized.ColumnarBatch | ||
|
|
||
| class DataSourceRDDPartition(val index: Int, val inputPartition: InputPartition) | ||
|
|
@@ -37,7 +38,8 @@ class DataSourceRDD( | |
| sc: SparkContext, | ||
| @transient private val inputPartitions: Seq[InputPartition], | ||
| partitionReaderFactory: PartitionReaderFactory, | ||
| columnarReads: Boolean) | ||
| columnarReads: Boolean, | ||
| customMetrics: Map[String, SQLMetric]) | ||
| extends RDD[InternalRow](sc, Nil) { | ||
|
|
||
| override protected def getPartitions: Array[Partition] = { | ||
|
|
@@ -55,11 +57,13 @@ class DataSourceRDD( | |
| val inputPartition = castPartition(split).inputPartition | ||
| val (iter, reader) = if (columnarReads) { | ||
| val batchReader = partitionReaderFactory.createColumnarReader(inputPartition) | ||
| val iter = new MetricsBatchIterator(new PartitionIterator[ColumnarBatch](batchReader)) | ||
| val iter = new MetricsBatchIterator( | ||
| new PartitionIterator[ColumnarBatch](batchReader, customMetrics)) | ||
| (iter, batchReader) | ||
| } else { | ||
| val rowReader = partitionReaderFactory.createReader(inputPartition) | ||
| val iter = new MetricsRowIterator(new PartitionIterator[InternalRow](rowReader)) | ||
| val iter = new MetricsRowIterator( | ||
| new PartitionIterator[InternalRow](rowReader, customMetrics)) | ||
| (iter, rowReader) | ||
| } | ||
| context.addTaskCompletionListener[Unit](_ => reader.close()) | ||
|
|
@@ -72,7 +76,9 @@ class DataSourceRDD( | |
| } | ||
| } | ||
|
|
||
| private class PartitionIterator[T](reader: PartitionReader[T]) extends Iterator[T] { | ||
| private class PartitionIterator[T]( | ||
| reader: PartitionReader[T], | ||
| customMetrics: Map[String, SQLMetric]) extends Iterator[T] { | ||
| private[this] var valuePrepared = false | ||
|
|
||
| override def hasNext: Boolean = { | ||
|
|
@@ -86,6 +92,12 @@ private class PartitionIterator[T](reader: PartitionReader[T]) extends Iterator[ | |
| if (!hasNext) { | ||
| throw QueryExecutionErrors.endOfStreamError() | ||
| } | ||
| reader.currentMetricsValues.foreach { metric => | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's still per-row update?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yea, will do "update metrics per some rows" in a followup. |
||
| assert(customMetrics.contains(metric.name()), | ||
| s"Custom metrics ${customMetrics.keys.mkString(", ")} do not contain the metric " + | ||
| s"${metric.name()}") | ||
| customMetrics(metric.name()).set(metric.value()) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to check whether
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can throw a user-friendly message if it doesn't. |
||
| } | ||
| valuePrepared = false | ||
| reader.get() | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,8 +32,14 @@ import org.apache.spark.util.Utils | |
|
|
||
| trait DataSourceV2ScanExecBase extends LeafExecNode { | ||
|
|
||
| override lazy val metrics = Map( | ||
| "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) | ||
| lazy val customMetrics = scan.supportedCustomMetrics().map { customMetric => | ||
| customMetric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, customMetric) | ||
| }.toMap | ||
|
|
||
| override lazy val metrics = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any plans to similarly support metrics in writes?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I looked at a few V2 write nodes, but seems we don't have any SQL metrics there (even number of output rows). I guess we don't provide metrics for writes generally now? If there is interest to see metrics in writes, I think it is okay to work on it later.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like updating I think it would be good to follow up and support metrics on the output side. It doesn't need to be done here, but metrics are really useful.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good to me. I can work on it in follow up PRs. |
||
| Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) ++ | ||
| customMetrics | ||
| } | ||
|
|
||
| def scan: Scan | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.execution.metric | ||
|
|
||
| import java.text.NumberFormat | ||
| import java.util.Locale | ||
|
|
||
| import org.apache.spark.sql.connector.CustomMetric | ||
|
|
||
| object CustomMetrics { | ||
| private[spark] val V2_CUSTOM = "v2Custom" | ||
|
|
||
| /** | ||
| * Given a class name, builds and returns a metric type for a V2 custom metric class | ||
| * `CustomMetric`. | ||
| */ | ||
| def buildV2CustomMetricTypeName(customMetric: CustomMetric): String = { | ||
| s"${V2_CUSTOM}_${customMetric.getClass.getCanonicalName}" | ||
| } | ||
|
|
||
| /** | ||
| * Given a V2 custom metric type name, this method parses it and returns the corresponding | ||
| * `CustomMetric` class name. | ||
| */ | ||
| def parseV2CustomMetricType(metricType: String): Option[String] = { | ||
| if (metricType.startsWith(s"${V2_CUSTOM}_")) { | ||
| Some(metricType.drop(V2_CUSTOM.length + 1)) | ||
| } else { | ||
| None | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Built-in `CustomMetric` that sums up metric values. | ||
| */ | ||
| class CustomSumMetric extends CustomMetric { | ||
| override def name(): String = "CustomSumMetric" | ||
|
|
||
| override def description(): String = "Sum up CustomMetric" | ||
|
|
||
| override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = { | ||
| taskMetrics.sum.toString | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Built-in `CustomMetric` that computes average of metric values. | ||
| */ | ||
| class CustomAvgMetric extends CustomMetric { | ||
| override def name(): String = "CustomAvgMetric" | ||
|
|
||
| override def description(): String = "Average CustomMetric" | ||
|
|
||
| override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = { | ||
| val average = if (taskMetrics.isEmpty) { | ||
| 0.0 | ||
| } else { | ||
| taskMetrics.sum.toDouble / taskMetrics.length | ||
| } | ||
| val numberFormat = NumberFormat.getNumberInstance(Locale.US) | ||
| numberFormat.format(average) | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,16 +22,19 @@ import java.util.concurrent.atomic.AtomicInteger | |
|
|
||
| import scala.collection.JavaConverters._ | ||
| import scala.collection.mutable | ||
| import scala.util.control.NonFatal | ||
|
|
||
| import org.apache.spark.{JobExecutionStatus, SparkConf} | ||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.internal.config.Status._ | ||
| import org.apache.spark.scheduler._ | ||
| import org.apache.spark.sql.connector.CustomMetric | ||
| import org.apache.spark.sql.errors.QueryExecutionErrors | ||
| import org.apache.spark.sql.execution.SQLExecution | ||
| import org.apache.spark.sql.execution.metric._ | ||
| import org.apache.spark.sql.internal.StaticSQLConf._ | ||
| import org.apache.spark.status.{ElementTrackingStore, KVUtils, LiveEntity} | ||
| import org.apache.spark.util.Utils | ||
| import org.apache.spark.util.collection.OpenHashMap | ||
|
|
||
| class SQLAppStatusListener( | ||
|
|
@@ -199,7 +202,37 @@ class SQLAppStatusListener( | |
| } | ||
|
|
||
| private def aggregateMetrics(exec: LiveExecutionData): Map[Long, String] = { | ||
| val metricTypes = exec.metrics.map { m => (m.accumulatorId, m.metricType) }.toMap | ||
| val accumIds = exec.metrics.map(_.accumulatorId).toSet | ||
|
|
||
| val metricAggregationMap = new mutable.HashMap[String, (Array[Long], Array[Long]) => String]() | ||
| val metricAggregationMethods = exec.metrics.map { m => | ||
| val optClassName = CustomMetrics.parseV2CustomMetricType(m.metricType) | ||
| val metricAggMethod = optClassName.map { className => | ||
| if (metricAggregationMap.contains(className)) { | ||
| metricAggregationMap(className) | ||
| } else { | ||
| // Try to initiate custom metric object | ||
| try { | ||
| val metric = Utils.loadExtensions(classOf[CustomMetric], Seq(className), conf).head | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @viirya, the
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For current usage, I don't see there are other necessary args for
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that we should keep it as simple as possible. I'm asking because the simple case class does not work since it doesn't have a no-arg constructor, or a 1-arg constructor. It took me a while to dig out because it just doesn't work and no WARN/ERROR logs.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. Added a warning log at #37386 |
||
| val method = | ||
| (metrics: Array[Long], _: Array[Long]) => metric.aggregateTaskMetrics(metrics) | ||
| metricAggregationMap.put(className, method) | ||
| method | ||
| } catch { | ||
| case NonFatal(_) => | ||
| // Cannot initialize custom metric object, we might be in history server that does | ||
| // not have the custom metric class. | ||
| val defaultMethod = (_: Array[Long], _: Array[Long]) => "N/A" | ||
| metricAggregationMap.put(className, defaultMethod) | ||
| defaultMethod | ||
| } | ||
| } | ||
| }.getOrElse( | ||
| // Built-in SQLMetric | ||
| SQLMetrics.stringValue(m.metricType, _, _) | ||
| ) | ||
| (m.accumulatorId, metricAggMethod) | ||
| }.toMap | ||
|
|
||
| val liveStageMetrics = exec.stages.toSeq | ||
| .flatMap { stageId => Option(stageMetrics.get(stageId)) } | ||
|
|
@@ -212,7 +245,7 @@ class SQLAppStatusListener( | |
|
|
||
| val maxMetricsFromAllStages = new mutable.HashMap[Long, Array[Long]]() | ||
|
|
||
| taskMetrics.filter(m => metricTypes.contains(m._1)).foreach { case (id, values) => | ||
| taskMetrics.filter(m => accumIds.contains(m._1)).foreach { case (id, values) => | ||
| val prev = allMetrics.getOrElse(id, null) | ||
| val updated = if (prev != null) { | ||
| prev ++ values | ||
|
|
@@ -223,7 +256,7 @@ class SQLAppStatusListener( | |
| } | ||
|
|
||
| // Find the max for each metric id between all stages. | ||
| val validMaxMetrics = maxMetrics.filter(m => metricTypes.contains(m._1)) | ||
| val validMaxMetrics = maxMetrics.filter(m => accumIds.contains(m._1)) | ||
| validMaxMetrics.foreach { case (id, value, taskId, stageId, attemptId) => | ||
| val updated = maxMetricsFromAllStages.getOrElse(id, Array(value, stageId, attemptId, taskId)) | ||
| if (value > updated(0)) { | ||
|
|
@@ -236,7 +269,7 @@ class SQLAppStatusListener( | |
| } | ||
|
|
||
| exec.driverAccumUpdates.foreach { case (id, value) => | ||
| if (metricTypes.contains(id)) { | ||
| if (accumIds.contains(id)) { | ||
| val prev = allMetrics.getOrElse(id, null) | ||
| val updated = if (prev != null) { | ||
| // If the driver updates same metrics as tasks and has higher value then remove | ||
|
|
@@ -256,7 +289,7 @@ class SQLAppStatusListener( | |
| } | ||
|
|
||
| val aggregatedMetrics = allMetrics.map { case (id, values) => | ||
| id -> SQLMetrics.stringValue(metricTypes(id), values, maxMetricsFromAllStages.getOrElse(id, | ||
| id -> metricAggregationMethods(id)(values, maxMetricsFromAllStages.getOrElse(id, | ||
| Array.empty[Long])) | ||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| }.toMap | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.