Skip to content

Commit 15ff85b

Browse files
cloud-fandavies
authored andcommitted
[SPARK-11068] [SQL] add callback to query execution
With this feature, we can track the query plan, time cost, exception during query execution for spark users. Author: Wenchen Fan <[email protected]> Closes #9078 from cloud-fan/callback.
1 parent e170c22 commit 15ff85b

File tree

4 files changed

+261
-6
lines changed

4 files changed

+261
-6
lines changed

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1344,7 +1344,9 @@ class DataFrame private[sql](
13441344
* @group action
13451345
* @since 1.3.0
13461346
*/
1347-
def head(n: Int): Array[Row] = limit(n).collect()
1347+
def head(n: Int): Array[Row] = withCallback("head", limit(n)) { df =>
1348+
df.collect(needCallback = false)
1349+
}
13481350

13491351
/**
13501352
* Returns the first row.
@@ -1414,25 +1416,39 @@ class DataFrame private[sql](
14141416
* @group action
14151417
* @since 1.3.0
14161418
*/
1417-
def collect(): Array[Row] = withNewExecutionId {
1418-
queryExecution.executedPlan.executeCollectPublic()
1419+
def collect(): Array[Row] = collect(needCallback = true)
1420+
1421+
private def collect(needCallback: Boolean): Array[Row] = {
1422+
def execute(): Array[Row] = withNewExecutionId {
1423+
queryExecution.executedPlan.executeCollectPublic()
1424+
}
1425+
1426+
if (needCallback) {
1427+
withCallback("collect", this)(_ => execute())
1428+
} else {
1429+
execute()
1430+
}
14191431
}
14201432

14211433
/**
14221434
* Returns a Java list that contains all of [[Row]]s in this [[DataFrame]].
14231435
* @group action
14241436
* @since 1.3.0
14251437
*/
1426-
def collectAsList(): java.util.List[Row] = withNewExecutionId {
1427-
java.util.Arrays.asList(rdd.collect() : _*)
1438+
def collectAsList(): java.util.List[Row] = withCallback("collectAsList", this) { _ =>
1439+
withNewExecutionId {
1440+
java.util.Arrays.asList(rdd.collect() : _*)
1441+
}
14281442
}
14291443

14301444
/**
14311445
* Returns the number of rows in the [[DataFrame]].
14321446
* @group action
14331447
* @since 1.3.0
14341448
*/
1435-
def count(): Long = groupBy().count().collect().head.getLong(0)
1449+
def count(): Long = withCallback("count", groupBy().count()) { df =>
1450+
df.collect(needCallback = false).head.getLong(0)
1451+
}
14361452

14371453
/**
14381454
* Returns a new [[DataFrame]] that has exactly `numPartitions` partitions.
@@ -1936,6 +1952,24 @@ class DataFrame private[sql](
19361952
SQLExecution.withNewExecutionId(sqlContext, queryExecution)(body)
19371953
}
19381954

1955+
/**
1956+
* Wrap a DataFrame action to track the QueryExecution and time cost, then report to the
1957+
* user-registered callback functions.
1958+
*/
1959+
private def withCallback[T](name: String, df: DataFrame)(action: DataFrame => T) = {
1960+
try {
1961+
val start = System.nanoTime()
1962+
val result = action(df)
1963+
val end = System.nanoTime()
1964+
sqlContext.listenerManager.onSuccess(name, df.queryExecution, end - start)
1965+
result
1966+
} catch {
1967+
case e: Exception =>
1968+
sqlContext.listenerManager.onFailure(name, df.queryExecution, e)
1969+
throw e
1970+
}
1971+
}
1972+
19391973
////////////////////////////////////////////////////////////////////////////
19401974
////////////////////////////////////////////////////////////////////////////
19411975
// End of deprecated methods
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import java.util.concurrent.locks.ReentrantReadWriteLock
21+
import scala.collection.mutable.ListBuffer
22+
23+
import org.apache.spark.annotation.{DeveloperApi, Experimental}
24+
import org.apache.spark.Logging
25+
import org.apache.spark.sql.execution.QueryExecution
26+
27+
28+
/**
29+
* The interface of query execution listener that can be used to analyze execution metrics.
30+
*
31+
* Note that implementations should guarantee thread-safety as they will be used in a non
32+
* thread-safe way.
33+
*/
34+
@Experimental
35+
trait QueryExecutionListener {
36+
37+
/**
38+
* A callback function that will be called when a query executed successfully.
39+
* Implementations should guarantee thread-safe.
40+
*
41+
* @param funcName the name of the action that triggered this query.
42+
* @param qe the QueryExecution object that carries detail information like logical plan,
43+
* physical plan, etc.
44+
* @param duration the execution time for this query in nanoseconds.
45+
*/
46+
@DeveloperApi
47+
def onSuccess(funcName: String, qe: QueryExecution, duration: Long)
48+
49+
/**
50+
* A callback function that will be called when a query execution failed.
51+
* Implementations should guarantee thread-safe.
52+
*
53+
* @param funcName the name of the action that triggered this query.
54+
* @param qe the QueryExecution object that carries detail information like logical plan,
55+
* physical plan, etc.
56+
* @param exception the exception that failed this query.
57+
*/
58+
@DeveloperApi
59+
def onFailure(funcName: String, qe: QueryExecution, exception: Exception)
60+
}
61+
62+
@Experimental
63+
class ExecutionListenerManager extends Logging {
64+
private[this] val listeners = ListBuffer.empty[QueryExecutionListener]
65+
private[this] val lock = new ReentrantReadWriteLock()
66+
67+
/** Acquires a read lock on the cache for the duration of `f`. */
68+
private def readLock[A](f: => A): A = {
69+
val rl = lock.readLock()
70+
rl.lock()
71+
try f finally {
72+
rl.unlock()
73+
}
74+
}
75+
76+
/** Acquires a write lock on the cache for the duration of `f`. */
77+
private def writeLock[A](f: => A): A = {
78+
val wl = lock.writeLock()
79+
wl.lock()
80+
try f finally {
81+
wl.unlock()
82+
}
83+
}
84+
85+
/**
86+
* Registers the specified QueryExecutionListener.
87+
*/
88+
@DeveloperApi
89+
def register(listener: QueryExecutionListener): Unit = writeLock {
90+
listeners += listener
91+
}
92+
93+
/**
94+
* Unregisters the specified QueryExecutionListener.
95+
*/
96+
@DeveloperApi
97+
def unregister(listener: QueryExecutionListener): Unit = writeLock {
98+
listeners -= listener
99+
}
100+
101+
/**
102+
* clears out all registered QueryExecutionListeners.
103+
*/
104+
@DeveloperApi
105+
def clear(): Unit = writeLock {
106+
listeners.clear()
107+
}
108+
109+
private[sql] def onSuccess(
110+
funcName: String,
111+
qe: QueryExecution,
112+
duration: Long): Unit = readLock {
113+
withErrorHandling { listener =>
114+
listener.onSuccess(funcName, qe, duration)
115+
}
116+
}
117+
118+
private[sql] def onFailure(
119+
funcName: String,
120+
qe: QueryExecution,
121+
exception: Exception): Unit = readLock {
122+
withErrorHandling { listener =>
123+
listener.onFailure(funcName, qe, exception)
124+
}
125+
}
126+
127+
private def withErrorHandling(f: QueryExecutionListener => Unit): Unit = {
128+
for (listener <- listeners) {
129+
try {
130+
f(listener)
131+
} catch {
132+
case e: Exception => logWarning("error executing query execution listener", e)
133+
}
134+
}
135+
}
136+
}

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,9 @@ class SQLContext private[sql](
177177
*/
178178
def getAllConfs: immutable.Map[String, String] = conf.getAllConfs
179179

180+
@transient
181+
lazy val listenerManager: ExecutionListenerManager = new ExecutionListenerManager
182+
180183
@transient
181184
protected[sql] lazy val catalog: Catalog = new SimpleCatalog(conf)
182185

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import org.apache.spark.SparkException
21+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project}
22+
import org.apache.spark.sql.execution.QueryExecution
23+
import org.apache.spark.sql.test.SharedSQLContext
24+
25+
import scala.collection.mutable.ArrayBuffer
26+
27+
class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
28+
import testImplicits._
29+
import functions._
30+
31+
test("execute callback functions when a DataFrame action finished successfully") {
32+
val metrics = ArrayBuffer.empty[(String, QueryExecution, Long)]
33+
val listener = new QueryExecutionListener {
34+
// Only test successful case here, so no need to implement `onFailure`
35+
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
36+
37+
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
38+
metrics += ((funcName, qe, duration))
39+
}
40+
}
41+
sqlContext.listenerManager.register(listener)
42+
43+
val df = Seq(1 -> "a").toDF("i", "j")
44+
df.select("i").collect()
45+
df.filter($"i" > 0).count()
46+
47+
assert(metrics.length == 2)
48+
49+
assert(metrics(0)._1 == "collect")
50+
assert(metrics(0)._2.analyzed.isInstanceOf[Project])
51+
assert(metrics(0)._3 > 0)
52+
53+
assert(metrics(1)._1 == "count")
54+
assert(metrics(1)._2.analyzed.isInstanceOf[Aggregate])
55+
assert(metrics(1)._3 > 0)
56+
}
57+
58+
test("execute callback functions when a DataFrame action failed") {
59+
val metrics = ArrayBuffer.empty[(String, QueryExecution, Exception)]
60+
val listener = new QueryExecutionListener {
61+
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
62+
metrics += ((funcName, qe, exception))
63+
}
64+
65+
// Only test failed case here, so no need to implement `onSuccess`
66+
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {}
67+
}
68+
sqlContext.listenerManager.register(listener)
69+
70+
val errorUdf = udf[Int, Int] { _ => throw new RuntimeException("udf error") }
71+
val df = sparkContext.makeRDD(Seq(1 -> "a")).toDF("i", "j")
72+
73+
// Ignore the log when we are expecting an exception.
74+
sparkContext.setLogLevel("FATAL")
75+
val e = intercept[SparkException](df.select(errorUdf($"i")).collect())
76+
77+
assert(metrics.length == 1)
78+
assert(metrics(0)._1 == "collect")
79+
assert(metrics(0)._2.analyzed.isInstanceOf[Project])
80+
assert(metrics(0)._3.getMessage == e.getMessage)
81+
}
82+
}

0 commit comments

Comments
 (0)