From eebf9c67fb8d062dda73853b09d2ff4c0e593d77 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 27 Jan 2021 09:50:03 -0800 Subject: [PATCH 1/5] Add custom metrics. --- .../kafka010/KafkaBatchPartitionReader.scala | 9 ++++ .../sql/kafka010/KafkaMicroBatchStream.scala | 7 ++- .../kafka010/consumer/KafkaDataConsumer.scala | 14 ++++++ .../sql/connector/read/PartitionReader.java | 9 ++++ .../read/streaming/CustomMetric.java | 43 +++++++++++++++++++ .../read/streaming/MicroBatchStream.java | 9 ++++ .../read/streaming/CustomMetrics.scala | 24 +++++++++++ .../datasources/v2/DataSourceRDD.scala | 18 ++++++-- .../datasources/v2/MicroBatchScanExec.scala | 23 +++++++++- 9 files changed, 150 insertions(+), 6 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/CustomMetric.java create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/connector/read/streaming/CustomMetrics.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala index 8b37fd6e7e2b..518367ccfd69 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala @@ -23,6 +23,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} +import org.apache.spark.sql.connector.read.streaming.{CustomMetric, CustomSumMetric} import org.apache.spark.sql.kafka010.consumer.KafkaDataConsumer /** A [[InputPartition]] for reading Kafka data in a batch based streaming query. */ @@ -105,4 +106,12 @@ private case class KafkaBatchPartitionReader( range } } + + override def getCustomMetrics(): Array[CustomMetric] = { + Array( + CustomSumMetric("offsetOutOfRange", "estimated number of fetched offsets out of range", + consumer.getNumOffsetOutOfRange()), + CustomSumMetric("dataLoss", "number of data loss error", + consumer.getNumDataLoss())) + } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala index 1c816ab82d3e..e14f2c7a7219 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala @@ -24,7 +24,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory} -import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset, ReadAllAvailable, ReadLimit, ReadMaxRows, SupportsAdmissionControl} +import org.apache.spark.sql.connector.read.streaming.{CustomMetric, CustomSumMetric, MicroBatchStream, Offset, ReadAllAvailable, ReadLimit, ReadMaxRows, SupportsAdmissionControl} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.UninterruptibleThread @@ -217,4 +217,9 @@ private[kafka010] class KafkaMicroBatchStream( logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE") } } + + override def supportedCustomMetrics(): Array[CustomMetric] = + Array( + CustomSumMetric("offsetOutOfRange", "estimated number of fetched offsets out of range", 0L), + CustomSumMetric("dataLoss", "number of data loss error", 0L)) } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala index 5c92d110a630..37fe38ea94ec 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala @@ -239,6 +239,9 @@ private[kafka010] class KafkaDataConsumer( fetchedDataPool: FetchedDataPool) extends Logging { import KafkaDataConsumer._ + private var offsetOutOfRange = 0L + private var dataLoss = 0L + private val isTokenProviderEnabled = HadoopDelegationTokenManager.isServiceEnabled(SparkEnv.get.conf, "kafka") @@ -329,7 +332,14 @@ private[kafka010] class KafkaDataConsumer( reportDataLoss(topicPartition, groupId, failOnDataLoss, s"Cannot fetch offset $toFetchOffset", e) + + val oldToFetchOffsetd = toFetchOffset toFetchOffset = getEarliestAvailableOffsetBetween(consumer, toFetchOffset, untilOffset) + if (toFetchOffset == UNKNOWN_OFFSET) { + offsetOutOfRange += (untilOffset - oldToFetchOffsetd) + } else { + offsetOutOfRange += (toFetchOffset - oldToFetchOffsetd) + } } } @@ -350,6 +360,9 @@ private[kafka010] class KafkaDataConsumer( consumer.getAvailableOffsetRange() } + def getNumOffsetOutOfRange(): Long = offsetOutOfRange + def getNumDataLoss(): Long = dataLoss + /** * Release borrowed objects in data reader to the pool. Once the instance is created, caller * must call method after using the instance to make sure resources are not leaked. @@ -596,6 +609,7 @@ private[kafka010] class KafkaDataConsumer( message: String, cause: Throwable = null): Unit = { val finalMessage = s"$message ${additionalMessage(topicPartition, groupId, failOnDataLoss)}" + dataLoss += 1 reportDataLoss0(failOnDataLoss, finalMessage, cause) } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java index 23fbd95800e2..cc3fc5028a39 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java @@ -21,6 +21,7 @@ import java.io.IOException; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.read.streaming.CustomMetric; /** * A partition reader returned by {@link PartitionReaderFactory#createReader(InputPartition)} or @@ -48,4 +49,12 @@ public interface PartitionReader extends Closeable { * Return the current record. This method should return same value until `next` is called. */ T get(); + + /** + * Returns an array of custom metrics. By default it returns empty array. + */ + default CustomMetric[] getCustomMetrics() { + CustomMetric[] NO_METRICS = {}; + return NO_METRICS; + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/CustomMetric.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/CustomMetric.java new file mode 100644 index 000000000000..5ebd51c8e001 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/CustomMetric.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read.streaming; + +import org.apache.spark.annotation.Evolving; + +/** + * A custom metric for {@link SparkDataStream}. + * + * @since 3.2.0 + */ +@Evolving +public interface CustomMetric { + /** + * Returns the name of custom metric. + */ + String getName(); + + /** + * Returns the description of custom metric. + */ + String getDescription(); + + /** + * Returns the value of custom metric. + */ + Long getValue(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/MicroBatchStream.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/MicroBatchStream.java index 40ecbf0578ee..7549aaaa8769 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/MicroBatchStream.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/MicroBatchStream.java @@ -56,4 +56,13 @@ public interface MicroBatchStream extends SparkDataStream { * Returns a factory to create a {@link PartitionReader} for each {@link InputPartition}. */ PartitionReaderFactory createReaderFactory(); + + /** + * Returns an array of supported custom metrics with name and description. + * By default it returns empty array. + */ + default CustomMetric[] supportedCustomMetrics() { + CustomMetric[] NO_METRICS = {}; + return NO_METRICS; + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/read/streaming/CustomMetrics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/read/streaming/CustomMetrics.scala new file mode 100644 index 000000000000..e299978061d9 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/read/streaming/CustomMetrics.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read.streaming + +case class CustomSumMetric(name: String, desc: String, value: Long) extends CustomMetric { + override def getName(): String = name + override def getDescription: String = desc + override def getValue: java.lang.Long = value +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 63403b957723..562b170bfae3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.CompletionIterator class DataSourceRDDPartition(val index: Int, val inputPartition: InputPartition) extends Partition with Serializable @@ -36,7 +37,8 @@ class DataSourceRDD( sc: SparkContext, @transient private val inputPartitions: Seq[InputPartition], partitionReaderFactory: PartitionReaderFactory, - columnarReads: Boolean) + columnarReads: Boolean, + onCompletion: PartitionReader[_] => Unit = _ => {}) extends RDD[InternalRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { @@ -55,11 +57,21 @@ class DataSourceRDD( val (iter, reader) = if (columnarReads) { val batchReader = partitionReaderFactory.createColumnarReader(inputPartition) val iter = new MetricsBatchIterator(new PartitionIterator[ColumnarBatch](batchReader)) - (iter, batchReader) + def completionFunction = { + onCompletion(batchReader) + } + val completionIterator = CompletionIterator[ColumnarBatch, Iterator[ColumnarBatch]]( + iter, completionFunction) + (completionIterator, batchReader) } else { val rowReader = partitionReaderFactory.createReader(inputPartition) val iter = new MetricsRowIterator(new PartitionIterator[InternalRow](rowReader)) - (iter, rowReader) + def completionFunction = { + onCompletion(rowReader) + } + val completionIterator = CompletionIterator[InternalRow, Iterator[InternalRow]]( + iter, completionFunction) + (completionIterator, rowReader) } context.addTaskCompletionListener[Unit](_ => reader.close()) // TODO: SPARK-25083 remove the type erasure hack in data source scan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MicroBatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MicroBatchScanExec.scala index bca28e3cacb6..f9c562f9e024 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MicroBatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MicroBatchScanExec.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory, Scan} +import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory, Scan} import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset} +import org.apache.spark.sql.execution.metric.SQLMetrics /** * Physical plan node for scanning a micro-batch of data from a data source. @@ -33,6 +34,14 @@ case class MicroBatchScanExec( @transient start: Offset, @transient end: Offset) extends DataSourceV2ScanExecBase { + override lazy val metrics = { + val customMetrics = stream.supportedCustomMetrics().map { customMetric => + customMetric.getName -> SQLMetrics.createMetric(sparkContext, customMetric.getDescription) + } + Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) ++ + customMetrics + } + // TODO: unify the equal/hashCode implementation for all data source v2 query plans. override def equals(other: Any): Boolean = other match { case other: MicroBatchScanExec => this.stream == other.stream @@ -45,7 +54,17 @@ case class MicroBatchScanExec( override lazy val readerFactory: PartitionReaderFactory = stream.createReaderFactory() + /** + * The callback function which is called when the output iterator of input RDD is consumed + * completely. + */ + private def onOutputCompletion(reader: PartitionReader[_]) = { + reader.getCustomMetrics.foreach { metric => + longMetric(metric.getName) += metric.getValue + } + } + override lazy val inputRDD: RDD[InternalRow] = { - new DataSourceRDD(sparkContext, partitions, readerFactory, supportsColumnar) + new DataSourceRDD(sparkContext, partitions, readerFactory, supportsColumnar, onOutputCompletion) } } From ee6fb13c3c6bc8d70a8d7cb2dbf00a27dc1083d1 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 22 Apr 2021 01:49:53 -0700 Subject: [PATCH 2/5] Remove unused import. --- .../spark/sql/execution/datasources/v2/DataSourceRDD.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 5db835191dcf..7850dfa39d16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, Par import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.CompletionIterator class DataSourceRDDPartition(val index: Int, val inputPartition: InputPartition) extends Partition with Serializable From 29ccb06a20888ef3f01934cc9325ca2f6b0fbdae Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 22 Apr 2021 02:00:28 -0700 Subject: [PATCH 3/5] Fix merging issue. --- .../spark/sql/kafka010/KafkaBatchPartitionReader.scala | 10 +--------- .../spark/sql/kafka010/KafkaSourceProvider.scala | 5 ++++- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala index a03949e01b6d..0aa51a940160 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala @@ -18,12 +18,12 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} + import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.connector.CustomTaskMetric import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} -import org.apache.spark.sql.connector.read.streaming.{CustomMetric, CustomSumMetric} import org.apache.spark.sql.kafka010.consumer.KafkaDataConsumer /** A [[InputPartition]] for reading Kafka data in a batch based streaming query. */ @@ -107,14 +107,6 @@ private case class KafkaBatchPartitionReader( } } - override def getCustomMetrics(): Array[CustomMetric] = { - Array( - CustomSumMetric("offsetOutOfRange", "estimated number of fetched offsets out of range", - consumer.getNumOffsetOutOfRange()), - CustomSumMetric("dataLoss", "number of data loss error", - consumer.getNumDataLoss())) - } - override def currentMetricsValues(): Array[CustomTaskMetric] = { val offsetOutOfRange = new CustomTaskMetric { override def name(): String = "offsetOutOfRange" diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 996d8deea46e..28ddb975d870 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -19,13 +19,16 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} import java.util.{Locale, UUID} + import scala.collection.JavaConverters._ + import org.apache.kafka.clients.consumer.ConsumerConfig import org.apache.kafka.clients.producer.ProducerConfig import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer} + import org.apache.spark.internal.Logging import org.apache.spark.kafka010.KafkaConfigUpdater -import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext, SaveMode} +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.CustomMetric import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability} From 6f3d27d8acf327ad702868eec2195212bd0895d6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 22 Apr 2021 17:22:57 -0700 Subject: [PATCH 4/5] Add metrics. --- .../sql/kafka010/KafkaSourceProvider.scala | 14 +++++++++++--- .../sql/execution/metric/CustomMetrics.scala | 18 ++++++++++-------- .../sql/execution/metric/SQLMetrics.scala | 2 +- .../execution/metric/CustomMetricsSuite.scala | 7 +++---- 4 files changed, 25 insertions(+), 16 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 28ddb975d870..c34c43563014 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -507,13 +507,21 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } override def supportedCustomMetrics(): Array[CustomMetric] = { - Array( - new CustomSumMetric("offsetOutOfRange", "estimated number of fetched offsets out of range"), - new CustomSumMetric("dataLoss", "number of data loss error")) + Array(new OffsetOutOfRangeMetric, new DataLossMetric) } } } +private[spark] class OffsetOutOfRangeMetric extends CustomSumMetric { + override def name(): String = "offsetOutOfRange" + override def description(): String = "estimated number of fetched offsets out of range" +} + +private[spark] class DataLossMetric extends CustomSumMetric { + override def name(): String = "dataLoss" + override def description(): String = "number of data loss error" +} + private[kafka010] object KafkaSourceProvider extends Logging { private val ASSIGN = "assign" private val SUBSCRIBE_PATTERN = "subscribepattern" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala index 8da744ecbd54..cbad8856ff0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala @@ -47,12 +47,13 @@ object CustomMetrics { } /** - * Built-in `CustomMetric` that sums up metric values. + * Built-in `CustomMetric` that sums up metric values. Note that please extend this class + * and override `name` and `description` to create your custom metric for real usage. */ -class CustomSumMetric(metricName: String, metricDescption: String) extends CustomMetric { - override def name(): String = metricName +class CustomSumMetric extends CustomMetric { + override def name(): String = "CustomSumMetric" - override def description(): String = metricDescption + override def description(): String = "Sum up CustomMetric" override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = { taskMetrics.sum.toString @@ -60,12 +61,13 @@ class CustomSumMetric(metricName: String, metricDescption: String) extends Custo } /** - * Built-in `CustomMetric` that computes average of metric values. + * Built-in `CustomMetric` that computes average of metric values. Note that please extend this + * class and override `name` and `description` to create your custom metric for real usage. */ -class CustomAvgMetric(metricName: String, metricDescption: String) extends CustomMetric { - override def name(): String = metricName +class CustomAvgMetric extends CustomMetric { + override def name(): String = "CustomAvgMetric" - override def description(): String = metricDescption + override def description(): String = "Average CustomMetric" override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = { val average = if (taskMetrics.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index da39e8c455e3..0e48e6efeee3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -113,7 +113,7 @@ object SQLMetrics { */ def createV2CustomMetric(sc: SparkContext, customMetric: CustomMetric): SQLMetric = { val acc = new SQLMetric(CustomMetrics.buildV2CustomMetricTypeName(customMetric)) - acc.register(sc, name = Some(customMetric.name()), countFailedValues = false) + acc.register(sc, name = Some(customMetric.description()), countFailedValues = false) acc } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/CustomMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/CustomMetricsSuite.scala index 1fa50a0843fd..e2fa03ff23c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/CustomMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/CustomMetricsSuite.scala @@ -22,8 +22,7 @@ import org.apache.spark.SparkFunSuite class CustomMetricsSuite extends SparkFunSuite { test("Build/parse custom metric metric type") { - Seq(new CustomSumMetric("custom sum", "Sum up CustomMetric"), - new CustomAvgMetric("custom average", "Average CustomMetric")).foreach { customMetric => + Seq(new CustomSumMetric, new CustomAvgMetric).foreach { customMetric => val metricType = CustomMetrics.buildV2CustomMetricTypeName(customMetric) assert(metricType == CustomMetrics.V2_CUSTOM + "_" + customMetric.getClass.getCanonicalName) @@ -34,7 +33,7 @@ class CustomMetricsSuite extends SparkFunSuite { } test("Built-in CustomSumMetric") { - val metric = new CustomSumMetric("custom sum", "Sum up CustomMetric") + val metric = new CustomSumMetric val metricValues1 = Array(0L, 1L, 5L, 5L, 7L, 10L) assert(metric.aggregateTaskMetrics(metricValues1) == metricValues1.sum.toString) @@ -44,7 +43,7 @@ class CustomMetricsSuite extends SparkFunSuite { } test("Built-in CustomAvgMetric") { - val metric = new CustomAvgMetric("custom average", "Average CustomMetric") + val metric = new CustomAvgMetric val metricValues1 = Array(0L, 1L, 5L, 5L, 7L, 10L) assert(metric.aggregateTaskMetrics(metricValues1) == "4.667") From 54bf5ae66d830218f1ba9ddfacdda6db2f6e54b2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 23 Apr 2021 09:13:41 -0700 Subject: [PATCH 5/5] Make built-in custom metric classes as abstract classes. --- .../sql/execution/metric/CustomMetrics.scala | 10 ++-------- .../execution/metric/CustomMetricsSuite.scala | 16 +++++++++++++--- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala index cbad8856ff0c..cc28be3ca8ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala @@ -50,10 +50,7 @@ object CustomMetrics { * Built-in `CustomMetric` that sums up metric values. Note that please extend this class * and override `name` and `description` to create your custom metric for real usage. */ -class CustomSumMetric extends CustomMetric { - override def name(): String = "CustomSumMetric" - - override def description(): String = "Sum up CustomMetric" +abstract class CustomSumMetric extends CustomMetric { override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = { taskMetrics.sum.toString @@ -64,10 +61,7 @@ class CustomSumMetric extends CustomMetric { * Built-in `CustomMetric` that computes average of metric values. Note that please extend this * class and override `name` and `description` to create your custom metric for real usage. */ -class CustomAvgMetric extends CustomMetric { - override def name(): String = "CustomAvgMetric" - - override def description(): String = "Average CustomMetric" +abstract class CustomAvgMetric extends CustomMetric { override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = { val average = if (taskMetrics.isEmpty) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/CustomMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/CustomMetricsSuite.scala index e2fa03ff23c9..020f3f494a2e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/CustomMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/CustomMetricsSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.SparkFunSuite class CustomMetricsSuite extends SparkFunSuite { test("Build/parse custom metric metric type") { - Seq(new CustomSumMetric, new CustomAvgMetric).foreach { customMetric => + Seq(new TestCustomSumMetric, new TestCustomAvgMetric).foreach { customMetric => val metricType = CustomMetrics.buildV2CustomMetricTypeName(customMetric) assert(metricType == CustomMetrics.V2_CUSTOM + "_" + customMetric.getClass.getCanonicalName) @@ -33,7 +33,7 @@ class CustomMetricsSuite extends SparkFunSuite { } test("Built-in CustomSumMetric") { - val metric = new CustomSumMetric + val metric = new TestCustomSumMetric val metricValues1 = Array(0L, 1L, 5L, 5L, 7L, 10L) assert(metric.aggregateTaskMetrics(metricValues1) == metricValues1.sum.toString) @@ -43,7 +43,7 @@ class CustomMetricsSuite extends SparkFunSuite { } test("Built-in CustomAvgMetric") { - val metric = new CustomAvgMetric + val metric = new TestCustomAvgMetric val metricValues1 = Array(0L, 1L, 5L, 5L, 7L, 10L) assert(metric.aggregateTaskMetrics(metricValues1) == "4.667") @@ -52,3 +52,13 @@ class CustomMetricsSuite extends SparkFunSuite { assert(metric.aggregateTaskMetrics(metricValues2) == "0") } } + +private[spark] class TestCustomSumMetric extends CustomSumMetric { + override def name(): String = "CustomSumMetric" + override def description(): String = "Sum up CustomMetric" +} + +private[spark] class TestCustomAvgMetric extends CustomAvgMetric { + override def name(): String = "CustomAvgMetric" + override def description(): String = "Average CustomMetric" +}