Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2863,6 +2863,13 @@ private[spark] object Utils extends Logging {
def stringHalfWidth(str: String): Int = {
if (str == null) 0 else str.length + fullWidthRegex.findAllIn(str).size
}

/** Create a new properties object with the same values as `props` */
def cloneProperties(props: Properties): Properties = {
val resultProps = new Properties()
props.asScala.foreach(entry => resultProps.put(entry._1, entry._2))
resultProps
}
}

private[util] object CallerContext extends Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,11 @@ object StaticSQLConf {
.intConf
.createWithDefault(1000)

val SUBQUERY_MAX_THREAD_THRESHOLD =
buildStaticConf("spark.sql.subquery.maxThreadThreshold")
.internal()
.doc("The maximum degree of parallelism to execute the subquery.")
.intConf
.checkValue(thres => thres > 0 && thres <= 128, "The threshold must be in (0,128].")
.createWithDefault(16)
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ package org.apache.spark.sql.execution
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicLong

import org.apache.spark.SparkContext
import scala.concurrent.{ExecutionContext, Future}

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart}
import org.apache.spark.util.Utils

object SQLExecution {

Expand Down Expand Up @@ -129,4 +131,20 @@ object SQLExecution {
}
}
}

/**
* Wrap passed function to ensure necessary thread-local variables like
* SparkContext local properties are forwarded to execution thread
*/
def withThreadLocalCaptured[T](
sparkSession: SparkSession, exec: ExecutionContext)(body: => T): Future[T] = {
val activeSession = sparkSession
val sc = sparkSession.sparkContext
val localProps = Utils.cloneProperties(sc.getLocalProperties)
Future {
SparkSession.setActiveSession(activeSession)
sc.setLocalProperties(localProps)
body
}(exec)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.sql.types.{LongType, StructType}
import org.apache.spark.util.ThreadUtils
import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}
Expand Down Expand Up @@ -658,7 +659,9 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode {
private lazy val relationFuture: Future[Array[InternalRow]] = {
// relationFuture is used in "doExecute". Therefore we can get the execution id correctly here.
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
Future {
SQLExecution.withThreadLocalCaptured[Array[InternalRow]](
sqlContext.sparkSession,
SubqueryExec.executionContext) {
// This will run in another thread. Set the execution id so that we can connect these jobs
// with the correct execution.
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
Expand All @@ -673,7 +676,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode {
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
rows
}
}(SubqueryExec.executionContext)
}
}

protected override def doPrepare(): Unit = {
Expand All @@ -691,5 +694,6 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode {

object SubqueryExec {
private[execution] val executionContext = ExecutionContext.fromExecutorService(
ThreadUtils.newDaemonCachedThreadPool("subquery", 16))
ThreadUtils.newDaemonCachedThreadPool("subquery",
SQLConf.get.getConf(StaticSQLConf.SUBQUERY_MAX_THREAD_THRESHOLD)))
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

package org.apache.spark.sql.internal

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.{SparkException, SparkFunSuite, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
Expand Down Expand Up @@ -129,6 +129,40 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils {
}
}
}

test("SPARK-30556 propagate local properties to subquery execution thread") {
withSQLConf(StaticSQLConf.SUBQUERY_MAX_THREAD_THRESHOLD.key -> "1") {
withTempView("l", "m", "n") {
Seq(true).toDF().createOrReplaceTempView("l")
val confKey = "spark.sql.y"

def createDataframe(confKey: String, confValue: String): Dataset[Boolean] = {
Seq(true)
.toDF()
.mapPartitions { _ =>
TaskContext.get.getLocalProperty(confKey) == confValue match {
case true => Iterator(true)
case false => Iterator.empty
}
}
}

// set local configuration and assert
val confValue1 = "e"
createDataframe(confKey, confValue1).createOrReplaceTempView("m")
spark.sparkContext.setLocalProperty(confKey, confValue1)
val result1 = sql("SELECT value, (SELECT MAX(*) FROM m) x FROM l").collect
assert(result1.forall(_.getBoolean(1)))

// change the conf value and assert again
val confValue2 = "f"
createDataframe(confKey, confValue2).createOrReplaceTempView("n")
spark.sparkContext.setLocalProperty(confKey, confValue2)
val result2 = sql("SELECT value, (SELECT MAX(*) FROM n) x FROM l").collect
assert(result2.forall(_.getBoolean(1)))
}
}
}
}

case class SQLConfAssertPlan(confToCheck: Seq[(String, String)]) extends LeafExecNode {
Expand Down