Skip to content

Commit 9f6b3e6

Browse files
committed
[SPARK-21238][SQL] allow nested SQL execution
## What changes were proposed in this pull request? This is kind of another follow-up for #18064 . In #18064 , we wrap every SQL command with SQL execution, which makes nested SQL execution very likely to happen. #18419 trid to improve it a little bit, by introduing `SQLExecition.ignoreNestedExecutionId`. However, this is not friendly to data source developers, they may need to update their code to use this `ignoreNestedExecutionId` API. This PR proposes a new solution, to just allow nested execution. The downside is that, we may have multiple executions for one query. We can improve this by updating the data organization in SQLListener, to have 1-n mapping from query to execution, instead of 1-1 mapping. This can be done in a follow-up. ## How was this patch tested? existing tests. Author: Wenchen Fan <[email protected]> Closes #18450 from cloud-fan/execution-id.
1 parent a946be3 commit 9f6b3e6

File tree

8 files changed

+50
-138
lines changed

8 files changed

+50
-138
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala

Lines changed: 19 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,12 @@ import java.util.concurrent.atomic.AtomicLong
2222

2323
import org.apache.spark.SparkContext
2424
import org.apache.spark.sql.SparkSession
25-
import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd,
26-
SparkListenerSQLExecutionStart}
25+
import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart}
2726

2827
object SQLExecution {
2928

3029
val EXECUTION_ID_KEY = "spark.sql.execution.id"
3130

32-
private val IGNORE_NESTED_EXECUTION_ID = "spark.sql.execution.ignoreNestedExecutionId"
33-
3431
private val _nextExecutionId = new AtomicLong(0)
3532

3633
private def nextExecutionId: Long = _nextExecutionId.getAndIncrement
@@ -45,10 +42,8 @@ object SQLExecution {
4542

4643
private[sql] def checkSQLExecutionId(sparkSession: SparkSession): Unit = {
4744
val sc = sparkSession.sparkContext
48-
val isNestedExecution = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null
49-
val hasExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) != null
5045
// only throw an exception during tests. a missing execution ID should not fail a job.
51-
if (testing && !isNestedExecution && !hasExecutionId) {
46+
if (testing && sc.getLocalProperty(EXECUTION_ID_KEY) == null) {
5247
// Attention testers: when a test fails with this exception, it means that the action that
5348
// started execution of a query didn't call withNewExecutionId. The execution ID should be
5449
// set by calling withNewExecutionId in the action that begins execution, like
@@ -66,56 +61,27 @@ object SQLExecution {
6661
queryExecution: QueryExecution)(body: => T): T = {
6762
val sc = sparkSession.sparkContext
6863
val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY)
69-
if (oldExecutionId == null) {
70-
val executionId = SQLExecution.nextExecutionId
71-
sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString)
72-
executionIdToQueryExecution.put(executionId, queryExecution)
73-
try {
74-
// sparkContext.getCallSite() would first try to pick up any call site that was previously
75-
// set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on
76-
// streaming queries would give us call site like "run at <unknown>:0"
77-
val callSite = sparkSession.sparkContext.getCallSite()
78-
79-
sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart(
80-
executionId, callSite.shortForm, callSite.longForm, queryExecution.toString,
81-
SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis()))
82-
try {
83-
body
84-
} finally {
85-
sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd(
86-
executionId, System.currentTimeMillis()))
87-
}
88-
} finally {
89-
executionIdToQueryExecution.remove(executionId)
90-
sc.setLocalProperty(EXECUTION_ID_KEY, null)
91-
}
92-
} else if (sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null) {
93-
// If `IGNORE_NESTED_EXECUTION_ID` is set, just ignore the execution id while evaluating the
94-
// `body`, so that Spark jobs issued in the `body` won't be tracked.
64+
val executionId = SQLExecution.nextExecutionId
65+
sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString)
66+
executionIdToQueryExecution.put(executionId, queryExecution)
67+
try {
68+
// sparkContext.getCallSite() would first try to pick up any call site that was previously
69+
// set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on
70+
// streaming queries would give us call site like "run at <unknown>:0"
71+
val callSite = sparkSession.sparkContext.getCallSite()
72+
73+
sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart(
74+
executionId, callSite.shortForm, callSite.longForm, queryExecution.toString,
75+
SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis()))
9576
try {
96-
sc.setLocalProperty(EXECUTION_ID_KEY, null)
9777
body
9878
} finally {
99-
sc.setLocalProperty(EXECUTION_ID_KEY, oldExecutionId)
79+
sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd(
80+
executionId, System.currentTimeMillis()))
10081
}
101-
} else {
102-
// Don't support nested `withNewExecutionId`. This is an example of the nested
103-
// `withNewExecutionId`:
104-
//
105-
// class DataFrame {
106-
// def foo: T = withNewExecutionId { something.createNewDataFrame().collect() }
107-
// }
108-
//
109-
// Note: `collect` will call withNewExecutionId
110-
// In this case, only the "executedPlan" for "collect" will be executed. The "executedPlan"
111-
// for the outer DataFrame won't be executed. So it's meaningless to create a new Execution
112-
// for the outer DataFrame. Even if we track it, since its "executedPlan" doesn't run,
113-
// all accumulator metrics will be 0. It will confuse people if we show them in Web UI.
114-
//
115-
// A real case is the `DataFrame.count` method.
116-
throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set, please wrap your " +
117-
"action with SQLExecution.ignoreNestedExecutionId if you don't want to track the Spark " +
118-
"jobs issued by the nested execution.")
82+
} finally {
83+
executionIdToQueryExecution.remove(executionId)
84+
sc.setLocalProperty(EXECUTION_ID_KEY, oldExecutionId)
11985
}
12086
}
12187

@@ -133,20 +99,4 @@ object SQLExecution {
13399
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
134100
}
135101
}
136-
137-
/**
138-
* Wrap an action which may have nested execution id. This method can be used to run an execution
139-
* inside another execution, e.g., `CacheTableCommand` need to call `Dataset.collect`. Note that,
140-
* all Spark jobs issued in the body won't be tracked in UI.
141-
*/
142-
def ignoreNestedExecutionId[T](sparkSession: SparkSession)(body: => T): T = {
143-
val sc = sparkSession.sparkContext
144-
val allowNestedPreviousValue = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID)
145-
try {
146-
sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, "true")
147-
body
148-
} finally {
149-
sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, allowNestedPreviousValue)
150-
}
151-
}
152102
}

sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,7 @@ case class AnalyzeTableCommand(
5151
// 2. when total size is changed, `oldRowCount` becomes invalid.
5252
// This is to make sure that we only record the right statistics.
5353
if (!noscan) {
54-
val newRowCount = SQLExecution.ignoreNestedExecutionId(sparkSession) {
55-
sparkSession.table(tableIdentWithDB).count()
56-
}
54+
val newRowCount = sparkSession.table(tableIdentWithDB).count()
5755
if (newRowCount >= 0 && newRowCount != oldRowCount) {
5856
newStats = if (newStats.isDefined) {
5957
newStats.map(_.copy(rowCount = Some(BigInt(newRowCount))))

sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,14 @@ case class CacheTableCommand(
3434
override def innerChildren: Seq[QueryPlan[_]] = plan.toSeq
3535

3636
override def run(sparkSession: SparkSession): Seq[Row] = {
37-
SQLExecution.ignoreNestedExecutionId(sparkSession) {
38-
plan.foreach { logicalPlan =>
39-
Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString)
40-
}
41-
sparkSession.catalog.cacheTable(tableIdent.quotedString)
37+
plan.foreach { logicalPlan =>
38+
Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString)
39+
}
40+
sparkSession.catalog.cacheTable(tableIdent.quotedString)
4241

43-
if (!isLazy) {
44-
// Performs eager caching
45-
sparkSession.table(tableIdent).count()
46-
}
42+
if (!isLazy) {
43+
// Performs eager caching
44+
sparkSession.table(tableIdent).count()
4745
}
4846

4947
Seq.empty[Row]

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,7 @@ object TextInputCSVDataSource extends CSVDataSource {
145145
inputPaths: Seq[FileStatus],
146146
parsedOptions: CSVOptions): StructType = {
147147
val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions)
148-
val maybeFirstLine = SQLExecution.ignoreNestedExecutionId(sparkSession) {
149-
CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption
150-
}
148+
val maybeFirstLine = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption
151149
inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions)
152150
}
153151

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,9 @@ private[sql] case class JDBCRelation(
130130
}
131131

132132
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
133-
SQLExecution.ignoreNestedExecutionId(data.sparkSession) {
134-
data.write
135-
.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append)
136-
.jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties)
137-
}
133+
data.write
134+
.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append)
135+
.jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties)
138136
}
139137

140138
override def toString: String = {

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,9 @@ class ConsoleSink(options: Map[String, String]) extends Sink with Logging {
4848
println(batchIdStr)
4949
println("-------------------------------------------")
5050
// scalastyle:off println
51-
SQLExecution.ignoreNestedExecutionId(data.sparkSession) {
52-
data.sparkSession.createDataFrame(
53-
data.sparkSession.sparkContext.parallelize(data.collect()), data.schema)
54-
.show(numRowsToShow, isTruncated)
55-
}
51+
data.sparkSession.createDataFrame(
52+
data.sparkSession.sparkContext.parallelize(data.collect()), data.schema)
53+
.show(numRowsToShow, isTruncated)
5654
}
5755
}
5856

@@ -82,9 +80,7 @@ class ConsoleSinkProvider extends StreamSinkProvider
8280

8381
// Truncate the displayed data if it is too long, by default it is true
8482
val isTruncated = parameters.get("truncate").map(_.toBoolean).getOrElse(true)
85-
SQLExecution.ignoreNestedExecutionId(sqlContext.sparkSession) {
86-
data.show(numRowsToShow, isTruncated)
87-
}
83+
data.show(numRowsToShow, isTruncated)
8884

8985
ConsoleRelation(sqlContext, data)
9086
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -194,23 +194,21 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi
194194
}
195195
if (notCommitted) {
196196
logDebug(s"Committing batch $batchId to $this")
197-
SQLExecution.ignoreNestedExecutionId(data.sparkSession) {
198-
outputMode match {
199-
case Append | Update =>
200-
val rows = AddedData(batchId, data.collect())
201-
synchronized { batches += rows }
202-
203-
case Complete =>
204-
val rows = AddedData(batchId, data.collect())
205-
synchronized {
206-
batches.clear()
207-
batches += rows
208-
}
209-
210-
case _ =>
211-
throw new IllegalArgumentException(
212-
s"Output mode $outputMode is not supported by MemorySink")
213-
}
197+
outputMode match {
198+
case Append | Update =>
199+
val rows = AddedData(batchId, data.collect())
200+
synchronized { batches += rows }
201+
202+
case Complete =>
203+
val rows = AddedData(batchId, data.collect())
204+
synchronized {
205+
batches.clear()
206+
batches += rows
207+
}
208+
209+
case _ =>
210+
throw new IllegalArgumentException(
211+
s"Output mode $outputMode is not supported by MemorySink")
214212
}
215213
} else {
216214
logDebug(s"Skipping already committed batch: $batchId")

sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,9 @@ import org.apache.spark.sql.SparkSession
2626
class SQLExecutionSuite extends SparkFunSuite {
2727

2828
test("concurrent query execution (SPARK-10548)") {
29-
// Try to reproduce the issue with the old SparkContext
3029
val conf = new SparkConf()
3130
.setMaster("local[*]")
3231
.setAppName("test")
33-
val badSparkContext = new BadSparkContext(conf)
34-
try {
35-
testConcurrentQueryExecution(badSparkContext)
36-
fail("unable to reproduce SPARK-10548")
37-
} catch {
38-
case e: IllegalArgumentException =>
39-
assert(e.getMessage.contains(SQLExecution.EXECUTION_ID_KEY))
40-
} finally {
41-
badSparkContext.stop()
42-
}
43-
44-
// Verify that the issue is fixed with the latest SparkContext
4532
val goodSparkContext = new SparkContext(conf)
4633
try {
4734
testConcurrentQueryExecution(goodSparkContext)
@@ -134,17 +121,6 @@ class SQLExecutionSuite extends SparkFunSuite {
134121
}
135122
}
136123

137-
/**
138-
* A bad [[SparkContext]] that does not clone the inheritable thread local properties
139-
* when passing them to children threads.
140-
*/
141-
private class BadSparkContext(conf: SparkConf) extends SparkContext(conf) {
142-
protected[spark] override val localProperties = new InheritableThreadLocal[Properties] {
143-
override protected def childValue(parent: Properties): Properties = new Properties(parent)
144-
override protected def initialValue(): Properties = new Properties()
145-
}
146-
}
147-
148124
object SQLExecutionSuite {
149125
@volatile var canProgress = false
150126
}

0 commit comments

Comments
 (0)