From 0795c16b4beaf70430e8dc62f135f99ac801960e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 26 Jun 2017 12:36:59 +0800 Subject: [PATCH 1/2] introduce SQLExecution.ignoreNestedExecutionId --- .../scala/org/apache/spark/sql/Dataset.scala | 39 ++----------------- .../spark/sql/execution/SQLExecution.scala | 24 +++++++++++- .../command/AnalyzeTableCommand.scala | 5 ++- .../spark/sql/execution/command/cache.scala | 19 ++++----- .../datasources/csv/CSVDataSource.scala | 6 ++- .../datasources/jdbc/JDBCRelation.scala | 14 +++---- .../sql/execution/streaming/console.scala | 13 +++++-- .../sql/execution/streaming/memory.scala | 33 +++++++++------- 8 files changed, 77 insertions(+), 76 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 6e66e92091ff..268a37ff5d27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -246,13 +246,8 @@ class Dataset[T] private[sql]( _numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = { val numRows = _numRows.max(0) val takeResult = toDF().take(numRows + 1) - showString(takeResult, numRows, truncate, vertical) - } - - private def showString( - dataWithOneMoreRow: Array[Row], numRows: Int, truncate: Int, vertical: Boolean): String = { - val hasMoreData = dataWithOneMoreRow.length > numRows - val data = dataWithOneMoreRow.take(numRows) + val hasMoreData = takeResult.length > numRows + val data = takeResult.take(numRows) lazy val timeZone = DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone) @@ -688,19 +683,6 @@ class Dataset[T] private[sql]( println(showString(numRows, truncate = 0)) } - // An internal version of `show`, which won't set execution id and trigger listeners. - private[sql] def showInternal(_numRows: Int, truncate: Boolean): Unit = { - val numRows = _numRows.max(0) - val takeResult = toDF().takeInternal(numRows + 1) - - if (truncate) { - println(showString(takeResult, numRows, truncate = 20, vertical = false)) - } else { - println(showString(takeResult, numRows, truncate = 0, vertical = false)) - } - } - // scalastyle:on println - /** * Displays the Dataset in a tabular form. For example: * {{{ @@ -2467,11 +2449,6 @@ class Dataset[T] private[sql]( */ def take(n: Int): Array[T] = head(n) - // An internal version of `take`, which won't set execution id and trigger listeners. - private[sql] def takeInternal(n: Int): Array[T] = { - collectFromPlan(limit(n).queryExecution.executedPlan) - } - /** * Returns the first `n` rows in the Dataset as a list. * @@ -2496,11 +2473,6 @@ class Dataset[T] private[sql]( */ def collect(): Array[T] = withAction("collect", queryExecution)(collectFromPlan) - // An internal version of `collect`, which won't set execution id and trigger listeners. - private[sql] def collectInternal(): Array[T] = { - collectFromPlan(queryExecution.executedPlan) - } - /** * Returns a Java list that contains all rows in this Dataset. * @@ -2542,11 +2514,6 @@ class Dataset[T] private[sql]( plan.executeCollect().head.getLong(0) } - // An internal version of `count`, which won't set execution id and trigger listeners. - private[sql] def countInternal(): Long = { - groupBy().count().queryExecution.executedPlan.executeCollect().head.getLong(0) - } - /** * Returns a new Dataset that has exactly `numPartitions` partitions. * @@ -2792,7 +2759,7 @@ class Dataset[T] private[sql]( createTempViewCommand(viewName, replace = true, global = true) } - private[spark] def createTempViewCommand( + private def createTempViewCommand( viewName: String, replace: Boolean, global: Boolean): CreateViewCommand = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index bb206e84325f..699f3b06be3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -29,6 +29,8 @@ object SQLExecution { val EXECUTION_ID_KEY = "spark.sql.execution.id" + private val IGNORE_NESTED_EXECUTION_ID = "spark.sql.execution.ignoreNestedExecutionId" + private val _nextExecutionId = new AtomicLong(0) private def nextExecutionId: Long = _nextExecutionId.getAndIncrement @@ -85,6 +87,9 @@ object SQLExecution { sc.setLocalProperty(EXECUTION_ID_KEY, null) } r + } else if (sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null) { + // If `IGNORE_NESTED_EXECUTION_ID` is set, just ignore this new execution id. + body } else { // Don't support nested `withNewExecutionId`. This is an example of the nested // `withNewExecutionId`: @@ -100,7 +105,9 @@ object SQLExecution { // all accumulator metrics will be 0. It will confuse people if we show them in Web UI. // // A real case is the `DataFrame.count` method. - throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set") + throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set, please wrap your " + + "action with SQLExecution.ignoreNestedExecutionId if you don't want to track the Spark " + + "jobs issued by the nested execution.") } } @@ -118,4 +125,19 @@ object SQLExecution { sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) } } + + /** + * Wrap an action which may have nested execution id. This method can be used to run an execution + * inside another execution, e.g., `CacheTableCommand` need to call `Dataset.collect`. + */ + def ignoreNestedExecutionId[T](sparkSession: SparkSession)(body: => T): T = { + val sc = sparkSession.sparkContext + val allowNestedPreviousValue = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) + try { + sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, "true") + body + } finally { + sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, allowNestedPreviousValue) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 06e588f56f1e..13b8faff844c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTableType} +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.internal.SessionState @@ -58,7 +59,9 @@ case class AnalyzeTableCommand( // 2. when total size is changed, `oldRowCount` becomes invalid. // This is to make sure that we only record the right statistics. if (!noscan) { - val newRowCount = sparkSession.table(tableIdentWithDB).countInternal() + val newRowCount = SQLExecution.ignoreNestedExecutionId(sparkSession) { + sparkSession.table(tableIdentWithDB).count() + } if (newRowCount >= 0 && newRowCount != oldRowCount) { newStats = if (newStats.isDefined) { newStats.map(_.copy(rowCount = Some(BigInt(newRowCount)))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index 184d0387ebfa..d36eb7587a3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SQLExecution case class CacheTableCommand( tableIdent: TableIdentifier, @@ -33,16 +34,16 @@ case class CacheTableCommand( override def innerChildren: Seq[QueryPlan[_]] = plan.toSeq override def run(sparkSession: SparkSession): Seq[Row] = { - plan.foreach { logicalPlan => - Dataset.ofRows(sparkSession, logicalPlan) - .createTempViewCommand(tableIdent.quotedString, replace = false, global = false) - .run(sparkSession) - } - sparkSession.catalog.cacheTable(tableIdent.quotedString) + SQLExecution.ignoreNestedExecutionId(sparkSession) { + plan.foreach { logicalPlan => + Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString) + } + sparkSession.catalog.cacheTable(tableIdent.quotedString) - if (!isLazy) { - // Performs eager caching - sparkSession.table(tableIdent).countInternal() + if (!isLazy) { + // Performs eager caching + sparkSession.table(tableIdent).count() + } } Seq.empty[Row] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index eadc6c94f4b3..99133bd70989 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -32,6 +32,7 @@ import org.apache.spark.input.{PortableDataStream, StreamInputFormat} import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType @@ -144,8 +145,9 @@ object TextInputCSVDataSource extends CSVDataSource { inputPaths: Seq[FileStatus], parsedOptions: CSVOptions): StructType = { val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions) - val maybeFirstLine = - CSVUtils.filterCommentAndEmpty(csv, parsedOptions).takeInternal(1).headOption + val maybeFirstLine = SQLExecution.ignoreNestedExecutionId(sparkSession) { + CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption + } inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index a06f1ce3287e..b11da7045de2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -23,6 +23,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.Partition import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType @@ -129,14 +130,11 @@ private[sql] case class JDBCRelation( } override def insert(data: DataFrame, overwrite: Boolean): Unit = { - import scala.collection.JavaConverters._ - - val options = jdbcOptions.asProperties.asScala + - ("url" -> jdbcOptions.url, "dbtable" -> jdbcOptions.table) - val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append - - new JdbcRelationProvider().createRelation( - data.sparkSession.sqlContext, mode, options.toMap, data) + SQLExecution.ignoreNestedExecutionId(data.sparkSession) { + data.write + .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) + .jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties) + } } override def toString: String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 9e889ff67945..6fa7c113defa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.types.StructType class ConsoleSink(options: Map[String, String]) extends Sink with Logging { @@ -47,9 +48,11 @@ class ConsoleSink(options: Map[String, String]) extends Sink with Logging { println(batchIdStr) println("-------------------------------------------") // scalastyle:off println - data.sparkSession.createDataFrame( - data.sparkSession.sparkContext.parallelize(data.collectInternal()), data.schema) - .showInternal(numRowsToShow, isTruncated) + SQLExecution.ignoreNestedExecutionId(data.sparkSession) { + data.sparkSession.createDataFrame( + data.sparkSession.sparkContext.parallelize(data.collect()), data.schema) + .show(numRowsToShow, isTruncated) + } } } @@ -79,7 +82,9 @@ class ConsoleSinkProvider extends StreamSinkProvider // Truncate the displayed data if it is too long, by default it is true val isTruncated = parameters.get("truncate").map(_.toBoolean).getOrElse(true) - data.showInternal(numRowsToShow, isTruncated) + SQLExecution.ignoreNestedExecutionId(sqlContext.sparkSession) { + data.show(numRowsToShow, isTruncated) + } ConsoleRelation(sqlContext, data) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index a5dac469f85b..198a34258280 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -193,21 +194,23 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi } if (notCommitted) { logDebug(s"Committing batch $batchId to $this") - outputMode match { - case Append | Update => - val rows = AddedData(batchId, data.collectInternal()) - synchronized { batches += rows } - - case Complete => - val rows = AddedData(batchId, data.collectInternal()) - synchronized { - batches.clear() - batches += rows - } - - case _ => - throw new IllegalArgumentException( - s"Output mode $outputMode is not supported by MemorySink") + SQLExecution.ignoreNestedExecutionId(data.sparkSession) { + outputMode match { + case Append | Update => + val rows = AddedData(batchId, data.collect()) + synchronized { batches += rows } + + case Complete => + val rows = AddedData(batchId, data.collect()) + synchronized { + batches.clear() + batches += rows + } + + case _ => + throw new IllegalArgumentException( + s"Output mode $outputMode is not supported by MemorySink") + } } } else { logDebug(s"Skipping already committed batch: $batchId") From cd6e3f0d53b35b9b7bcb6fff6d5cdfe801f03da9 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 26 Jun 2017 14:03:13 +0800 Subject: [PATCH 2/2] address comments --- .../spark/sql/execution/SQLExecution.scala | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 699f3b06be3a..ca8bed5214f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -44,8 +44,11 @@ object SQLExecution { private val testing = sys.props.contains("spark.testing") private[sql] def checkSQLExecutionId(sparkSession: SparkSession): Unit = { + val sc = sparkSession.sparkContext + val isNestedExecution = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null + val hasExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) != null // only throw an exception during tests. a missing execution ID should not fail a job. - if (testing && sparkSession.sparkContext.getLocalProperty(EXECUTION_ID_KEY) == null) { + if (testing && !isNestedExecution && !hasExecutionId) { // Attention testers: when a test fails with this exception, it means that the action that // started execution of a query didn't call withNewExecutionId. The execution ID should be // set by calling withNewExecutionId in the action that begins execution, like @@ -67,7 +70,7 @@ object SQLExecution { val executionId = SQLExecution.nextExecutionId sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString) executionIdToQueryExecution.put(executionId, queryExecution) - val r = try { + try { // sparkContext.getCallSite() would first try to pick up any call site that was previously // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on // streaming queries would give us call site like "run at :0" @@ -86,10 +89,15 @@ object SQLExecution { executionIdToQueryExecution.remove(executionId) sc.setLocalProperty(EXECUTION_ID_KEY, null) } - r } else if (sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null) { - // If `IGNORE_NESTED_EXECUTION_ID` is set, just ignore this new execution id. - body + // If `IGNORE_NESTED_EXECUTION_ID` is set, just ignore the execution id while evaluating the + // `body`, so that Spark jobs issued in the `body` won't be tracked. + try { + sc.setLocalProperty(EXECUTION_ID_KEY, null) + body + } finally { + sc.setLocalProperty(EXECUTION_ID_KEY, oldExecutionId) + } } else { // Don't support nested `withNewExecutionId`. This is an example of the nested // `withNewExecutionId`: @@ -128,7 +136,8 @@ object SQLExecution { /** * Wrap an action which may have nested execution id. This method can be used to run an execution - * inside another execution, e.g., `CacheTableCommand` need to call `Dataset.collect`. + * inside another execution, e.g., `CacheTableCommand` need to call `Dataset.collect`. Note that, + * all Spark jobs issued in the body won't be tracked in UI. */ def ignoreNestedExecutionId[T](sparkSession: SparkSession)(body: => T): T = { val sc = sparkSession.sparkContext