Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,12 @@ import java.util.concurrent.atomic.AtomicLong

import org.apache.spark.SparkContext
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd,
SparkListenerSQLExecutionStart}
import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart}

object SQLExecution {

val EXECUTION_ID_KEY = "spark.sql.execution.id"

private val IGNORE_NESTED_EXECUTION_ID = "spark.sql.execution.ignoreNestedExecutionId"

private val _nextExecutionId = new AtomicLong(0)

private def nextExecutionId: Long = _nextExecutionId.getAndIncrement
Expand All @@ -45,10 +42,8 @@ object SQLExecution {

private[sql] def checkSQLExecutionId(sparkSession: SparkSession): Unit = {
val sc = sparkSession.sparkContext
val isNestedExecution = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null
val hasExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) != null
// only throw an exception during tests. a missing execution ID should not fail a job.
if (testing && !isNestedExecution && !hasExecutionId) {
if (testing && sc.getLocalProperty(EXECUTION_ID_KEY) == null) {
// Attention testers: when a test fails with this exception, it means that the action that
// started execution of a query didn't call withNewExecutionId. The execution ID should be
// set by calling withNewExecutionId in the action that begins execution, like
Expand All @@ -66,56 +61,27 @@ object SQLExecution {
queryExecution: QueryExecution)(body: => T): T = {
val sc = sparkSession.sparkContext
val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY)
if (oldExecutionId == null) {
val executionId = SQLExecution.nextExecutionId
sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString)
executionIdToQueryExecution.put(executionId, queryExecution)
try {
// sparkContext.getCallSite() would first try to pick up any call site that was previously
// set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on
// streaming queries would give us call site like "run at <unknown>:0"
val callSite = sparkSession.sparkContext.getCallSite()

sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart(
executionId, callSite.shortForm, callSite.longForm, queryExecution.toString,
SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis()))
try {
body
} finally {
sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd(
executionId, System.currentTimeMillis()))
}
} finally {
executionIdToQueryExecution.remove(executionId)
sc.setLocalProperty(EXECUTION_ID_KEY, null)
}
} else if (sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null) {
// If `IGNORE_NESTED_EXECUTION_ID` is set, just ignore the execution id while evaluating the
// `body`, so that Spark jobs issued in the `body` won't be tracked.
val executionId = SQLExecution.nextExecutionId
sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString)
executionIdToQueryExecution.put(executionId, queryExecution)
try {
// sparkContext.getCallSite() would first try to pick up any call site that was previously
// set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on
// streaming queries would give us call site like "run at <unknown>:0"
val callSite = sparkSession.sparkContext.getCallSite()

sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart(
executionId, callSite.shortForm, callSite.longForm, queryExecution.toString,
SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis()))
try {
sc.setLocalProperty(EXECUTION_ID_KEY, null)
body
} finally {
sc.setLocalProperty(EXECUTION_ID_KEY, oldExecutionId)
sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd(
executionId, System.currentTimeMillis()))
}
} else {
// Don't support nested `withNewExecutionId`. This is an example of the nested
// `withNewExecutionId`:
//
// class DataFrame {
// def foo: T = withNewExecutionId { something.createNewDataFrame().collect() }
// }
//
// Note: `collect` will call withNewExecutionId
// In this case, only the "executedPlan" for "collect" will be executed. The "executedPlan"
// for the outer DataFrame won't be executed. So it's meaningless to create a new Execution
// for the outer DataFrame. Even if we track it, since its "executedPlan" doesn't run,
// all accumulator metrics will be 0. It will confuse people if we show them in Web UI.
//
// A real case is the `DataFrame.count` method.
throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set, please wrap your " +
"action with SQLExecution.ignoreNestedExecutionId if you don't want to track the Spark " +
"jobs issued by the nested execution.")
} finally {
executionIdToQueryExecution.remove(executionId)
sc.setLocalProperty(EXECUTION_ID_KEY, oldExecutionId)
}
}

Expand All @@ -133,20 +99,4 @@ object SQLExecution {
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
}
}

/**
* Wrap an action which may have nested execution id. This method can be used to run an execution
* inside another execution, e.g., `CacheTableCommand` need to call `Dataset.collect`. Note that,
* all Spark jobs issued in the body won't be tracked in UI.
*/
def ignoreNestedExecutionId[T](sparkSession: SparkSession)(body: => T): T = {
val sc = sparkSession.sparkContext
val allowNestedPreviousValue = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID)
try {
sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, "true")
body
} finally {
sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, allowNestedPreviousValue)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ case class AnalyzeTableCommand(
// 2. when total size is changed, `oldRowCount` becomes invalid.
// This is to make sure that we only record the right statistics.
if (!noscan) {
val newRowCount = SQLExecution.ignoreNestedExecutionId(sparkSession) {
sparkSession.table(tableIdentWithDB).count()
}
val newRowCount = sparkSession.table(tableIdentWithDB).count()
if (newRowCount >= 0 && newRowCount != oldRowCount) {
newStats = if (newStats.isDefined) {
newStats.map(_.copy(rowCount = Some(BigInt(newRowCount))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,14 @@ case class CacheTableCommand(
override def innerChildren: Seq[QueryPlan[_]] = plan.toSeq

override def run(sparkSession: SparkSession): Seq[Row] = {
SQLExecution.ignoreNestedExecutionId(sparkSession) {
plan.foreach { logicalPlan =>
Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString)
}
sparkSession.catalog.cacheTable(tableIdent.quotedString)
plan.foreach { logicalPlan =>
Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString)
}
sparkSession.catalog.cacheTable(tableIdent.quotedString)

if (!isLazy) {
// Performs eager caching
sparkSession.table(tableIdent).count()
}
if (!isLazy) {
// Performs eager caching
sparkSession.table(tableIdent).count()
}

Seq.empty[Row]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,7 @@ object TextInputCSVDataSource extends CSVDataSource {
inputPaths: Seq[FileStatus],
parsedOptions: CSVOptions): StructType = {
val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions)
val maybeFirstLine = SQLExecution.ignoreNestedExecutionId(sparkSession) {
CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption
}
val maybeFirstLine = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption
inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,9 @@ private[sql] case class JDBCRelation(
}

override def insert(data: DataFrame, overwrite: Boolean): Unit = {
SQLExecution.ignoreNestedExecutionId(data.sparkSession) {
data.write
.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append)
.jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties)
}
data.write
.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append)
.jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties)
}

override def toString: String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,9 @@ class ConsoleSink(options: Map[String, String]) extends Sink with Logging {
println(batchIdStr)
println("-------------------------------------------")
// scalastyle:off println
SQLExecution.ignoreNestedExecutionId(data.sparkSession) {
data.sparkSession.createDataFrame(
data.sparkSession.sparkContext.parallelize(data.collect()), data.schema)
.show(numRowsToShow, isTruncated)
}
data.sparkSession.createDataFrame(
data.sparkSession.sparkContext.parallelize(data.collect()), data.schema)
.show(numRowsToShow, isTruncated)
}
}

Expand Down Expand Up @@ -82,9 +80,7 @@ class ConsoleSinkProvider extends StreamSinkProvider

// Truncate the displayed data if it is too long, by default it is true
val isTruncated = parameters.get("truncate").map(_.toBoolean).getOrElse(true)
SQLExecution.ignoreNestedExecutionId(sqlContext.sparkSession) {
data.show(numRowsToShow, isTruncated)
}
data.show(numRowsToShow, isTruncated)

ConsoleRelation(sqlContext, data)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,23 +194,21 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi
}
if (notCommitted) {
logDebug(s"Committing batch $batchId to $this")
SQLExecution.ignoreNestedExecutionId(data.sparkSession) {
outputMode match {
case Append | Update =>
val rows = AddedData(batchId, data.collect())
synchronized { batches += rows }

case Complete =>
val rows = AddedData(batchId, data.collect())
synchronized {
batches.clear()
batches += rows
}

case _ =>
throw new IllegalArgumentException(
s"Output mode $outputMode is not supported by MemorySink")
}
outputMode match {
case Append | Update =>
val rows = AddedData(batchId, data.collect())
synchronized { batches += rows }

case Complete =>
val rows = AddedData(batchId, data.collect())
synchronized {
batches.clear()
batches += rows
}

case _ =>
throw new IllegalArgumentException(
s"Output mode $outputMode is not supported by MemorySink")
}
} else {
logDebug(s"Skipping already committed batch: $batchId")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,9 @@ import org.apache.spark.sql.SparkSession
class SQLExecutionSuite extends SparkFunSuite {

test("concurrent query execution (SPARK-10548)") {
// Try to reproduce the issue with the old SparkContext
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now we allow nested execution, so we can't reproduce this bug anymore.

val conf = new SparkConf()
.setMaster("local[*]")
.setAppName("test")
val badSparkContext = new BadSparkContext(conf)
try {
testConcurrentQueryExecution(badSparkContext)
fail("unable to reproduce SPARK-10548")
} catch {
case e: IllegalArgumentException =>
assert(e.getMessage.contains(SQLExecution.EXECUTION_ID_KEY))
} finally {
badSparkContext.stop()
}

// Verify that the issue is fixed with the latest SparkContext
val goodSparkContext = new SparkContext(conf)
try {
testConcurrentQueryExecution(goodSparkContext)
Expand Down Expand Up @@ -134,17 +121,6 @@ class SQLExecutionSuite extends SparkFunSuite {
}
}

/**
* A bad [[SparkContext]] that does not clone the inheritable thread local properties
* when passing them to children threads.
*/
private class BadSparkContext(conf: SparkConf) extends SparkContext(conf) {
protected[spark] override val localProperties = new InheritableThreadLocal[Properties] {
override protected def childValue(parent: Properties): Properties = new Properties(parent)
override protected def initialValue(): Properties = new Properties()
}
}

object SQLExecutionSuite {
@volatile var canProgress = false
}