Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,10 @@ private[kafka010] object KafkaWriter extends Logging {
topic: Option[String] = None): Unit = {
val schema = queryExecution.analyzed.output
validateQuery(queryExecution, kafkaParameters, topic)
SQLExecution.withNewExecutionId(sparkSession, queryExecution) {
queryExecution.toRdd.foreachPartition { iter =>
val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic)
Utils.tryWithSafeFinally(block = writeTask.execute(iter))(
finallyBlock = writeTask.close())
}
queryExecution.toRdd.foreachPartition { iter =>
val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic)
Utils.tryWithSafeFinally(block = writeTask.execute(iter))(
finallyBlock = writeTask.close())
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation}
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogRelation, CatalogTable, CatalogTableType}
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan}
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation, SaveIntoDataSourceCommand}
import org.apache.spark.sql.sources.BaseRelation
Expand Down Expand Up @@ -607,7 +608,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
try {
val start = System.nanoTime()
// call `QueryExecution.toRDD` to trigger the execution of commands.
qe.toRdd
SQLExecution.withNewExecutionId(session, qe)(qe.toRdd)
val end = System.nanoTime()
session.listenerManager.onSuccess(name, qe, end - start)
} catch {
Expand Down
4 changes: 2 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,9 @@ class Dataset[T] private[sql](
// to happen right away to let these side effects take place eagerly.
queryExecution.analyzed match {
case c: Command =>
LocalRelation(c.output, queryExecution.executedPlan.executeCollect())
Copy link
Contributor

@cloud-fan cloud-fan Apr 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about LocalRelation(c.output, withAction("collect", queryExecution)(_. executeCollect()))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually do we need to do this? most Commands are just local operations(talking with metastore).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the check I added to ensure we get the same results in the SQL tab has several hundred failures that go through this. Looks like the path is almost always spark.sql when the SQL statement is a command like CTAS.

I like your version and will update.

LocalRelation(c.output, withAction("collect", queryExecution)(_.executeCollect()))
case u @ Union(children) if children.forall(_.isInstanceOf[Command]) =>
LocalRelation(u.output, queryExecution.executedPlan.executeCollect())
LocalRelation(u.output, withAction("collect", queryExecution)(_.executeCollect()))
case _ =>
queryExecution.analyzed
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,24 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
case ExecutedCommandExec(desc: DescribeTableCommand) =>
// If it is a describe command for a Hive table, we want to have the output format
// be similar with Hive.
desc.run(sparkSession).map {
SQLExecution.withNewExecutionId(sparkSession, this) {
desc.run(sparkSession)
}.map {
case Row(name: String, dataType: String, comment) =>
Seq(name, dataType,
Option(comment.asInstanceOf[String]).getOrElse(""))
.map(s => String.format(s"%-20s", s))
.mkString("\t")
.map(s => String.format(s"%-20s", s))
.mkString("\t")
}
// SHOW TABLES in Hive only output table names, while ours output database, table name, isTemp.
case command @ ExecutedCommandExec(s: ShowTablesCommand) if !s.isExtended =>
command.executeCollect().map(_.getString(1))
SQLExecution.withNewExecutionId(sparkSession, this) {
command.executeCollect()
}.map(_.getString(1))
case other =>
val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq
val result: Seq[Seq[Any]] = SQLExecution.withNewExecutionId(sparkSession, this) {
other.executeCollectPublic()
}.map(_.toSeq).toSeq
// We need the types so we can output struct field names
val types = analyzed.output.map(_.dataType)
// Reformat to match hive tab delimited output.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicLong

import org.apache.spark.SparkContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd,
SparkListenerSQLExecutionStart}
import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart}

object SQLExecution {
object SQLExecution extends Logging {

val EXECUTION_ID_KEY = "spark.sql.execution.id"

Expand All @@ -39,6 +39,32 @@ object SQLExecution {
executionIdToQueryExecution.get(executionId)
}

private val testing = sys.props.contains("spark.testing")

private[sql] def checkSQLExecutionId(sparkSession: SparkSession): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is only called in FileFormatWirter, is there any other places we need to consider?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To keep this PR from growing too big, I want to just use it where I've removed withNewExecutionId to check for regressions. I'll follow up with another PR with more checks.

// only throw an exception during tests. a missing execution ID should not fail a job.
if (testing && sparkSession.sparkContext.getLocalProperty(EXECUTION_ID_KEY) == null) {
// Attention testers: when a test fails with this exception, it means that the action that
// started execution of a query didn't call withNewExecutionId. The execution ID should be
// set by calling withNewExecutionId in the action that begins execution, like
// Dataset.collect or DataFrameWriter.insertInto.
throw new IllegalStateException("Execution ID should be set")
}
}

private val ALLOW_NESTED_EXECUTION = "spark.sql.execution.nested"

private[sql] def nested[T](sparkSession: SparkSession)(body: => T): T = {
val sc = sparkSession.sparkContext
val allowNestedPreviousValue = sc.getLocalProperty(SQLExecution.ALLOW_NESTED_EXECUTION)
try {
sc.setLocalProperty(SQLExecution.ALLOW_NESTED_EXECUTION, "true")
body
} finally {
sc.setLocalProperty(SQLExecution.ALLOW_NESTED_EXECUTION, allowNestedPreviousValue)
}
}

/**
* Wrap an action that will execute "queryExecution" to track all Spark jobs in the body so that
* we can connect them with an execution.
Expand Down Expand Up @@ -73,21 +99,35 @@ object SQLExecution {
}
r
} else {
// Don't support nested `withNewExecutionId`. This is an example of the nested
// `withNewExecutionId`:
// Nesting `withNewExecutionId` may be incorrect; log a warning.
//
// This is an example of the nested `withNewExecutionId`:
//
// class DataFrame {
// // Note: `collect` will call withNewExecutionId
// def foo: T = withNewExecutionId { something.createNewDataFrame().collect() }
// }
//
// Note: `collect` will call withNewExecutionId
// In this case, only the "executedPlan" for "collect" will be executed. The "executedPlan"
// for the outer DataFrame won't be executed. So it's meaningless to create a new Execution
// for the outer DataFrame. Even if we track it, since its "executedPlan" doesn't run,
// for the outer Dataset won't be executed. So it's meaningless to create a new Execution
// for the outer Dataset. Even if we track it, since its "executedPlan" doesn't run,
// all accumulator metrics will be 0. It will confuse people if we show them in Web UI.
//
// A real case is the `DataFrame.count` method.
throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set")
// Some operations will start nested executions. For example, CacheTableCommand will uses
// Dataset#count to materialize cached records when caching is not lazy. Because there are
// legitimate reasons to nest executions in withNewExecutionId, this logs a warning but does
// not throw an exception to avoid failing at runtime. Exceptions will be thrown for tests
// to ensure that nesting is avoided.
//
// To avoid this warning, use nested { ... }
if (!Option(sc.getLocalProperty(ALLOW_NESTED_EXECUTION)).exists(_.toBoolean)) {
if (testing) {
throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set: $oldExecutionId")
} else {
logWarning(s"$EXECUTION_ID_KEY is already set")
}
}
body
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTableTyp
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.SQLExecution


/**
Expand Down Expand Up @@ -96,7 +97,9 @@ case class AnalyzeColumnCommand(
attributesToAnalyze.map(ColumnStat.statExprs(_, ndvMaxErr))

val namedExpressions = expressions.map(e => Alias(e, e.toString)())
val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)).head()
val statsRow = SQLExecution.nested(sparkSession) {
Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)).head()
}

val rowCount = statsRow.getLong(0)
val columnStats = attributesToAnalyze.zipWithIndex.map { case (attr, i) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTableType}
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.internal.SessionState


Expand Down Expand Up @@ -56,7 +57,9 @@ case class AnalyzeTableCommand(
// 2. when total size is changed, `oldRowCount` becomes invalid.
// This is to make sure that we only record the right statistics.
if (!noscan) {
val newRowCount = sparkSession.table(tableIdentWithDB).count()
val newRowCount = SQLExecution.nested(sparkSession) {
sparkSession.table(tableIdentWithDB).count()
}
if (newRowCount >= 0 && newRowCount != oldRowCount) {
newStats = if (newStats.isDefined) {
newStats.map(_.copy(rowCount = Some(BigInt(newRowCount))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SQLExecution

case class CacheTableCommand(
tableIdent: TableIdentifier,
Expand All @@ -36,13 +37,17 @@ case class CacheTableCommand(

override def run(sparkSession: SparkSession): Seq[Row] = {
plan.foreach { logicalPlan =>
Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString)
SQLExecution.nested(sparkSession) {
Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString)
}
}
sparkSession.catalog.cacheTable(tableIdent.quotedString)

if (!isLazy) {
// Performs eager caching
sparkSession.table(tableIdent).count()
SQLExecution.nested(sparkSession) {
sparkSession.table(tableIdent).count()
}
}

Seq.empty[Row]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,50 +161,51 @@ object FileFormatWriter extends Logging {
}
}

SQLExecution.withNewExecutionId(sparkSession, queryExecution) {
// This call shouldn't be put into the `try` block below because it only initializes and
// prepares the job, any exception thrown from here shouldn't cause abortJob() to be called.
committer.setupJob(job)

try {
val rdd = if (orderingMatched) {
queryExecution.toRdd
} else {
SortExec(
requiredOrdering.map(SortOrder(_, Ascending)),
global = false,
child = queryExecution.executedPlan).execute()
}
val ret = new Array[WriteTaskResult](rdd.partitions.length)
sparkSession.sparkContext.runJob(
rdd,
(taskContext: TaskContext, iter: Iterator[InternalRow]) => {
executeTask(
description = description,
sparkStageId = taskContext.stageId(),
sparkPartitionId = taskContext.partitionId(),
sparkAttemptNumber = taskContext.attemptNumber(),
committer,
iterator = iter)
},
0 until rdd.partitions.length,
(index, res: WriteTaskResult) => {
committer.onTaskCommit(res.commitMsg)
ret(index) = res
})

val commitMsgs = ret.map(_.commitMsg)
val updatedPartitions = ret.flatMap(_.updatedPartitions)
.distinct.map(PartitioningUtils.parsePathFragment)

committer.commitJob(job, commitMsgs)
logInfo(s"Job ${job.getJobID} committed.")
refreshFunction(updatedPartitions)
} catch { case cause: Throwable =>
logError(s"Aborting job ${job.getJobID}.", cause)
committer.abortJob(job)
throw new SparkException("Job aborted.", cause)
// During tests, make sure there is an execution ID.
SQLExecution.checkSQLExecutionId(sparkSession)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The major issue is this change. For all queries using FileFormatWriter, we won't get any metrics because of

queryExecution = Dataset.ofRows(sparkSession, query).queryExecution,
. It creates a new QueryExecution and we don't track it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make SQL metrics work, we should always wrap the correct QueryExecution with SparkListenerSQLExecutionStart.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zsxwing, that and similar cases are what I was talking about earlier when I said there are two physical plans. The inner Dataset.ofRows ends up creating a completely separate plan.

Are you saying that adding SparkListenerSQLExecutionStart (and also end) events will fix the metrics problem? I think it would at least require the metrics work-around I added to SQLListener, since metrics are filtered out if they aren't reported by the physical plan.


// This call shouldn't be put into the `try` block below because it only initializes and
// prepares the job, any exception thrown from here shouldn't cause abortJob() to be called.
committer.setupJob(job)

try {
val rdd = if (orderingMatched) {
queryExecution.toRdd
} else {
SortExec(
requiredOrdering.map(SortOrder(_, Ascending)),
global = false,
child = queryExecution.executedPlan).execute()
}
val ret = new Array[WriteTaskResult](rdd.partitions.length)
sparkSession.sparkContext.runJob(
rdd,
(taskContext: TaskContext, iter: Iterator[InternalRow]) => {
executeTask(
description = description,
sparkStageId = taskContext.stageId(),
sparkPartitionId = taskContext.partitionId(),
sparkAttemptNumber = taskContext.attemptNumber(),
committer,
iterator = iter)
},
0 until rdd.partitions.length,
(index, res: WriteTaskResult) => {
committer.onTaskCommit(res.commitMsg)
ret(index) = res
})

val commitMsgs = ret.map(_.commitMsg)
val updatedPartitions = ret.flatMap(_.updatedPartitions)
.distinct.map(PartitioningUtils.parsePathFragment)

committer.commitJob(job, commitMsgs)
logInfo(s"Job ${job.getJobID} committed.")
refreshFunction(updatedPartitions)
} catch { case cause: Throwable =>
logError(s"Aborting job ${job.getJobID}.", cause)
committer.abortJob(job)
throw new SparkException("Job aborted.", cause)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.sources.InsertableRelation

Expand All @@ -37,14 +38,18 @@ case class InsertIntoDataSourceCommand(

override def run(sparkSession: SparkSession): Seq[Row] = {
val relation = logicalRelation.relation.asInstanceOf[InsertableRelation]
val data = Dataset.ofRows(sparkSession, query)
// Apply the schema of the existing table to the new data.
val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema)
relation.insert(df, overwrite)

// Re-cache all cached plans(including this relation itself, if it's cached) that refer to this
// data source relation.
sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, logicalRelation)
SQLExecution.nested(sparkSession) {
val data = Dataset.ofRows(sparkSession, query)

// Apply the schema of the existing table to the new data.
val df = sparkSession.internalCreateDataFrame(
data.queryExecution.toRdd, logicalRelation.schema)
relation.insert(df, overwrite)

// Re-cache all cached plans(including this relation itself, if it's cached) that refer to
// this data source relation.
sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, logicalRelation)
}

Seq.empty[Row]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources
import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.command.RunnableCommand

/**
Expand All @@ -41,12 +42,13 @@ case class SaveIntoDataSourceCommand(
override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query)

override def run(sparkSession: SparkSession): Seq[Row] = {
DataSource(
sparkSession,
className = provider,
partitionColumns = partitionColumns,
options = options).write(mode, Dataset.ofRows(sparkSession, query))

SQLExecution.nested(sparkSession) {
DataSource(
sparkSession,
className = provider,
partitionColumns = partitionColumns,
options = options).write(mode, Dataset.ofRows(sparkSession, query))
}
Seq.empty[Row]
}
}
Loading