From ede4044aed0b2e038fb7b3bb2d46b383d9892c48 Mon Sep 17 00:00:00 2001 From: Ajith Date: Thu, 23 Jan 2020 15:50:18 +0530 Subject: [PATCH] Backport SPARK-30556 from master --- .../scala/org/apache/spark/util/Utils.scala | 7 ++++ .../spark/sql/internal/StaticSQLConf.scala | 7 ++++ .../spark/sql/execution/SQLExecution.scala | 20 +++++++++- .../execution/basicPhysicalOperators.scala | 10 +++-- .../internal/ExecutorSideSQLConfSuite.scala | 38 ++++++++++++++++++- 5 files changed, 76 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 8f86b472b937..2e516146ec14 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2863,6 +2863,13 @@ private[spark] object Utils extends Logging { def stringHalfWidth(str: String): Int = { if (str == null) 0 else str.length + fullWidthRegex.findAllIn(str).size } + + /** Create a new properties object with the same values as `props` */ + def cloneProperties(props: Properties): Properties = { + val resultProps = new Properties() + props.asScala.foreach(entry => resultProps.put(entry._1, entry._2)) + resultProps + } } private[util] object CallerContext extends Logging { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index d9c354b165e5..4b5bb8578182 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -126,4 +126,11 @@ object StaticSQLConf { .intConf .createWithDefault(1000) + val SUBQUERY_MAX_THREAD_THRESHOLD = + buildStaticConf("spark.sql.subquery.maxThreadThreshold") + .internal() + .doc("The maximum degree of parallelism to execute the subquery.") + .intConf + .checkValue(thres => thres > 0 && thres <= 128, "The threshold must be in (0,128].") + .createWithDefault(16) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 439932b0cc3a..296076d67eb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -20,9 +20,11 @@ package org.apache.spark.sql.execution import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicLong -import org.apache.spark.SparkContext +import scala.concurrent.{ExecutionContext, Future} + import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart} +import org.apache.spark.util.Utils object SQLExecution { @@ -129,4 +131,20 @@ object SQLExecution { } } } + + /** + * Wrap passed function to ensure necessary thread-local variables like + * SparkContext local properties are forwarded to execution thread + */ + def withThreadLocalCaptured[T]( + sparkSession: SparkSession, exec: ExecutionContext)(body: => T): Future[T] = { + val activeSession = sparkSession + val sc = sparkSession.sparkContext + val localProps = Utils.cloneProperties(sc.getLocalProperties) + Future { + SparkSession.setActiveSession(activeSession) + sc.setLocalProperties(localProps) + body + }(exec) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 27aa9b8e3b1e..29af8440b9e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.types.{LongType, StructType} import org.apache.spark.util.ThreadUtils import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} @@ -658,7 +659,9 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { private lazy val relationFuture: Future[Array[InternalRow]] = { // relationFuture is used in "doExecute". Therefore we can get the execution id correctly here. val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - Future { + SQLExecution.withThreadLocalCaptured[Array[InternalRow]]( + sqlContext.sparkSession, + SubqueryExec.executionContext) { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { @@ -673,7 +676,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) rows } - }(SubqueryExec.executionContext) + } } protected override def doPrepare(): Unit = { @@ -691,5 +694,6 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { object SubqueryExec { private[execution] val executionContext = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonCachedThreadPool("subquery", 16)) + ThreadUtils.newDaemonCachedThreadPool("subquery", + SQLConf.get.getConf(StaticSQLConf.SUBQUERY_MAX_THREAD_THRESHOLD))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala index ae7206b8f46c..2233f4d2a1c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.internal -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.{SparkException, SparkFunSuite, TaskContext} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LocalRelation @@ -129,6 +129,40 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { } } } + + test("SPARK-30556 propagate local properties to subquery execution thread") { + withSQLConf(StaticSQLConf.SUBQUERY_MAX_THREAD_THRESHOLD.key -> "1") { + withTempView("l", "m", "n") { + Seq(true).toDF().createOrReplaceTempView("l") + val confKey = "spark.sql.y" + + def createDataframe(confKey: String, confValue: String): Dataset[Boolean] = { + Seq(true) + .toDF() + .mapPartitions { _ => + TaskContext.get.getLocalProperty(confKey) == confValue match { + case true => Iterator(true) + case false => Iterator.empty + } + } + } + + // set local configuration and assert + val confValue1 = "e" + createDataframe(confKey, confValue1).createOrReplaceTempView("m") + spark.sparkContext.setLocalProperty(confKey, confValue1) + val result1 = sql("SELECT value, (SELECT MAX(*) FROM m) x FROM l").collect + assert(result1.forall(_.getBoolean(1))) + + // change the conf value and assert again + val confValue2 = "f" + createDataframe(confKey, confValue2).createOrReplaceTempView("n") + spark.sparkContext.setLocalProperty(confKey, confValue2) + val result2 = sql("SELECT value, (SELECT MAX(*) FROM n) x FROM l").collect + assert(result2.forall(_.getBoolean(1))) + } + } + } } case class SQLConfAssertPlan(confToCheck: Seq[(String, String)]) extends LeafExecNode {