diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index d2f27da23901..66ac9ddb21aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -145,9 +145,17 @@ object StaticSQLConf { "cause longer waiting for other broadcasting. Also, increasing parallelism may " + "cause memory problem.") .intConf - .checkValue(thres => thres > 0 && thres <= 128, "The threshold must be in [0,128].") + .checkValue(thres => thres > 0 && thres <= 128, "The threshold must be in (0,128].") .createWithDefault(128) + val SUBQUERY_MAX_THREAD_THRESHOLD = + buildStaticConf("spark.sql.subquery.maxThreadThreshold") + .internal() + .doc("The maximum degree of parallelism to execute the subquery.") + .intConf + .checkValue(thres => thres > 0 && thres <= 128, "The threshold must be in (0,128].") + .createWithDefault(16) + val SQL_EVENT_TRUNCATE_LENGTH = buildStaticConf("spark.sql.event.truncate.length") .doc("Threshold of SQL length beyond which it will be truncated before adding to " + "event. Defaults to no truncation. If set to 0, callsite will be logged instead.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 6046805ae95d..995d94ef5eac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicLong +import scala.concurrent.{ExecutionContext, Future} + import org.apache.spark.SparkContext import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.sql.SparkSession @@ -164,4 +166,20 @@ object SQLExecution { } } } + + /** + * Wrap passed function to ensure necessary thread-local variables like + * SparkContext local properties are forwarded to execution thread + */ + def withThreadLocalCaptured[T]( + sparkSession: SparkSession, exec: ExecutionContext)(body: => T): Future[T] = { + val activeSession = sparkSession + val sc = sparkSession.sparkContext + val localProps = Utils.cloneProperties(sc.getLocalProperties) + Future { + SparkSession.setActiveSession(activeSession) + sc.setLocalProperties(localProps) + body + }(exec) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index e128d59dca6b..f3f756425a15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -31,8 +31,9 @@ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.types.{LongType, StructType} -import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.{ThreadUtils, Utils} import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} /** Physical plan for Project. */ @@ -749,7 +750,9 @@ case class SubqueryExec(name: String, child: SparkPlan) private lazy val relationFuture: Future[Array[InternalRow]] = { // relationFuture is used in "doExecute". Therefore we can get the execution id correctly here. val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - Future { + SQLExecution.withThreadLocalCaptured[Array[InternalRow]]( + sqlContext.sparkSession, + SubqueryExec.executionContext) { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { @@ -764,7 +767,7 @@ case class SubqueryExec(name: String, child: SparkPlan) SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) rows } - }(SubqueryExec.executionContext) + } } protected override def doCanonicalize(): SparkPlan = { @@ -788,7 +791,8 @@ case class SubqueryExec(name: String, child: SparkPlan) object SubqueryExec { private[execution] val executionContext = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonCachedThreadPool("subquery", 16)) + ThreadUtils.newDaemonCachedThreadPool("subquery", + SQLConf.get.getConf(StaticSQLConf.SUBQUERY_MAX_THREAD_THRESHOLD))) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala index 776cdb107084..0cc658c49961 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.internal import org.scalatest.Assertions._ -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.{SparkException, SparkFunSuite, TaskContext} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LocalRelation @@ -125,6 +125,38 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { val e = intercept[SparkException](dummyQueryExecution1.toRdd.collect()) assert(e.getCause.isInstanceOf[NoSuchElementException]) } + + test("SPARK-30556 propagate local properties to subquery execution thread") { + withSQLConf(StaticSQLConf.SUBQUERY_MAX_THREAD_THRESHOLD.key -> "1") { + withTempView("l", "m", "n") { + Seq(true).toDF().createOrReplaceTempView("l") + val confKey = "spark.sql.y" + + def createDataframe(confKey: String, confValue: String): Dataset[Boolean] = { + Seq(true) + .toDF() + .mapPartitions { _ => + TaskContext.get.getLocalProperty(confKey) == confValue match { + case true => Iterator(true) + case false => Iterator.empty + } + } + } + + // set local configuration and assert + val confValue1 = "e" + createDataframe(confKey, confValue1).createOrReplaceTempView("m") + spark.sparkContext.setLocalProperty(confKey, confValue1) + assert(sql("SELECT * FROM l WHERE EXISTS (SELECT * FROM m)").collect.size == 1) + + // change the conf value and assert again + val confValue2 = "f" + createDataframe(confKey, confValue2).createOrReplaceTempView("n") + spark.sparkContext.setLocalProperty(confKey, confValue2) + assert(sql("SELECT * FROM l WHERE EXISTS (SELECT * FROM n)").collect().size == 1) + } + } + } } case class SQLConfAssertPlan(confToCheck: Seq[(String, String)]) extends LeafExecNode {