From ba467036fdd2e6efe3ef2be66f378da341c73423 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 19 May 2018 18:51:02 +0800 Subject: [PATCH 1/4] support accessing SQLConf at executor side --- .../org/apache/spark/TaskContextImpl.scala | 2 + .../spark/sql/internal/ReadOnlySQLConf.scala | 66 +++++++++++++++++++ .../apache/spark/sql/internal/SQLConf.scala | 21 +++--- .../org/apache/spark/sql/SparkSession.scala | 21 +++++- .../spark/sql/execution/SQLExecution.scala | 50 ++++++++++---- .../execution/basicPhysicalOperators.scala | 2 +- .../datasources/json/JsonDataSource.scala | 16 ++--- .../exchange/BroadcastExchangeExec.scala | 2 +- .../internal/ExecutorSideSQLConfSuite.scala | 66 +++++++++++++++++++ 9 files changed, 210 insertions(+), 36 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index cccd3ea457ba4..0791fe856ef15 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -178,4 +178,6 @@ private[spark] class TaskContextImpl( private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException + // TODO: shall we publish it and define it in `TaskContext`? + private[spark] def getLocalProperties(): Properties = localProperties } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala new file mode 100644 index 0000000000000..19f67236c8979 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import java.util.{Map => JMap} + +import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.internal.config.{ConfigEntry, ConfigProvider, ConfigReader} + +/** + * A readonly SQLConf that will be created by tasks running at the executor side. It reads the + * configs from the local properties which are propagated from driver to executors. + */ +class ReadOnlySQLConf(context: TaskContext) extends SQLConf { + + @transient override val settings: JMap[String, String] = { + context.asInstanceOf[TaskContextImpl].getLocalProperties().asInstanceOf[JMap[String, String]] + } + + @transient override protected val reader: ConfigReader = { + new ConfigReader(new TaskContextConfigProvider(context)) + } + + override protected def setConfWithCheck(key: String, value: String): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def unsetConf(key: String): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def unsetConf(entry: ConfigEntry[_]): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def clear(): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def clone(): SQLConf = { + throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.") + } + + override def copy(entries: (ConfigEntry[_], Any)*): SQLConf = { + throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.") + } +} + +class TaskContextConfigProvider(context: TaskContext) extends ConfigProvider { + override def get(key: String): Option[String] = Option(context.getLocalProperty(key)) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 53a50305348fa..643e4c686f58d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,13 +27,12 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkContext, SparkEnv} +import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator -import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. @@ -107,7 +106,13 @@ object SQLConf { * run tests in parallel. At the time this feature was implemented, this was a no-op since we * run unit tests (that does not involve SparkSession) in serial order. */ - def get: SQLConf = confGetter.get()() + def get: SQLConf = { + if (TaskContext.get != null) { + new ReadOnlySQLConf(TaskContext.get()) + } else { + confGetter.get()() + } + } val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations") .internal() @@ -1292,17 +1297,11 @@ object SQLConf { class SQLConf extends Serializable with Logging { import SQLConf._ - if (Utils.isTesting && SparkEnv.get != null) { - // assert that we're only accessing it on the driver. - assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER, - "SQLConf should only be created and accessed on the driver.") - } - /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) - @transient private val reader = new ConfigReader(settings) + @transient protected val reader = new ConfigReader(settings) /** ************************ Spark SQL Params/Hints ******************* */ @@ -1765,7 +1764,7 @@ class SQLConf extends Serializable with Logging { settings.containsKey(key) } - private def setConfWithCheck(key: String, value: String): Unit = { + protected def setConfWithCheck(key: String, value: String): Unit = { settings.put(key, value) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index c502e583a55c5..e2a1a57c7dd4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext} +import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext} import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging @@ -898,6 +898,7 @@ object SparkSession extends Logging { * @since 2.0.0 */ def getOrCreate(): SparkSession = synchronized { + assertOnDriver() // Get the session from current thread's active session. var session = activeThreadSession.get() if ((session ne null) && !session.sparkContext.isStopped) { @@ -1022,14 +1023,20 @@ object SparkSession extends Logging { * * @since 2.2.0 */ - def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get) + def getActiveSession: Option[SparkSession] = { + assertOnDriver() + Option(activeThreadSession.get) + } /** * Returns the default SparkSession that is returned by the builder. * * @since 2.2.0 */ - def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) + def getDefaultSession: Option[SparkSession] = { + assertOnDriver() + Option(defaultSession.get) + } /** * Returns the currently active SparkSession, otherwise the default one. If there is no default @@ -1062,6 +1069,14 @@ object SparkSession extends Logging { } } + private def assertOnDriver(): Unit = { + if (Utils.isTesting && TaskContext.get != null) { + // we're accessing it during task execution, fail. + throw new IllegalStateException( + "SparkSession should only be created and accessed on the driver.") + } + } + /** * Helper method to create an instance of `SessionState` based on `className` from conf. * The result is either `SessionState` or a Hive based `SessionState`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 2c5102b1e5ee7..032525a08ccdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -68,16 +68,18 @@ object SQLExecution { // sparkContext.getCallSite() would first try to pick up any call site that was previously // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on // streaming queries would give us call site like "run at :0" - val callSite = sparkSession.sparkContext.getCallSite() + val callSite = sc.getCallSite() - sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( - executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, - SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) - try { - body - } finally { - sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) + withSQLConfPropagated(sparkSession) { + sc.listenerBus.post(SparkListenerSQLExecutionStart( + executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) + try { + body + } finally { + sc.listenerBus.post(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) + } } } finally { executionIdToQueryExecution.remove(executionId) @@ -90,13 +92,37 @@ object SQLExecution { * thread from the original one, this method can be used to connect the Spark jobs in this action * with the known executionId, e.g., `BroadcastExchangeExec.relationFuture`. */ - def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = { + def withExecutionId[T](sparkSession: SparkSession, executionId: String)(body: => T): T = { + val sc = sparkSession.sparkContext val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + withSQLConfPropagated(sparkSession) { + try { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) + body + } finally { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) + } + } + } + + def withSQLConfPropagated[T](sparkSession: SparkSession)(body: => T): T = { + val sc = sparkSession.sparkContext + // Set all the specified SQL configs to local properties, so that they can be available at + // the executor side. + val allConfigs = sparkSession.sessionState.conf.getAllConfs + val originalLocalProps = allConfigs.collect { + case (key, value) if key.startsWith("spark") => + val originalValue = sc.getLocalProperty(key) + sc.setLocalProperty(key, value) + (key, originalValue) + } + try { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) body } finally { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) + for ((key, value) <- originalLocalProps) { + sc.setLocalProperty(key, value) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 1edfdc888afd8..d54bfbfc14f5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -629,7 +629,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { + SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { val beforeCollect = System.nanoTime() // Note that we use .executeCollect() because we don't want to convert data to Scala types val rows: Array[InternalRow] = child.executeCollect() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index ba83df0efebd0..3b6df45e949e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -34,6 +34,7 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType @@ -104,22 +105,19 @@ object TextInputJsonDataSource extends JsonDataSource { CreateJacksonParser.internalRow(enc, _: JsonFactory, _: InternalRow) }.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow)) - JsonInferSchema.infer(rdd, parsedOptions, rowParser) + SQLExecution.withSQLConfPropagated(json.sparkSession) { + JsonInferSchema.infer(rdd, parsedOptions, rowParser) + } } private def createBaseDataset( sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): Dataset[String] = { - val paths = inputPaths.map(_.getPath.toString) - val textOptions = Map.empty[String, String] ++ - parsedOptions.encoding.map("encoding" -> _) ++ - parsedOptions.lineSeparator.map("lineSep" -> _) - sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, - paths = paths, + paths = inputPaths.map(_.getPath.toString), className = classOf[TextFileFormat].getName, options = parsedOptions.parameters ).resolveRelation(checkFilesExist = false)) @@ -165,7 +163,9 @@ object MultiLineJsonDataSource extends JsonDataSource { .map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream)) .getOrElse(createParser(_: JsonFactory, _: PortableDataStream)) - JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) + SQLExecution.withSQLConfPropagated(sparkSession) { + JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) + } } private def createBaseRdd( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index daea6c39624d6..9e0ec9481b0de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -69,7 +69,7 @@ case class BroadcastExchangeExec( Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { + SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { try { val beforeCollect = System.nanoTime() // Use executeCollect/executeCollectIterator to avoid conversion to Scala types diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala new file mode 100644 index 0000000000000..404d6313ab92c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.test.SQLTestUtils + +class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { + import testImplicits._ + + protected var spark: SparkSession = null + + // Create a new [[SparkSession]] running in local-cluster mode. + override def beforeAll(): Unit = { + super.beforeAll() + spark = SparkSession.builder() + .master("local-cluster[2,1,1024]") + .appName("testing") + .getOrCreate() + } + + override def afterAll(): Unit = { + spark.stop() + spark = null + } + + test("ReadonlySQLConf is correctly created at the executor side") { + SQLConf.get.setConfString("spark.sql.x", "a") + try { + val checks = spark.range(10).mapPartitions { it => + val conf = SQLConf.get + Iterator(conf.isInstanceOf[ReadOnlySQLConf] && conf.getConfString("spark.sql.x") == "a") + }.collect() + assert(checks.forall(_ == true)) + } finally { + SQLConf.get.unsetConf("spark.sql.x") + } + } + + test("case-sensitive config should work for json schema inference") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + withTempPath { path => + val pathString = path.getCanonicalPath + spark.range(10).select('id.as("ID")).write.json(pathString) + spark.range(10).write.mode("append").json(pathString) + assert(spark.read.json(pathString).columns.toSet == Set("id", "ID")) + } + } + } +} From a1519d4aa692adceef1f3878a2ccd1715bf6175a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 20 May 2018 18:33:00 +0800 Subject: [PATCH 2/4] fix json schema inference --- .../org/apache/spark/sql/internal/SQLConf.scala | 1 + .../apache/spark/sql/execution/SQLExecution.scala | 4 ++++ .../datasources/json/JsonInferSchema.scala | 15 ++++++++++----- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 643e4c686f58d..6e4059cef4c7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -33,6 +33,7 @@ import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator +import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 032525a08ccdb..439932b0cc3ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -105,6 +105,10 @@ object SQLExecution { } } + /** + * Wrap an action with specified SQL configs. These configs will be propagated to the executor + * side via job local properties. + */ def withSQLConfPropagated[T](sparkSession: SparkSession)(body: => T): T = { val sc = sparkSession.sparkContext // Set all the specified SQL configs to local properties, so that they can be available at diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index a270a6451d5dd..e7eed95a560a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -45,8 +45,9 @@ private[sql] object JsonInferSchema { val parseMode = configOptions.parseMode val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord - // perform schema inference on each row and merge afterwards - val rootType = json.mapPartitions { iter => + // In each RDD partition, perform schema inference on each row and merge afterwards. + val typeMerger = compatibleRootType(columnNameOfCorruptRecord, parseMode) + val mergedTypesFromPartitions = json.mapPartitions { iter => val factory = new JsonFactory() configOptions.setJacksonOptions(factory) iter.flatMap { row => @@ -66,9 +67,13 @@ private[sql] object JsonInferSchema { s"Parse Mode: ${FailFastMode.name}.", e) } } - } - }.fold(StructType(Nil))( - compatibleRootType(columnNameOfCorruptRecord, parseMode)) + }.reduceOption(typeMerger).toIterator + } + + // Here we get RDD local iterator then fold, instead of calling `RDD.fold` directly, because + // `RDD.fold` will run the fold function in DAGScheduler event loop thread, which may not have + // active SparkSession and `SQLConf.get` may point to the wrong configs. + val rootType = mergedTypesFromPartitions.toLocalIterator.fold(StructType(Nil))(typeMerger) canonicalizeType(rootType) match { case Some(st: StructType) => st From d25e8467aac66cbe6f7a2b2b038d2450a37e19c9 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 20 May 2018 22:42:10 +0800 Subject: [PATCH 3/4] address comments --- .../scala/org/apache/spark/scheduler/DAGScheduler.scala | 2 +- .../src/main/scala/org/apache/spark/util/EventLoop.scala | 3 ++- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 9 ++++++++- .../spark/sql/internal/ExecutorSideSQLConfSuite.scala | 2 +- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 78b6b34b5d2bb..5f2d16d03165f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -206,7 +206,7 @@ class DAGScheduler( private val messageScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("dag-scheduler-message") - private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) + private[spark] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) /** diff --git a/core/src/main/scala/org/apache/spark/util/EventLoop.scala b/core/src/main/scala/org/apache/spark/util/EventLoop.scala index 3ea9139e11027..651ea4996f6cb 100644 --- a/core/src/main/scala/org/apache/spark/util/EventLoop.scala +++ b/core/src/main/scala/org/apache/spark/util/EventLoop.scala @@ -37,7 +37,8 @@ private[spark] abstract class EventLoop[E](name: String) extends Logging { private val stopped = new AtomicBoolean(false) - private val eventThread = new Thread(name) { + // Exposed for testing. + private[spark] val eventThread = new Thread(name) { setDaemon(true) override def run(): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6e4059cef4c7d..ee77813b5b452 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,7 +27,7 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path -import org.apache.spark.TaskContext +import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit @@ -111,6 +111,13 @@ object SQLConf { if (TaskContext.get != null) { new ReadOnlySQLConf(TaskContext.get()) } else { + if (Utils.isTesting && SparkContext.getActive.isDefined) { + val schedulerEventLoopThread = + SparkContext.getActive.get.dagScheduler.eventProcessLoop.eventThread + if (schedulerEventLoopThread.getId == Thread.currentThread().getId) { + throw new RuntimeException("Cannot get SQLConf inside scheduler event loop thread.") + } + } confGetter.get()() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala index 404d6313ab92c..3dd0712e02448 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala @@ -40,7 +40,7 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { spark = null } - test("ReadonlySQLConf is correctly created at the executor side") { + test("ReadOnlySQLConf is correctly created at the executor side") { SQLConf.get.setConfString("spark.sql.x", "a") try { val checks = spark.range(10).mapPartitions { it => From f7a629906f1ba15b663eb8ab1c6b49daa48c34d2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 21 May 2018 15:41:53 +0800 Subject: [PATCH 4/4] address comments --- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ee77813b5b452..a2fb3c64844b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -95,7 +95,9 @@ object SQLConf { /** * Returns the active config object within the current scope. If there is an active SparkSession, - * the proper SQLConf associated with the thread's session is used. + * the proper SQLConf associated with the thread's active session is used. If it's called from + * tasks in the executor side, a SQLConf will be created from job local properties, which are set + * and propagated from the driver side. * * The way this works is a little bit convoluted, due to the fact that config was added initially * only for physical plans (and as a result not in sql/catalyst module). @@ -112,6 +114,8 @@ object SQLConf { new ReadOnlySQLConf(TaskContext.get()) } else { if (Utils.isTesting && SparkContext.getActive.isDefined) { + // DAGScheduler event loop thread does not have an active SparkSession, the `confGetter` + // will return `fallbackConf` which is unexpected. Here we prevent it from happening. val schedulerEventLoopThread = SparkContext.getActive.get.dagScheduler.eventProcessLoop.eventThread if (schedulerEventLoopThread.getId == Thread.currentThread().getId) {