Skip to content

Commit 0795c16

Browse files
committed
introduce SQLExecution.ignoreNestedExecutionId
1 parent 6b3d022 commit 0795c16

File tree

8 files changed

+77
-76
lines changed

8 files changed

+77
-76
lines changed

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 3 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -246,13 +246,8 @@ class Dataset[T] private[sql](
246246
_numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = {
247247
val numRows = _numRows.max(0)
248248
val takeResult = toDF().take(numRows + 1)
249-
showString(takeResult, numRows, truncate, vertical)
250-
}
251-
252-
private def showString(
253-
dataWithOneMoreRow: Array[Row], numRows: Int, truncate: Int, vertical: Boolean): String = {
254-
val hasMoreData = dataWithOneMoreRow.length > numRows
255-
val data = dataWithOneMoreRow.take(numRows)
249+
val hasMoreData = takeResult.length > numRows
250+
val data = takeResult.take(numRows)
256251

257252
lazy val timeZone =
258253
DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone)
@@ -688,19 +683,6 @@ class Dataset[T] private[sql](
688683
println(showString(numRows, truncate = 0))
689684
}
690685

691-
// An internal version of `show`, which won't set execution id and trigger listeners.
692-
private[sql] def showInternal(_numRows: Int, truncate: Boolean): Unit = {
693-
val numRows = _numRows.max(0)
694-
val takeResult = toDF().takeInternal(numRows + 1)
695-
696-
if (truncate) {
697-
println(showString(takeResult, numRows, truncate = 20, vertical = false))
698-
} else {
699-
println(showString(takeResult, numRows, truncate = 0, vertical = false))
700-
}
701-
}
702-
// scalastyle:on println
703-
704686
/**
705687
* Displays the Dataset in a tabular form. For example:
706688
* {{{
@@ -2467,11 +2449,6 @@ class Dataset[T] private[sql](
24672449
*/
24682450
def take(n: Int): Array[T] = head(n)
24692451

2470-
// An internal version of `take`, which won't set execution id and trigger listeners.
2471-
private[sql] def takeInternal(n: Int): Array[T] = {
2472-
collectFromPlan(limit(n).queryExecution.executedPlan)
2473-
}
2474-
24752452
/**
24762453
* Returns the first `n` rows in the Dataset as a list.
24772454
*
@@ -2496,11 +2473,6 @@ class Dataset[T] private[sql](
24962473
*/
24972474
def collect(): Array[T] = withAction("collect", queryExecution)(collectFromPlan)
24982475

2499-
// An internal version of `collect`, which won't set execution id and trigger listeners.
2500-
private[sql] def collectInternal(): Array[T] = {
2501-
collectFromPlan(queryExecution.executedPlan)
2502-
}
2503-
25042476
/**
25052477
* Returns a Java list that contains all rows in this Dataset.
25062478
*
@@ -2542,11 +2514,6 @@ class Dataset[T] private[sql](
25422514
plan.executeCollect().head.getLong(0)
25432515
}
25442516

2545-
// An internal version of `count`, which won't set execution id and trigger listeners.
2546-
private[sql] def countInternal(): Long = {
2547-
groupBy().count().queryExecution.executedPlan.executeCollect().head.getLong(0)
2548-
}
2549-
25502517
/**
25512518
* Returns a new Dataset that has exactly `numPartitions` partitions.
25522519
*
@@ -2792,7 +2759,7 @@ class Dataset[T] private[sql](
27922759
createTempViewCommand(viewName, replace = true, global = true)
27932760
}
27942761

2795-
private[spark] def createTempViewCommand(
2762+
private def createTempViewCommand(
27962763
viewName: String,
27972764
replace: Boolean,
27982765
global: Boolean): CreateViewCommand = {

sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ object SQLExecution {
2929

3030
val EXECUTION_ID_KEY = "spark.sql.execution.id"
3131

32+
private val IGNORE_NESTED_EXECUTION_ID = "spark.sql.execution.ignoreNestedExecutionId"
33+
3234
private val _nextExecutionId = new AtomicLong(0)
3335

3436
private def nextExecutionId: Long = _nextExecutionId.getAndIncrement
@@ -85,6 +87,9 @@ object SQLExecution {
8587
sc.setLocalProperty(EXECUTION_ID_KEY, null)
8688
}
8789
r
90+
} else if (sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null) {
91+
// If `IGNORE_NESTED_EXECUTION_ID` is set, just ignore this new execution id.
92+
body
8893
} else {
8994
// Don't support nested `withNewExecutionId`. This is an example of the nested
9095
// `withNewExecutionId`:
@@ -100,7 +105,9 @@ object SQLExecution {
100105
// all accumulator metrics will be 0. It will confuse people if we show them in Web UI.
101106
//
102107
// A real case is the `DataFrame.count` method.
103-
throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set")
108+
throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set, please wrap your " +
109+
"action with SQLExecution.ignoreNestedExecutionId if you don't want to track the Spark " +
110+
"jobs issued by the nested execution.")
104111
}
105112
}
106113

@@ -118,4 +125,19 @@ object SQLExecution {
118125
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
119126
}
120127
}
128+
129+
/**
130+
* Wrap an action which may have nested execution id. This method can be used to run an execution
131+
* inside another execution, e.g., `CacheTableCommand` need to call `Dataset.collect`.
132+
*/
133+
def ignoreNestedExecutionId[T](sparkSession: SparkSession)(body: => T): T = {
134+
val sc = sparkSession.sparkContext
135+
val allowNestedPreviousValue = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID)
136+
try {
137+
sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, "true")
138+
body
139+
} finally {
140+
sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, allowNestedPreviousValue)
141+
}
142+
}
121143
}

sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging
2727
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
2828
import org.apache.spark.sql.catalyst.TableIdentifier
2929
import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTableType}
30+
import org.apache.spark.sql.execution.SQLExecution
3031
import org.apache.spark.sql.internal.SessionState
3132

3233

@@ -58,7 +59,9 @@ case class AnalyzeTableCommand(
5859
// 2. when total size is changed, `oldRowCount` becomes invalid.
5960
// This is to make sure that we only record the right statistics.
6061
if (!noscan) {
61-
val newRowCount = sparkSession.table(tableIdentWithDB).countInternal()
62+
val newRowCount = SQLExecution.ignoreNestedExecutionId(sparkSession) {
63+
sparkSession.table(tableIdentWithDB).count()
64+
}
6265
if (newRowCount >= 0 && newRowCount != oldRowCount) {
6366
newStats = if (newStats.isDefined) {
6467
newStats.map(_.copy(rowCount = Some(BigInt(newRowCount))))

sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
2222
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
2323
import org.apache.spark.sql.catalyst.plans.QueryPlan
2424
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
25+
import org.apache.spark.sql.execution.SQLExecution
2526

2627
case class CacheTableCommand(
2728
tableIdent: TableIdentifier,
@@ -33,16 +34,16 @@ case class CacheTableCommand(
3334
override def innerChildren: Seq[QueryPlan[_]] = plan.toSeq
3435

3536
override def run(sparkSession: SparkSession): Seq[Row] = {
36-
plan.foreach { logicalPlan =>
37-
Dataset.ofRows(sparkSession, logicalPlan)
38-
.createTempViewCommand(tableIdent.quotedString, replace = false, global = false)
39-
.run(sparkSession)
40-
}
41-
sparkSession.catalog.cacheTable(tableIdent.quotedString)
37+
SQLExecution.ignoreNestedExecutionId(sparkSession) {
38+
plan.foreach { logicalPlan =>
39+
Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString)
40+
}
41+
sparkSession.catalog.cacheTable(tableIdent.quotedString)
4242

43-
if (!isLazy) {
44-
// Performs eager caching
45-
sparkSession.table(tableIdent).countInternal()
43+
if (!isLazy) {
44+
// Performs eager caching
45+
sparkSession.table(tableIdent).count()
46+
}
4647
}
4748

4849
Seq.empty[Row]

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.spark.input.{PortableDataStream, StreamInputFormat}
3232
import org.apache.spark.rdd.{BinaryFileRDD, RDD}
3333
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
3434
import org.apache.spark.sql.catalyst.InternalRow
35+
import org.apache.spark.sql.execution.SQLExecution
3536
import org.apache.spark.sql.execution.datasources._
3637
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
3738
import org.apache.spark.sql.types.StructType
@@ -144,8 +145,9 @@ object TextInputCSVDataSource extends CSVDataSource {
144145
inputPaths: Seq[FileStatus],
145146
parsedOptions: CSVOptions): StructType = {
146147
val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions)
147-
val maybeFirstLine =
148-
CSVUtils.filterCommentAndEmpty(csv, parsedOptions).takeInternal(1).headOption
148+
val maybeFirstLine = SQLExecution.ignoreNestedExecutionId(sparkSession) {
149+
CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption
150+
}
149151
inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions)
150152
}
151153

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.apache.spark.internal.Logging
2323
import org.apache.spark.Partition
2424
import org.apache.spark.rdd.RDD
2525
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext}
26+
import org.apache.spark.sql.execution.SQLExecution
2627
import org.apache.spark.sql.jdbc.JdbcDialects
2728
import org.apache.spark.sql.sources._
2829
import org.apache.spark.sql.types.StructType
@@ -129,14 +130,11 @@ private[sql] case class JDBCRelation(
129130
}
130131

131132
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
132-
import scala.collection.JavaConverters._
133-
134-
val options = jdbcOptions.asProperties.asScala +
135-
("url" -> jdbcOptions.url, "dbtable" -> jdbcOptions.table)
136-
val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append
137-
138-
new JdbcRelationProvider().createRelation(
139-
data.sparkSession.sqlContext, mode, options.toMap, data)
133+
SQLExecution.ignoreNestedExecutionId(data.sparkSession) {
134+
data.write
135+
.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append)
136+
.jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties)
137+
}
140138
}
141139

142140
override def toString: String = {

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext}
2222
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, StreamSinkProvider}
2323
import org.apache.spark.sql.streaming.OutputMode
2424
import org.apache.spark.sql.SaveMode
25+
import org.apache.spark.sql.execution.SQLExecution
2526
import org.apache.spark.sql.types.StructType
2627

2728
class ConsoleSink(options: Map[String, String]) extends Sink with Logging {
@@ -47,9 +48,11 @@ class ConsoleSink(options: Map[String, String]) extends Sink with Logging {
4748
println(batchIdStr)
4849
println("-------------------------------------------")
4950
// scalastyle:off println
50-
data.sparkSession.createDataFrame(
51-
data.sparkSession.sparkContext.parallelize(data.collectInternal()), data.schema)
52-
.showInternal(numRowsToShow, isTruncated)
51+
SQLExecution.ignoreNestedExecutionId(data.sparkSession) {
52+
data.sparkSession.createDataFrame(
53+
data.sparkSession.sparkContext.parallelize(data.collect()), data.schema)
54+
.show(numRowsToShow, isTruncated)
55+
}
5356
}
5457
}
5558

@@ -79,7 +82,9 @@ class ConsoleSinkProvider extends StreamSinkProvider
7982

8083
// Truncate the displayed data if it is too long, by default it is true
8184
val isTruncated = parameters.get("truncate").map(_.toBoolean).getOrElse(true)
82-
data.showInternal(numRowsToShow, isTruncated)
85+
SQLExecution.ignoreNestedExecutionId(sqlContext.sparkSession) {
86+
data.show(numRowsToShow, isTruncated)
87+
}
8388

8489
ConsoleRelation(sqlContext, data)
8590
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor
2929
import org.apache.spark.sql.catalyst.expressions.Attribute
3030
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
3131
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
32+
import org.apache.spark.sql.execution.SQLExecution
3233
import org.apache.spark.sql.streaming.OutputMode
3334
import org.apache.spark.sql.types.StructType
3435
import org.apache.spark.util.Utils
@@ -193,21 +194,23 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi
193194
}
194195
if (notCommitted) {
195196
logDebug(s"Committing batch $batchId to $this")
196-
outputMode match {
197-
case Append | Update =>
198-
val rows = AddedData(batchId, data.collectInternal())
199-
synchronized { batches += rows }
200-
201-
case Complete =>
202-
val rows = AddedData(batchId, data.collectInternal())
203-
synchronized {
204-
batches.clear()
205-
batches += rows
206-
}
207-
208-
case _ =>
209-
throw new IllegalArgumentException(
210-
s"Output mode $outputMode is not supported by MemorySink")
197+
SQLExecution.ignoreNestedExecutionId(data.sparkSession) {
198+
outputMode match {
199+
case Append | Update =>
200+
val rows = AddedData(batchId, data.collect())
201+
synchronized { batches += rows }
202+
203+
case Complete =>
204+
val rows = AddedData(batchId, data.collect())
205+
synchronized {
206+
batches.clear()
207+
batches += rows
208+
}
209+
210+
case _ =>
211+
throw new IllegalArgumentException(
212+
s"Output mode $outputMode is not supported by MemorySink")
213+
}
211214
}
212215
} else {
213216
logDebug(s"Skipping already committed batch: $batchId")

0 commit comments

Comments
 (0)