Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
* All the subqueries of current plan.
*/
def subqueries: Seq[PlanType] = {
expressions.flatMap(_.collect {case e: SubqueryExpression => e.plan.asInstanceOf[PlanType]})
expressions.flatMap(_.collect {
case e: SubqueryExpression => e
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we use ExpressionCanonicalizer to canonicalize the expression before call distinct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because we use ReusedScalarSubquery, not Alias to indicate the reused SubqueryExpression, I think we don't use ExpressionCanonicalizer.

}).distinct.map(_.plan.asInstanceOf[PlanType])
}

override protected def innerChildren: Seq[QueryPlan[_]] = subqueries
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ package org.apache.spark.sql.execution
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration._
import scala.concurrent.ExecutionContext

import org.apache.spark.{broadcast, SparkEnv}
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -138,48 +137,29 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
}

/**
* List of (uncorrelated scalar subquery, future holding the subquery result) for this plan node.
* List of uncorrelated scalar subquery for this plan node.
* This list is populated by [[prepareSubqueries]], which is called in [[prepare]].
*/
@transient
private val subqueryResults = new ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])]
private var allSubqueries = new ArrayBuffer[ScalarSubquery]

/**
* Finds scalar subquery expressions in this plan node and starts evaluating them.
* The list of subqueries are added to [[subqueryResults]].
* The list of subqueries are added to [[allSubqueries]].
*/
protected def prepareSubqueries(): Unit = {
val allSubqueries = expressions.flatMap(_.collect {case e: ScalarSubquery => e})
allSubqueries.asInstanceOf[Seq[ScalarSubquery]].foreach { e =>
val futureResult = Future {
// Each subquery should return only one row (and one column). We take two here and throws
// an exception later if the number of rows is greater than one.
e.executedPlan.executeTake(2)
}(SparkPlan.subqueryExecutionContext)
subqueryResults += e -> futureResult
expressions.flatMap(_.collect { case e: ScalarSubquery => e }).distinct.foreach { e =>
e.submitSubqueryEvaluated()
allSubqueries += e
}
}

/**
* Blocks the thread until all subqueries finish evaluation and update the results.
* Blocks the thread until all subqueries finish evaluation.
*/
protected def waitForSubqueries(): Unit = synchronized {
// fill in the result of subqueries
subqueryResults.foreach { case (e, futureResult) =>
val rows = ThreadUtils.awaitResult(futureResult, Duration.Inf)
if (rows.length > 1) {
sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}")
}
if (rows.length == 1) {
assert(rows(0).numFields == 1,
s"Expects 1 field, but got ${rows(0).numFields}; something went wrong in analysis")
e.updateResult(rows(0).get(0, e.dataType))
} else {
// If there is no rows returned, the result should be null.
e.updateResult(null)
}
}
subqueryResults.clear()
allSubqueries.foreach(_.awaitSubqueryResult())
allSubqueries.clear()
}

/**
Expand Down Expand Up @@ -390,11 +370,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
}
}

object SparkPlan {
private[execution] val subqueryExecutionContext = ExecutionContext.fromExecutorService(
ThreadUtils.newDaemonCachedThreadPool("subquery", 16))
}

private[sql] trait LeafExecNode extends SparkPlan {
override def children: Seq[SparkPlan] = Nil
override def producedAttributes: AttributeSet = outputSet
Expand Down
100 changes: 93 additions & 7 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,20 @@

package org.apache.spark.sql.execution

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId, Literal, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.ThreadUtils

/**
* A subquery that will return only one row and one column.
Expand All @@ -48,10 +54,12 @@ case class ScalarSubquery(
override def toString: String = s"subquery#${exprId.id}"

// the first column in first row from `query`.
@volatile private var result: Any = null
@volatile private var updated: Boolean = false
@volatile private[this] var result: Any = null
@volatile private[this] var updated: Boolean = false
@transient private[this] var evaluated: Boolean = false
@transient private[this] var futureResult: Future[Array[InternalRow]] = _

def updateResult(v: Any): Unit = {
private def updateResult(v: Any): Unit = {
result = v
updated = true
}
Expand All @@ -65,17 +73,95 @@ case class ScalarSubquery(
require(updated, s"$this has not finished")
Literal.create(result, dataType).doGenCode(ctx, ev)
}

/**
* Submit the subquery to be evaluated. No need to do if it has been evaluated before.
*/
def submitSubqueryEvaluated(): Unit = synchronized {
if (!evaluated) {
futureResult = Future {
// Each subquery should return only one row (and one column). We take two here and throws
// an exception later if the number of rows is greater than one.
executedPlan.executeTake(2)
}(ScalarSubquery.subqueryExecutionContext)
evaluated = true
}
}

/**
* Blocks the thread until the evaluation of subquery has been finished.
*/
def awaitSubqueryResult(): Unit = synchronized {
if (!updated) {
val rows = ThreadUtils.awaitResult(futureResult, Duration.Inf)
if (rows.length > 1) {
sys.error(s"more than one row returned by a subquery used as an expression:\n${plan}")
}
if (rows.length == 1) {
assert(rows(0).numFields == 1,
s"Expects 1 field, but got ${rows(0).numFields}; something went wrong in analysis")
updateResult(rows(0).get(0, dataType))
} else {
// If there is no rows returned, the result should be null.
updateResult(null)
}
}
}

override def equals(o: Any): Boolean = o match {
case other: ScalarSubquery => this.eq(other)
case _ => false
}

override def hashCode: Int = exprId.hashCode()
}

object ScalarSubquery {
private[execution] val subqueryExecutionContext = ExecutionContext.fromExecutorService(
ThreadUtils.newDaemonCachedThreadPool("subquery", 16))
}

/**
* A wrapper for reused uncorrelated ScalarSubquery to avoid the re-computing for subqueries with
* the same "canonical" logical plan in a query, because uncorrelated subqueries with the same
* "canonical" logical plan always produce the same results.
*/
case class ReusedScalarSubquery(
exprId: ExprId,
child: ScalarSubquery) extends UnaryExpression {

override def dataType: DataType = child.dataType
override def toString: String = s"ReusedSubquery#${exprId.id}($child)"

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
defineCodeGen(ctx, ev, c => c)

protected override def nullSafeEval(input: Any): Any = input
}

/**
* Plans scalar subqueries from that are present in the given [[SparkPlan]].
*/
case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
// Build a hash map using schema of subquery's logical plan to avoid O(N*N) sameResult calls.
val subqueryMap = mutable.HashMap[StructType, ArrayBuffer[(LogicalPlan, ScalarSubquery)]]()
plan.transformAllExpressions {
case subquery: expressions.ScalarSubquery =>
val executedPlan = new QueryExecution(sparkSession, subquery.plan).executedPlan
ScalarSubquery(executedPlan, subquery.exprId)
val sameSchema = subqueryMap.getOrElseUpdate(
subquery.query.schema, ArrayBuffer[(LogicalPlan, ScalarSubquery)]())
val samePlan = sameSchema.find { case (e, _) =>
subquery.query.sameResult(e)
}
if (samePlan.isDefined) {
// Subqueries that have the same logical plan can be reused the same results.
ReusedScalarSubquery(subquery.exprId, samePlan.get._2)
} else {
val executedPlan = new QueryExecution(sparkSession, subquery.plan).executedPlan
val physicalSubquery = new ScalarSubquery(executedPlan, subquery.exprId)
sameSchema += ((subquery.plan, physicalSubquery))
physicalSubquery
}
}
}
}
27 changes: 27 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -571,4 +571,31 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
Row(1.0, false) :: Row(1.0, false) :: Row(2.0, true) :: Row(2.0, true) ::
Row(3.0, false) :: Row(5.0, true) :: Row(null, false) :: Row(null, true) :: Nil)
}

test("SPARK-16456: Reuse the uncorrelated scalar subqueries with the same logical plan") {
withTempTable("t1", "t2", "t3") {
val df = (1 to 3).map(i => (i, i)).toDF("key", "value")
df.createOrReplaceTempView("t1")
df.createOrReplaceTempView("t2")
df.createOrReplaceTempView("t3")
checkAnswer(
sql(
"""
|WITH max_test AS
|(
| SELECT max(key) as max_key FROM t1
|),
|max_test2 AS
|(
| SELECT max(key) as max_key FROM t1
|)
|SELECT key FROM t2
|WHERE key = (SELECT max_key FROM max_test) and value = (SELECT max_key FROM max_test)
|UNION ALL
|SELECT key FROM t3
|WHERE key = (SELECT max_key FROM max_test) and value = (SELECT max_key FROM max_test2)
""".stripMargin
), Row(3) :: Row(3) :: Nil)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{execution, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
import org.apache.spark.sql.catalyst.plans.physical._
Expand Down Expand Up @@ -518,6 +518,40 @@ class PlannerSuite extends SharedSQLContext {
fail(s"Should have only two shuffles:\n$outputPlan")
}
}

test("SPARK-16456: Reuse the uncorrelated scalar subqueries with the same logical plan") {
withTempTable("t1", "t2", "t3") {
val df = (1 to 3).map(i => (i, i)).toDF("key", "value")
df.createOrReplaceTempView("t1")
df.createOrReplaceTempView("t2")
df.createOrReplaceTempView("t3")
val planned = sql(
"""
|WITH max_test AS
|(
| SELECT max(key) as max_key FROM t1
|),
|max_test2 AS
|(
| SELECT max(key) as max_key FROM t1
|)
|SELECT key FROM t2
|WHERE key = (SELECT max_key FROM max_test) and value = (SELECT max_key FROM max_test)
|UNION ALL
|SELECT key FROM t3
|WHERE key = (SELECT max_key FROM max_test) and value = (SELECT max_key FROM max_test2)
""".stripMargin
).queryExecution.executedPlan
val numExecutedSubqueries = planned.flatMap {
case plan => plan.expressions.flatMap(_.collect { case e: ScalarSubquery => e })
}.distinct.size
assert(numExecutedSubqueries === 1)
val numReusedSubqueries = planned.flatMap {
case plan => plan.expressions.flatMap(_.collect { case e: ReusedScalarSubquery => e })
}.size
assert(numReusedSubqueries === 3)
}
}
}

// Used for unit-testing EnsureRequirements
Expand Down