Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 3 additions & 36 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -246,13 +246,8 @@ class Dataset[T] private[sql](
_numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = {
val numRows = _numRows.max(0)
val takeResult = toDF().take(numRows + 1)
showString(takeResult, numRows, truncate, vertical)
}

private def showString(
dataWithOneMoreRow: Array[Row], numRows: Int, truncate: Int, vertical: Boolean): String = {
val hasMoreData = dataWithOneMoreRow.length > numRows
val data = dataWithOneMoreRow.take(numRows)
val hasMoreData = takeResult.length > numRows
val data = takeResult.take(numRows)

lazy val timeZone =
DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone)
Expand Down Expand Up @@ -688,19 +683,6 @@ class Dataset[T] private[sql](
println(showString(numRows, truncate = 0))
}

// An internal version of `show`, which won't set execution id and trigger listeners.
private[sql] def showInternal(_numRows: Int, truncate: Boolean): Unit = {
val numRows = _numRows.max(0)
val takeResult = toDF().takeInternal(numRows + 1)

if (truncate) {
println(showString(takeResult, numRows, truncate = 20, vertical = false))
} else {
println(showString(takeResult, numRows, truncate = 0, vertical = false))
}
}
// scalastyle:on println

/**
* Displays the Dataset in a tabular form. For example:
* {{{
Expand Down Expand Up @@ -2467,11 +2449,6 @@ class Dataset[T] private[sql](
*/
def take(n: Int): Array[T] = head(n)

// An internal version of `take`, which won't set execution id and trigger listeners.
private[sql] def takeInternal(n: Int): Array[T] = {
collectFromPlan(limit(n).queryExecution.executedPlan)
}

/**
* Returns the first `n` rows in the Dataset as a list.
*
Expand All @@ -2496,11 +2473,6 @@ class Dataset[T] private[sql](
*/
def collect(): Array[T] = withAction("collect", queryExecution)(collectFromPlan)

// An internal version of `collect`, which won't set execution id and trigger listeners.
private[sql] def collectInternal(): Array[T] = {
collectFromPlan(queryExecution.executedPlan)
}

/**
* Returns a Java list that contains all rows in this Dataset.
*
Expand Down Expand Up @@ -2542,11 +2514,6 @@ class Dataset[T] private[sql](
plan.executeCollect().head.getLong(0)
}

// An internal version of `count`, which won't set execution id and trigger listeners.
private[sql] def countInternal(): Long = {
groupBy().count().queryExecution.executedPlan.executeCollect().head.getLong(0)
}

/**
* Returns a new Dataset that has exactly `numPartitions` partitions.
*
Expand Down Expand Up @@ -2792,7 +2759,7 @@ class Dataset[T] private[sql](
createTempViewCommand(viewName, replace = true, global = true)
}

private[spark] def createTempViewCommand(
private def createTempViewCommand(
viewName: String,
replace: Boolean,
global: Boolean): CreateViewCommand = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ object SQLExecution {

val EXECUTION_ID_KEY = "spark.sql.execution.id"

private val IGNORE_NESTED_EXECUTION_ID = "spark.sql.execution.ignoreNestedExecutionId"

private val _nextExecutionId = new AtomicLong(0)

private def nextExecutionId: Long = _nextExecutionId.getAndIncrement
Expand All @@ -42,8 +44,11 @@ object SQLExecution {
private val testing = sys.props.contains("spark.testing")

private[sql] def checkSQLExecutionId(sparkSession: SparkSession): Unit = {
val sc = sparkSession.sparkContext
val isNestedExecution = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null
val hasExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) != null
// only throw an exception during tests. a missing execution ID should not fail a job.
if (testing && sparkSession.sparkContext.getLocalProperty(EXECUTION_ID_KEY) == null) {
if (testing && !isNestedExecution && !hasExecutionId) {
// Attention testers: when a test fails with this exception, it means that the action that
// started execution of a query didn't call withNewExecutionId. The execution ID should be
// set by calling withNewExecutionId in the action that begins execution, like
Expand All @@ -65,7 +70,7 @@ object SQLExecution {
val executionId = SQLExecution.nextExecutionId
sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString)
executionIdToQueryExecution.put(executionId, queryExecution)
val r = try {
try {
// sparkContext.getCallSite() would first try to pick up any call site that was previously
// set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on
// streaming queries would give us call site like "run at <unknown>:0"
Expand All @@ -84,7 +89,15 @@ object SQLExecution {
executionIdToQueryExecution.remove(executionId)
sc.setLocalProperty(EXECUTION_ID_KEY, null)
}
r
} else if (sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null) {
// If `IGNORE_NESTED_EXECUTION_ID` is set, just ignore the execution id while evaluating the
// `body`, so that Spark jobs issued in the `body` won't be tracked.
try {
sc.setLocalProperty(EXECUTION_ID_KEY, null)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@viirya now we won't track the spark jobs even in SparkListener

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good.

body
} finally {
sc.setLocalProperty(EXECUTION_ID_KEY, oldExecutionId)
}
} else {
// Don't support nested `withNewExecutionId`. This is an example of the nested
// `withNewExecutionId`:
Expand All @@ -100,7 +113,9 @@ object SQLExecution {
// all accumulator metrics will be 0. It will confuse people if we show them in Web UI.
//
// A real case is the `DataFrame.count` method.
throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set")
throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set, please wrap your " +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nested execution is a developer problem, not a user problem. That's why the original PR did not throw IllegalArgumentException outside of testing. I think that should still be how this is handled.

If this is thrown at runtime, adding the text about ignoreNestedExecutionId is confusing for users, who can't (or shouldn't) set it. A comment is more appropriate if users will see this message. If the change to only throw during testing is added, then I think it is fine to add the text to the exception.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SQLExecution is kind of a developer API, people who develop data source may need to call ignoreNestedExecutionId inside their data source implementation, as reading/writing data source will be run inside a command and they may hit the nested execution problem. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that this is an easy error to hit and it shouldn't affect end users. It is better to warn that something is wrong than to fail a job that would otherwise succeed for a bug in Spark. As for the error message, I think it is fine if we intend to leave it in. I'd just rather not fail user jobs here.

I assume that DataSource developers will have tests, but probably not ones that know to set spark.testing. Is there a better way to detect test cases?

"action with SQLExecution.ignoreNestedExecutionId if you don't want to track the Spark " +
"jobs issued by the nested execution.")
}
}

Expand All @@ -118,4 +133,20 @@ object SQLExecution {
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
}
}

/**
* Wrap an action which may have nested execution id. This method can be used to run an execution
* inside another execution, e.g., `CacheTableCommand` need to call `Dataset.collect`. Note that,
* all Spark jobs issued in the body won't be tracked in UI.
*/
def ignoreNestedExecutionId[T](sparkSession: SparkSession)(body: => T): T = {
val sc = sparkSession.sparkContext
val allowNestedPreviousValue = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID)
try {
sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, "true")
body
} finally {
sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, allowNestedPreviousValue)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTableType}
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.internal.SessionState


Expand Down Expand Up @@ -58,7 +59,9 @@ case class AnalyzeTableCommand(
// 2. when total size is changed, `oldRowCount` becomes invalid.
// This is to make sure that we only record the right statistics.
if (!noscan) {
val newRowCount = sparkSession.table(tableIdentWithDB).countInternal()
val newRowCount = SQLExecution.ignoreNestedExecutionId(sparkSession) {
sparkSession.table(tableIdentWithDB).count()
}
if (newRowCount >= 0 && newRowCount != oldRowCount) {
newStats = if (newStats.isDefined) {
newStats.map(_.copy(rowCount = Some(BigInt(newRowCount))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SQLExecution

case class CacheTableCommand(
tableIdent: TableIdentifier,
Expand All @@ -33,16 +34,16 @@ case class CacheTableCommand(
override def innerChildren: Seq[QueryPlan[_]] = plan.toSeq

override def run(sparkSession: SparkSession): Seq[Row] = {
plan.foreach { logicalPlan =>
Dataset.ofRows(sparkSession, logicalPlan)
.createTempViewCommand(tableIdent.quotedString, replace = false, global = false)
.run(sparkSession)
}
sparkSession.catalog.cacheTable(tableIdent.quotedString)
SQLExecution.ignoreNestedExecutionId(sparkSession) {
plan.foreach { logicalPlan =>
Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString)
}
sparkSession.catalog.cacheTable(tableIdent.quotedString)

if (!isLazy) {
// Performs eager caching
sparkSession.table(tableIdent).countInternal()
if (!isLazy) {
// Performs eager caching
sparkSession.table(tableIdent).count()
}
}

Seq.empty[Row]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.input.{PortableDataStream, StreamInputFormat}
import org.apache.spark.rdd.{BinaryFileRDD, RDD}
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -144,8 +145,9 @@ object TextInputCSVDataSource extends CSVDataSource {
inputPaths: Seq[FileStatus],
parsedOptions: CSVOptions): StructType = {
val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions)
val maybeFirstLine =
CSVUtils.filterCommentAndEmpty(csv, parsedOptions).takeInternal(1).headOption
val maybeFirstLine = SQLExecution.ignoreNestedExecutionId(sparkSession) {
CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption
}
inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.Partition
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext}
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -129,14 +130,11 @@ private[sql] case class JDBCRelation(
}

override def insert(data: DataFrame, overwrite: Boolean): Unit = {
import scala.collection.JavaConverters._

val options = jdbcOptions.asProperties.asScala +
("url" -> jdbcOptions.url, "dbtable" -> jdbcOptions.table)
val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append

new JdbcRelationProvider().createRelation(
data.sparkSession.sqlContext, mode, options.toMap, data)
SQLExecution.ignoreNestedExecutionId(data.sparkSession) {
data.write
.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append)
.jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties)
}
}

override def toString: String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, StreamSinkProvider}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.types.StructType

class ConsoleSink(options: Map[String, String]) extends Sink with Logging {
Expand All @@ -47,9 +48,11 @@ class ConsoleSink(options: Map[String, String]) extends Sink with Logging {
println(batchIdStr)
println("-------------------------------------------")
// scalastyle:off println
data.sparkSession.createDataFrame(
data.sparkSession.sparkContext.parallelize(data.collectInternal()), data.schema)
.showInternal(numRowsToShow, isTruncated)
SQLExecution.ignoreNestedExecutionId(data.sparkSession) {
data.sparkSession.createDataFrame(
data.sparkSession.sparkContext.parallelize(data.collect()), data.schema)
.show(numRowsToShow, isTruncated)
}
}
}

Expand Down Expand Up @@ -79,7 +82,9 @@ class ConsoleSinkProvider extends StreamSinkProvider

// Truncate the displayed data if it is too long, by default it is true
val isTruncated = parameters.get("truncate").map(_.toBoolean).getOrElse(true)
data.showInternal(numRowsToShow, isTruncated)
SQLExecution.ignoreNestedExecutionId(sqlContext.sparkSession) {
data.show(numRowsToShow, isTruncated)
}

ConsoleRelation(sqlContext, data)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -193,21 +194,23 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi
}
if (notCommitted) {
logDebug(s"Committing batch $batchId to $this")
outputMode match {
case Append | Update =>
val rows = AddedData(batchId, data.collectInternal())
synchronized { batches += rows }

case Complete =>
val rows = AddedData(batchId, data.collectInternal())
synchronized {
batches.clear()
batches += rows
}

case _ =>
throw new IllegalArgumentException(
s"Output mode $outputMode is not supported by MemorySink")
SQLExecution.ignoreNestedExecutionId(data.sparkSession) {
outputMode match {
case Append | Update =>
val rows = AddedData(batchId, data.collect())
synchronized { batches += rows }

case Complete =>
val rows = AddedData(batchId, data.collect())
synchronized {
batches.clear()
batches += rows
}

case _ =>
throw new IllegalArgumentException(
s"Output mode $outputMode is not supported by MemorySink")
}
}
} else {
logDebug(s"Skipping already committed batch: $batchId")
Expand Down