diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 9cf480caba3e4..1a44cf20075dc 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -1300,10 +1300,28 @@ Configuration of in-memory caching can be done using the `setConf` method on `Sp
+## QueryExecutionListener Options
+Use this configuration option to attach query execution listeners
+
+
+ | Property Name | Default | Meaning |
+
+ spark.sql.queryExecutionListeners |
+ |
+
+ A comma-separated list of classes that implement QueryExecutionListener. When creating a SparkSession,
+ instances of these listeners will be added to it. These classes needs to have a zero-argument
+ constructor. If the specified class can't be found or the class specified doesn't have a valid
+ constructor the SparkSession creation will fail with an exception.
+ |
+
+
+
## Other Configuration Options
The following options can also be used to tune the performance of query execution. It is possible
-that these options will be deprecated in future release as more optimizations are performed automatically.
+that these options will be deprecated in future release as more optimizations are performed
+automatically.
| Property Name | Default | Meaning |
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 7e6e143523387..a5609f2d2489c 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -133,7 +133,13 @@ object MimaExcludes {
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryException.startOffset"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryException.endOffset"),
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryException.this"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryException.query")
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryException.query"),
+
+ // [SPARK-18120 ][SQL] Call QueryExecutionListener callback methods for DataFrameWriter methods
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.util.QueryExecutionListener.onSuccess"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.util.QueryExecutionListener.onFailure"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.util.QueryExecutionListener.onSuccess"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.util.QueryExecutionListener.onFailure")
)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index ff1f0177e8ba0..9e8afcff6c097 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -26,10 +26,13 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation}
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogRelation, CatalogTable, CatalogTableType}
import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable
+import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation}
+import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.{OutputParams}
/**
* Interface used to write a [[Dataset]] to external storage systems (e.g. file systems,
@@ -189,6 +192,32 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
this
}
+ /**
+ * Wrap a DataFrameWriter action to track the query execution and time cost, then report to the
+ * user-registered callback functions.
+ *
+ * @param funcName A identifier for the method executing the query
+ * @param qe the @see `QueryExecution` object associated with the query
+ * @param outputParams The output parameters useful for query analysis
+ * @param action the function that executes the query after which the listener methods gets
+ * called.
+ */
+ private def withAction(
+ funcName: String,
+ qe: QueryExecution,
+ outputParams: OutputParams)(action: => Unit) = {
+ try {
+ val start = System.nanoTime()
+ action
+ val end = System.nanoTime()
+ df.sparkSession.listenerManager.onSuccess(funcName, qe, end - start, Some(outputParams))
+ } catch {
+ case e: Exception =>
+ df.sparkSession.listenerManager.onFailure(funcName, qe, e, Some(outputParams))
+ throw e
+ }
+ }
+
/**
* Saves the content of the `DataFrame` at the specified path.
*
@@ -218,7 +247,14 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
bucketSpec = getBucketSpec,
options = extraOptions.toMap)
- dataSource.write(mode, df)
+ val destination = source match {
+ case "jdbc" => extraOptions.get(JDBCOptions.JDBC_TABLE_NAME)
+ case _ => extraOptions.get("path")
+ }
+ val outputParams = OutputParams(source, destination, extraOptions.toMap)
+ withAction("save", df.queryExecution, outputParams) {
+ dataSource.write(mode, df)
+ }
}
/**
@@ -261,13 +297,15 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
)
}
- df.sparkSession.sessionState.executePlan(
+ val qe = df.sparkSession.sessionState.executePlan(
InsertIntoTable(
table = UnresolvedRelation(tableIdent),
partition = Map.empty[String, Option[String]],
child = df.logicalPlan,
overwrite = mode == SaveMode.Overwrite,
- ifNotExists = false)).toRdd
+ ifNotExists = false))
+ val outputParams = OutputParams(source, Some(tableIdent.unquotedString), extraOptions.toMap)
+ withAction("insertInto", qe, outputParams)(qe.toRdd)
}
private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols =>
@@ -324,7 +362,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
private def assertNotPartitioned(operation: String): Unit = {
if (partitioningColumns.isDefined) {
- throw new AnalysisException( s"'$operation' does not support partitioning")
+ throw new AnalysisException(s"'$operation' does not support partitioning")
}
}
@@ -428,8 +466,10 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
partitionColumnNames = partitioningColumns.getOrElse(Nil),
bucketSpec = getBucketSpec
)
- df.sparkSession.sessionState.executePlan(
- CreateTable(tableDesc, mode, Some(df.logicalPlan))).toRdd
+ val qe = df.sparkSession.sessionState.executePlan(
+ CreateTable(tableDesc, mode, Some(df.logicalPlan)))
+ val outputParams = OutputParams(source, Some(tableIdent.unquotedString), extraOptions.toMap)
+ withAction("saveAsTable", qe, outputParams)(qe.toRdd)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index f3dde480eabe0..8a027c56cdf5b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -40,12 +40,12 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.ui.SQLListener
-import org.apache.spark.sql.internal.{CatalogImpl, SessionState, SharedState}
+import org.apache.spark.sql.internal._
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.types.{DataType, LongType, StructType}
-import org.apache.spark.sql.util.ExecutionListenerManager
+import org.apache.spark.sql.util.{ExecutionListenerManager, QueryExecutionListener}
import org.apache.spark.util.Utils
@@ -876,6 +876,9 @@ object SparkSession {
}
session = new SparkSession(sparkContext)
options.foreach { case (k, v) => session.sessionState.conf.setConfString(k, v) }
+ for (qeListener <- createQueryExecutionListeners(session.sparkContext.getConf)) {
+ session.listenerManager.register(qeListener)
+ }
defaultSession.set(session)
// Register a successfully instantiated context to the singleton. This should be at the
@@ -893,6 +896,12 @@ object SparkSession {
}
}
+ private def createQueryExecutionListeners(conf: SparkConf): Seq[QueryExecutionListener] = {
+ conf.get(StaticSQLConf.QUERY_EXECUTION_LISTENERS)
+ .map(Utils.classForName(_))
+ .map(_.newInstance().asInstanceOf[QueryExecutionListener])
+ }
+
/**
* Creates a [[SparkSession.Builder]] for constructing a [[SparkSession]].
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 5ba4192512a59..b8b9a2e03f638 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1047,4 +1047,14 @@ object StaticSQLConf {
"SQL configuration and the current database.")
.booleanConf
.createWithDefault(false)
+
+ val QUERY_EXECUTION_LISTENERS = buildConf("spark.sql.queryExecutionListeners")
+ .doc("A comma-separated list of classes that implement QueryExecutionListener. When creating " +
+ "a SparkSession, instances of these listeners will be added to it. These classes " +
+ "needs to have a zero-argument constructor. If the specified class can't be found or" +
+ " the class specified doesn't have a valid constructor the SparkSession creation " +
+ "will fail with an exception.")
+ .stringConf
+ .toSequence
+ .createWithDefault(Nil)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala
index 26ad0eadd9d4c..2f0b39d8a7990 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala
@@ -44,12 +44,15 @@ trait QueryExecutionListener {
* @param qe the QueryExecution object that carries detail information like logical plan,
* physical plan, etc.
* @param durationNs the execution time for this query in nanoseconds.
- *
- * @note This can be invoked by multiple different threads.
+ * @param outputParams The output parameters in case the method is invoked as a result of a
+ * write operation. In case of a read will be @see `None`
*/
@DeveloperApi
- def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit
-
+ def onSuccess(
+ funcName: String,
+ qe: QueryExecution,
+ durationNs: Long,
+ outputParams: Option[OutputParams]): Unit
/**
* A callback function that will be called when a query execution failed.
*
@@ -57,14 +60,34 @@ trait QueryExecutionListener {
* @param qe the QueryExecution object that carries detail information like logical plan,
* physical plan, etc.
* @param exception the exception that failed this query.
+ * @param outputParams The output parameters in case the method is invoked as a result of a
+ * write operation. In case of a read will be @see `None`
*
* @note This can be invoked by multiple different threads.
*/
@DeveloperApi
- def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit
+ def onFailure(
+ funcName: String,
+ qe: QueryExecution,
+ exception: Exception,
+ outputParams: Option[OutputParams]): Unit
}
-
+/**
+ * Contains extra information useful for query analysis passed on from the methods in
+ * @see `org.apache.spark.sql.DataFrameWriter` while writing to a datasource
+ * @param datasourceType type of data source written to like csv, parquet, json, hive, jdbc etc.
+ * @param destination path or table name written to
+ * @param options the map containing the output options for the underlying datasource
+ * specified by using the @see `org.apache.spark.sql.DataFrameWriter#option` method
+ * @param writeParams will contain any extra information that the write method wants to provide
+ */
+@DeveloperApi
+case class OutputParams(
+ datasourceType: String,
+ destination: Option[String],
+ options: Map[String, String],
+ writeParams: Map[String, String] = Map.empty)
/**
* :: Experimental ::
*
@@ -98,18 +121,26 @@ class ExecutionListenerManager private[sql] () extends Logging {
listeners.clear()
}
- private[sql] def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
+ private[sql] def onSuccess(
+ funcName: String,
+ qe: QueryExecution,
+ duration: Long,
+ outputParams: Option[OutputParams] = None): Unit = {
readLock {
withErrorHandling { listener =>
- listener.onSuccess(funcName, qe, duration)
+ listener.onSuccess(funcName, qe, duration, outputParams)
}
}
}
- private[sql] def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
+ private[sql] def onFailure(
+ funcName: String,
+ qe: QueryExecution,
+ exception: Exception,
+ outputParams: Option[OutputParams] = None): Unit = {
readLock {
withErrorHandling { listener =>
- listener.onFailure(funcName, qe, exception)
+ listener.onFailure(funcName, qe, exception, outputParams)
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSQLQueryExecutionListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSQLQueryExecutionListenerSuite.scala
new file mode 100644
index 0000000000000..1e823a9840f48
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSQLQueryExecutionListenerSuite.scala
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.mock.MockitoSugar
+
+import org.apache.spark.{SparkContext, SparkFunSuite}
+import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.util.{OutputParams, QueryExecutionListener}
+
+/**
+ * Test cases for the property 'spark.sql.queryExecutionListeners' that adds the
+ * @see `QueryExecutionListener` to a @see `SparkSession`
+ */
+class SparkSQLQueryExecutionListenerSuite
+ extends SparkFunSuite
+ with MockitoSugar
+ with BeforeAndAfterEach {
+
+ override def afterEach(): Unit = {
+ SparkSession.clearActiveSession()
+ SparkSession.clearDefaultSession()
+ SparkContext.clearActiveContext()
+ }
+
+ test("Creation of SparkContext with non-existent QueryExecutionListener class fails fast") {
+ intercept[ClassNotFoundException] {
+ SparkSession
+ .builder()
+ .master("local")
+ .config("spark.sql.queryExecutionListeners", "non.existent.QueryExecutionListener")
+ .getOrCreate()
+ }
+ assert(!SparkSession.getDefaultSession.isDefined)
+ }
+
+ test("QueryExecutionListener that doesn't have a default constructor fails fast") {
+ intercept[InstantiationException] {
+ SparkSession
+ .builder()
+ .master("local")
+ .config("spark.sql.queryExecutionListeners", classOf[NoZeroArgConstructorListener].getName)
+ .getOrCreate()
+ }
+ assert(!SparkSession.getDefaultSession.isDefined)
+ }
+
+ test("Normal QueryExecutionListeners gets added as listeners") {
+ val sparkSession = SparkSession
+ .builder()
+ .master("local")
+ .config("mykey", "myvalue")
+ .config("spark.sql.queryExecutionListeners",
+ classOf[NormalQueryExecutionListener].getName + " ,"
+ + classOf[AnotherQueryExecutionListener].getName)
+ .getOrCreate()
+ assert(SparkSession.getDefaultSession.isDefined)
+ assert(NormalQueryExecutionListener.successCount === 0)
+ assert(NormalQueryExecutionListener.failureCount === 0)
+ assert(AnotherQueryExecutionListener.successCount === 0)
+ assert(AnotherQueryExecutionListener.failureCount === 0)
+ sparkSession.listenerManager.onSuccess("test1", mock[QueryExecution], 0)
+ assert(NormalQueryExecutionListener.successCount === 1)
+ assert(NormalQueryExecutionListener.failureCount === 0)
+ assert(AnotherQueryExecutionListener.successCount === 1)
+ assert(AnotherQueryExecutionListener.failureCount === 0)
+ sparkSession.listenerManager.onFailure("test2", mock[QueryExecution], new Exception)
+ assert(NormalQueryExecutionListener.successCount === 1)
+ assert(NormalQueryExecutionListener.failureCount === 1)
+ assert(AnotherQueryExecutionListener.successCount === 1)
+ assert(AnotherQueryExecutionListener.failureCount === 1)
+ }
+}
+
+class NoZeroArgConstructorListener(myString: String) extends QueryExecutionListener {
+
+ override def onSuccess(
+ funcName: String,
+ qe: QueryExecution,
+ durationNs: Long,
+ options: Option[OutputParams]
+ ): Unit = {}
+
+ override def onFailure(
+ funcName: String,
+ qe: QueryExecution,
+ exception: Exception,
+ options: Option[OutputParams]
+ ): Unit = {}
+}
+
+class NormalQueryExecutionListener extends QueryExecutionListener {
+
+ override def onSuccess(
+ funcName: String,
+ qe: QueryExecution,
+ durationNs: Long,
+ options: Option[OutputParams]
+ ): Unit = { NormalQueryExecutionListener.successCount += 1 }
+
+ override def onFailure(
+ funcName: String,
+ qe: QueryExecution,
+ exception: Exception,
+ options: Option[OutputParams]
+ ): Unit = { NormalQueryExecutionListener.failureCount += 1 }
+}
+
+object NormalQueryExecutionListener {
+ var successCount = 0;
+ var failureCount = 0;
+}
+
+class AnotherQueryExecutionListener extends QueryExecutionListener {
+
+ override def onSuccess(
+ funcName: String,
+ qe: QueryExecution,
+ durationNs: Long,
+ options: Option[OutputParams]
+ ): Unit = { AnotherQueryExecutionListener.successCount += 1 }
+
+ override def onFailure(
+ funcName: String,
+ qe: QueryExecution,
+ exception: Exception,
+ options: Option[OutputParams]
+ ): Unit = { AnotherQueryExecutionListener.failureCount += 1 }
+}
+
+object AnotherQueryExecutionListener {
+ var successCount = 0;
+ var failureCount = 0;
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index 3ae5ce610d2a6..d6bb793d9b620 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.util
import scala.collection.mutable.ArrayBuffer
import org.apache.spark._
-import org.apache.spark.sql.{functions, QueryTest}
+import org.apache.spark.sql.{functions, DataFrame, QueryTest}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project}
import org.apache.spark.sql.execution.{QueryExecution, WholeStageCodegenExec}
import org.apache.spark.sql.test.SharedSQLContext
@@ -33,9 +33,17 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
val metrics = ArrayBuffer.empty[(String, QueryExecution, Long)]
val listener = new QueryExecutionListener {
// Only test successful case here, so no need to implement `onFailure`
- override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
+ override def onFailure(
+ funcName: String,
+ qe: QueryExecution,
+ exception: Exception,
+ outputParams: Option[OutputParams]): Unit = {}
- override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
+ override def onSuccess(
+ funcName: String,
+ qe: QueryExecution,
+ duration: Long,
+ outputParams: Option[OutputParams]): Unit = {
metrics += ((funcName, qe, duration))
}
}
@@ -61,12 +69,20 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
test("execute callback functions when a DataFrame action failed") {
val metrics = ArrayBuffer.empty[(String, QueryExecution, Exception)]
val listener = new QueryExecutionListener {
- override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
+ override def onFailure(
+ funcName: String,
+ qe: QueryExecution,
+ exception: Exception,
+ outputParams: Option[OutputParams]): Unit = {
metrics += ((funcName, qe, exception))
}
// Only test failed case here, so no need to implement `onSuccess`
- override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {}
+ override def onSuccess(
+ funcName: String,
+ qe: QueryExecution,
+ duration: Long,
+ outputParams: Option[OutputParams]): Unit = {}
}
spark.listenerManager.register(listener)
@@ -89,9 +105,17 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
val metrics = ArrayBuffer.empty[Long]
val listener = new QueryExecutionListener {
// Only test successful case here, so no need to implement `onFailure`
- override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
+ override def onFailure(
+ funcName: String,
+ qe: QueryExecution,
+ exception: Exception,
+ outputParams: Option[OutputParams]): Unit = {}
- override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
+ override def onSuccess(
+ funcName: String,
+ qe: QueryExecution,
+ duration: Long,
+ outputParams: Option[OutputParams]): Unit = {
val metric = qe.executedPlan match {
case w: WholeStageCodegenExec => w.child.longMetric("numOutputRows")
case other => other.longMetric("numOutputRows")
@@ -114,6 +138,53 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
spark.listenerManager.unregister(listener)
}
+ test("QueryExecutionListener gets called on DataFrameWriter.parquet method") {
+ callSave("parquet", (df: DataFrame, path: String) => df.write.parquet(path))
+ }
+
+ test("QueryExecutionListener gets called on DataFrameWriter.json method") {
+ callSave("json", (df: DataFrame, path: String) => df.write.json(path))
+ }
+
+ test("QueryExecutionListener gets called on DataFrameWriter.csv method") {
+ callSave("csv", (df: DataFrame, path: String) => df.write.csv(path))
+ }
+
+ test("QueryExecutionListener gets called on DataFrameWriter.saveAsTable method") {
+ var onWriteSuccessCalled = false
+ spark.listenerManager.register(new QueryExecutionListener {
+
+ override def onFailure(
+ funcName: String,
+ qe: QueryExecution,
+ exception: Exception,
+ outputParams: Option[OutputParams]): Unit = {}
+
+ override def onSuccess(
+ funcName: String,
+ qe: QueryExecution,
+ durationNs: Long,
+ outputParams: Option[OutputParams]): Unit = {
+ assert(durationNs > 0)
+ assert(qe ne null)
+ onWriteSuccessCalled = true
+ }
+ })
+ withTable("bar") {
+ Seq(1 -> 100).toDF("x", "y").write.saveAsTable("bar")
+ }
+ assert(onWriteSuccessCalled)
+ }
+
+ private def callSave(source: String, callSaveFunction: (DataFrame, String) => Unit): Unit = {
+ val testQueryExecutionListener = new TestQueryExecutionListener(source)
+ spark.listenerManager.register(testQueryExecutionListener)
+ withTempPath { path =>
+ callSaveFunction(Seq(1 -> 100).toDF("x", "y"), path.getAbsolutePath)
+ }
+ assert(testQueryExecutionListener.onWriteSuccessCalled)
+ }
+
// TODO: Currently some LongSQLMetric use -1 as initial value, so if the accumulator is never
// updated, we can filter it out later. However, when we aggregate(sum) accumulator values at
// driver side for SQL physical operators, these -1 values will make our result smaller.
@@ -123,9 +194,17 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
val metrics = ArrayBuffer.empty[Long]
val listener = new QueryExecutionListener {
// Only test successful case here, so no need to implement `onFailure`
- override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
+ override def onFailure(
+ funcName: String,
+ qe: QueryExecution,
+ exception: Exception,
+ outputParams: Option[OutputParams]): Unit = {}
- override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
+ override def onSuccess(
+ funcName: String,
+ qe: QueryExecution,
+ duration: Long,
+ outputParams: Option[OutputParams]): Unit = {
metrics += qe.executedPlan.longMetric("dataSize").value
val bottomAgg = qe.executedPlan.children(0).children(0)
metrics += bottomAgg.longMetric("dataSize").value
@@ -159,4 +238,34 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
spark.listenerManager.unregister(listener)
}
+
+ class TestQueryExecutionListener(source: String) extends QueryExecutionListener {
+ var onWriteSuccessCalled = false
+
+ // Only test successful case here, so no need to implement `onFailure`
+ override def onFailure(
+ funcName: String,
+ qe: QueryExecution,
+ exception: Exception,
+ outputParams: Option[OutputParams]): Unit = {}
+
+ override def onSuccess(
+ funcName: String,
+ qe: QueryExecution,
+ durationNs: Long,
+ outputParams: Option[OutputParams]): Unit = {
+ assert(qe ne null)
+ assert(outputParams.isDefined)
+ assert(!outputParams.get.destination.isEmpty)
+ assert(!outputParams.get.datasourceType.isEmpty)
+ assert(durationNs > 0)
+ onWriteSuccessCalled = true
+ }
+ }
+
+ protected override def afterEach(): Unit = {
+ super.afterEach()
+ spark.listenerManager.clear()
+ }
+
}