diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 9cf480caba3e4..1a44cf20075dc 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1300,10 +1300,28 @@ Configuration of in-memory caching can be done using the `setConf` method on `Sp +## QueryExecutionListener Options +Use this configuration option to attach query execution listeners + + + + + + + + +
Property NameDefaultMeaning
spark.sql.queryExecutionListeners + A comma-separated list of classes that implement QueryExecutionListener. When creating a SparkSession, + instances of these listeners will be added to it. These classes needs to have a zero-argument + constructor. If the specified class can't be found or the class specified doesn't have a valid + constructor the SparkSession creation will fail with an exception. +
+ ## Other Configuration Options The following options can also be used to tune the performance of query execution. It is possible -that these options will be deprecated in future release as more optimizations are performed automatically. +that these options will be deprecated in future release as more optimizations are performed +automatically. diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 7e6e143523387..a5609f2d2489c 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -133,7 +133,13 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryException.startOffset"), ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryException.endOffset"), ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryException.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryException.query") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryException.query"), + + // [SPARK-18120 ][SQL] Call QueryExecutionListener callback methods for DataFrameWriter methods + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.util.QueryExecutionListener.onSuccess"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.util.QueryExecutionListener.onFailure"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.util.QueryExecutionListener.onSuccess"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.util.QueryExecutionListener.onFailure") ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index ff1f0177e8ba0..9e8afcff6c097 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -26,10 +26,13 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogRelation, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable +import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation} +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.{OutputParams} /** * Interface used to write a [[Dataset]] to external storage systems (e.g. file systems, @@ -189,6 +192,32 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { this } + /** + * Wrap a DataFrameWriter action to track the query execution and time cost, then report to the + * user-registered callback functions. + * + * @param funcName A identifier for the method executing the query + * @param qe the @see `QueryExecution` object associated with the query + * @param outputParams The output parameters useful for query analysis + * @param action the function that executes the query after which the listener methods gets + * called. + */ + private def withAction( + funcName: String, + qe: QueryExecution, + outputParams: OutputParams)(action: => Unit) = { + try { + val start = System.nanoTime() + action + val end = System.nanoTime() + df.sparkSession.listenerManager.onSuccess(funcName, qe, end - start, Some(outputParams)) + } catch { + case e: Exception => + df.sparkSession.listenerManager.onFailure(funcName, qe, e, Some(outputParams)) + throw e + } + } + /** * Saves the content of the `DataFrame` at the specified path. * @@ -218,7 +247,14 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { bucketSpec = getBucketSpec, options = extraOptions.toMap) - dataSource.write(mode, df) + val destination = source match { + case "jdbc" => extraOptions.get(JDBCOptions.JDBC_TABLE_NAME) + case _ => extraOptions.get("path") + } + val outputParams = OutputParams(source, destination, extraOptions.toMap) + withAction("save", df.queryExecution, outputParams) { + dataSource.write(mode, df) + } } /** @@ -261,13 +297,15 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { ) } - df.sparkSession.sessionState.executePlan( + val qe = df.sparkSession.sessionState.executePlan( InsertIntoTable( table = UnresolvedRelation(tableIdent), partition = Map.empty[String, Option[String]], child = df.logicalPlan, overwrite = mode == SaveMode.Overwrite, - ifNotExists = false)).toRdd + ifNotExists = false)) + val outputParams = OutputParams(source, Some(tableIdent.unquotedString), extraOptions.toMap) + withAction("insertInto", qe, outputParams)(qe.toRdd) } private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols => @@ -324,7 +362,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private def assertNotPartitioned(operation: String): Unit = { if (partitioningColumns.isDefined) { - throw new AnalysisException( s"'$operation' does not support partitioning") + throw new AnalysisException(s"'$operation' does not support partitioning") } } @@ -428,8 +466,10 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { partitionColumnNames = partitioningColumns.getOrElse(Nil), bucketSpec = getBucketSpec ) - df.sparkSession.sessionState.executePlan( - CreateTable(tableDesc, mode, Some(df.logicalPlan))).toRdd + val qe = df.sparkSession.sessionState.executePlan( + CreateTable(tableDesc, mode, Some(df.logicalPlan))) + val outputParams = OutputParams(source, Some(tableIdent.unquotedString), extraOptions.toMap) + withAction("saveAsTable", qe, outputParams)(qe.toRdd) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index f3dde480eabe0..8a027c56cdf5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -40,12 +40,12 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.ui.SQLListener -import org.apache.spark.sql.internal.{CatalogImpl, SessionState, SharedState} +import org.apache.spark.sql.internal._ import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming._ import org.apache.spark.sql.types.{DataType, LongType, StructType} -import org.apache.spark.sql.util.ExecutionListenerManager +import org.apache.spark.sql.util.{ExecutionListenerManager, QueryExecutionListener} import org.apache.spark.util.Utils @@ -876,6 +876,9 @@ object SparkSession { } session = new SparkSession(sparkContext) options.foreach { case (k, v) => session.sessionState.conf.setConfString(k, v) } + for (qeListener <- createQueryExecutionListeners(session.sparkContext.getConf)) { + session.listenerManager.register(qeListener) + } defaultSession.set(session) // Register a successfully instantiated context to the singleton. This should be at the @@ -893,6 +896,12 @@ object SparkSession { } } + private def createQueryExecutionListeners(conf: SparkConf): Seq[QueryExecutionListener] = { + conf.get(StaticSQLConf.QUERY_EXECUTION_LISTENERS) + .map(Utils.classForName(_)) + .map(_.newInstance().asInstanceOf[QueryExecutionListener]) + } + /** * Creates a [[SparkSession.Builder]] for constructing a [[SparkSession]]. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5ba4192512a59..b8b9a2e03f638 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1047,4 +1047,14 @@ object StaticSQLConf { "SQL configuration and the current database.") .booleanConf .createWithDefault(false) + + val QUERY_EXECUTION_LISTENERS = buildConf("spark.sql.queryExecutionListeners") + .doc("A comma-separated list of classes that implement QueryExecutionListener. When creating " + + "a SparkSession, instances of these listeners will be added to it. These classes " + + "needs to have a zero-argument constructor. If the specified class can't be found or" + + " the class specified doesn't have a valid constructor the SparkSession creation " + + "will fail with an exception.") + .stringConf + .toSequence + .createWithDefault(Nil) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index 26ad0eadd9d4c..2f0b39d8a7990 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -44,12 +44,15 @@ trait QueryExecutionListener { * @param qe the QueryExecution object that carries detail information like logical plan, * physical plan, etc. * @param durationNs the execution time for this query in nanoseconds. - * - * @note This can be invoked by multiple different threads. + * @param outputParams The output parameters in case the method is invoked as a result of a + * write operation. In case of a read will be @see `None` */ @DeveloperApi - def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit - + def onSuccess( + funcName: String, + qe: QueryExecution, + durationNs: Long, + outputParams: Option[OutputParams]): Unit /** * A callback function that will be called when a query execution failed. * @@ -57,14 +60,34 @@ trait QueryExecutionListener { * @param qe the QueryExecution object that carries detail information like logical plan, * physical plan, etc. * @param exception the exception that failed this query. + * @param outputParams The output parameters in case the method is invoked as a result of a + * write operation. In case of a read will be @see `None` * * @note This can be invoked by multiple different threads. */ @DeveloperApi - def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit + def onFailure( + funcName: String, + qe: QueryExecution, + exception: Exception, + outputParams: Option[OutputParams]): Unit } - +/** + * Contains extra information useful for query analysis passed on from the methods in + * @see `org.apache.spark.sql.DataFrameWriter` while writing to a datasource + * @param datasourceType type of data source written to like csv, parquet, json, hive, jdbc etc. + * @param destination path or table name written to + * @param options the map containing the output options for the underlying datasource + * specified by using the @see `org.apache.spark.sql.DataFrameWriter#option` method + * @param writeParams will contain any extra information that the write method wants to provide + */ +@DeveloperApi +case class OutputParams( + datasourceType: String, + destination: Option[String], + options: Map[String, String], + writeParams: Map[String, String] = Map.empty) /** * :: Experimental :: * @@ -98,18 +121,26 @@ class ExecutionListenerManager private[sql] () extends Logging { listeners.clear() } - private[sql] def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + private[sql] def onSuccess( + funcName: String, + qe: QueryExecution, + duration: Long, + outputParams: Option[OutputParams] = None): Unit = { readLock { withErrorHandling { listener => - listener.onSuccess(funcName, qe, duration) + listener.onSuccess(funcName, qe, duration, outputParams) } } } - private[sql] def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { + private[sql] def onFailure( + funcName: String, + qe: QueryExecution, + exception: Exception, + outputParams: Option[OutputParams] = None): Unit = { readLock { withErrorHandling { listener => - listener.onFailure(funcName, qe, exception) + listener.onFailure(funcName, qe, exception, outputParams) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSQLQueryExecutionListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSQLQueryExecutionListenerSuite.scala new file mode 100644 index 0000000000000..1e823a9840f48 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSQLQueryExecutionListenerSuite.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import org.scalatest.BeforeAndAfterEach +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.{SparkContext, SparkFunSuite} +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.util.{OutputParams, QueryExecutionListener} + +/** + * Test cases for the property 'spark.sql.queryExecutionListeners' that adds the + * @see `QueryExecutionListener` to a @see `SparkSession` + */ +class SparkSQLQueryExecutionListenerSuite + extends SparkFunSuite + with MockitoSugar + with BeforeAndAfterEach { + + override def afterEach(): Unit = { + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + SparkContext.clearActiveContext() + } + + test("Creation of SparkContext with non-existent QueryExecutionListener class fails fast") { + intercept[ClassNotFoundException] { + SparkSession + .builder() + .master("local") + .config("spark.sql.queryExecutionListeners", "non.existent.QueryExecutionListener") + .getOrCreate() + } + assert(!SparkSession.getDefaultSession.isDefined) + } + + test("QueryExecutionListener that doesn't have a default constructor fails fast") { + intercept[InstantiationException] { + SparkSession + .builder() + .master("local") + .config("spark.sql.queryExecutionListeners", classOf[NoZeroArgConstructorListener].getName) + .getOrCreate() + } + assert(!SparkSession.getDefaultSession.isDefined) + } + + test("Normal QueryExecutionListeners gets added as listeners") { + val sparkSession = SparkSession + .builder() + .master("local") + .config("mykey", "myvalue") + .config("spark.sql.queryExecutionListeners", + classOf[NormalQueryExecutionListener].getName + " ," + + classOf[AnotherQueryExecutionListener].getName) + .getOrCreate() + assert(SparkSession.getDefaultSession.isDefined) + assert(NormalQueryExecutionListener.successCount === 0) + assert(NormalQueryExecutionListener.failureCount === 0) + assert(AnotherQueryExecutionListener.successCount === 0) + assert(AnotherQueryExecutionListener.failureCount === 0) + sparkSession.listenerManager.onSuccess("test1", mock[QueryExecution], 0) + assert(NormalQueryExecutionListener.successCount === 1) + assert(NormalQueryExecutionListener.failureCount === 0) + assert(AnotherQueryExecutionListener.successCount === 1) + assert(AnotherQueryExecutionListener.failureCount === 0) + sparkSession.listenerManager.onFailure("test2", mock[QueryExecution], new Exception) + assert(NormalQueryExecutionListener.successCount === 1) + assert(NormalQueryExecutionListener.failureCount === 1) + assert(AnotherQueryExecutionListener.successCount === 1) + assert(AnotherQueryExecutionListener.failureCount === 1) + } +} + +class NoZeroArgConstructorListener(myString: String) extends QueryExecutionListener { + + override def onSuccess( + funcName: String, + qe: QueryExecution, + durationNs: Long, + options: Option[OutputParams] + ): Unit = {} + + override def onFailure( + funcName: String, + qe: QueryExecution, + exception: Exception, + options: Option[OutputParams] + ): Unit = {} +} + +class NormalQueryExecutionListener extends QueryExecutionListener { + + override def onSuccess( + funcName: String, + qe: QueryExecution, + durationNs: Long, + options: Option[OutputParams] + ): Unit = { NormalQueryExecutionListener.successCount += 1 } + + override def onFailure( + funcName: String, + qe: QueryExecution, + exception: Exception, + options: Option[OutputParams] + ): Unit = { NormalQueryExecutionListener.failureCount += 1 } +} + +object NormalQueryExecutionListener { + var successCount = 0; + var failureCount = 0; +} + +class AnotherQueryExecutionListener extends QueryExecutionListener { + + override def onSuccess( + funcName: String, + qe: QueryExecution, + durationNs: Long, + options: Option[OutputParams] + ): Unit = { AnotherQueryExecutionListener.successCount += 1 } + + override def onFailure( + funcName: String, + qe: QueryExecution, + exception: Exception, + options: Option[OutputParams] + ): Unit = { AnotherQueryExecutionListener.failureCount += 1 } +} + +object AnotherQueryExecutionListener { + var successCount = 0; + var failureCount = 0; +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index 3ae5ce610d2a6..d6bb793d9b620 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.util import scala.collection.mutable.ArrayBuffer import org.apache.spark._ -import org.apache.spark.sql.{functions, QueryTest} +import org.apache.spark.sql.{functions, DataFrame, QueryTest} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project} import org.apache.spark.sql.execution.{QueryExecution, WholeStageCodegenExec} import org.apache.spark.sql.test.SharedSQLContext @@ -33,9 +33,17 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { val metrics = ArrayBuffer.empty[(String, QueryExecution, Long)] val listener = new QueryExecutionListener { // Only test successful case here, so no need to implement `onFailure` - override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + override def onFailure( + funcName: String, + qe: QueryExecution, + exception: Exception, + outputParams: Option[OutputParams]): Unit = {} - override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + override def onSuccess( + funcName: String, + qe: QueryExecution, + duration: Long, + outputParams: Option[OutputParams]): Unit = { metrics += ((funcName, qe, duration)) } } @@ -61,12 +69,20 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { test("execute callback functions when a DataFrame action failed") { val metrics = ArrayBuffer.empty[(String, QueryExecution, Exception)] val listener = new QueryExecutionListener { - override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { + override def onFailure( + funcName: String, + qe: QueryExecution, + exception: Exception, + outputParams: Option[OutputParams]): Unit = { metrics += ((funcName, qe, exception)) } // Only test failed case here, so no need to implement `onSuccess` - override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {} + override def onSuccess( + funcName: String, + qe: QueryExecution, + duration: Long, + outputParams: Option[OutputParams]): Unit = {} } spark.listenerManager.register(listener) @@ -89,9 +105,17 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { val metrics = ArrayBuffer.empty[Long] val listener = new QueryExecutionListener { // Only test successful case here, so no need to implement `onFailure` - override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + override def onFailure( + funcName: String, + qe: QueryExecution, + exception: Exception, + outputParams: Option[OutputParams]): Unit = {} - override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + override def onSuccess( + funcName: String, + qe: QueryExecution, + duration: Long, + outputParams: Option[OutputParams]): Unit = { val metric = qe.executedPlan match { case w: WholeStageCodegenExec => w.child.longMetric("numOutputRows") case other => other.longMetric("numOutputRows") @@ -114,6 +138,53 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { spark.listenerManager.unregister(listener) } + test("QueryExecutionListener gets called on DataFrameWriter.parquet method") { + callSave("parquet", (df: DataFrame, path: String) => df.write.parquet(path)) + } + + test("QueryExecutionListener gets called on DataFrameWriter.json method") { + callSave("json", (df: DataFrame, path: String) => df.write.json(path)) + } + + test("QueryExecutionListener gets called on DataFrameWriter.csv method") { + callSave("csv", (df: DataFrame, path: String) => df.write.csv(path)) + } + + test("QueryExecutionListener gets called on DataFrameWriter.saveAsTable method") { + var onWriteSuccessCalled = false + spark.listenerManager.register(new QueryExecutionListener { + + override def onFailure( + funcName: String, + qe: QueryExecution, + exception: Exception, + outputParams: Option[OutputParams]): Unit = {} + + override def onSuccess( + funcName: String, + qe: QueryExecution, + durationNs: Long, + outputParams: Option[OutputParams]): Unit = { + assert(durationNs > 0) + assert(qe ne null) + onWriteSuccessCalled = true + } + }) + withTable("bar") { + Seq(1 -> 100).toDF("x", "y").write.saveAsTable("bar") + } + assert(onWriteSuccessCalled) + } + + private def callSave(source: String, callSaveFunction: (DataFrame, String) => Unit): Unit = { + val testQueryExecutionListener = new TestQueryExecutionListener(source) + spark.listenerManager.register(testQueryExecutionListener) + withTempPath { path => + callSaveFunction(Seq(1 -> 100).toDF("x", "y"), path.getAbsolutePath) + } + assert(testQueryExecutionListener.onWriteSuccessCalled) + } + // TODO: Currently some LongSQLMetric use -1 as initial value, so if the accumulator is never // updated, we can filter it out later. However, when we aggregate(sum) accumulator values at // driver side for SQL physical operators, these -1 values will make our result smaller. @@ -123,9 +194,17 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { val metrics = ArrayBuffer.empty[Long] val listener = new QueryExecutionListener { // Only test successful case here, so no need to implement `onFailure` - override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + override def onFailure( + funcName: String, + qe: QueryExecution, + exception: Exception, + outputParams: Option[OutputParams]): Unit = {} - override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + override def onSuccess( + funcName: String, + qe: QueryExecution, + duration: Long, + outputParams: Option[OutputParams]): Unit = { metrics += qe.executedPlan.longMetric("dataSize").value val bottomAgg = qe.executedPlan.children(0).children(0) metrics += bottomAgg.longMetric("dataSize").value @@ -159,4 +238,34 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { spark.listenerManager.unregister(listener) } + + class TestQueryExecutionListener(source: String) extends QueryExecutionListener { + var onWriteSuccessCalled = false + + // Only test successful case here, so no need to implement `onFailure` + override def onFailure( + funcName: String, + qe: QueryExecution, + exception: Exception, + outputParams: Option[OutputParams]): Unit = {} + + override def onSuccess( + funcName: String, + qe: QueryExecution, + durationNs: Long, + outputParams: Option[OutputParams]): Unit = { + assert(qe ne null) + assert(outputParams.isDefined) + assert(!outputParams.get.destination.isEmpty) + assert(!outputParams.get.datasourceType.isEmpty) + assert(durationNs > 0) + onWriteSuccessCalled = true + } + } + + protected override def afterEach(): Unit = { + super.afterEach() + spark.listenerManager.clear() + } + }
Property NameDefaultMeaning