Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.optimizer

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -126,8 +127,14 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] {
* merged as there can be subqueries that are different ([[checkIdenticalPlans]] is
* false) due to an extra [[Project]] node in one of them. In that case
* `attributes.size` remains 1 after merging, but the merged flag becomes true.
* @param references A set of subquery indexes in the cache to track all (including transitive)
* nested subqueries.
*/
case class Header(attributes: Seq[Attribute], plan: LogicalPlan, merged: Boolean)
case class Header(
attributes: Seq[Attribute],
plan: LogicalPlan,
merged: Boolean,
references: Set[Int])

private def extractCommonScalarSubqueries(plan: LogicalPlan) = {
val cache = ArrayBuffer.empty[Header]
Expand Down Expand Up @@ -166,26 +173,39 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] {
// "Header".
private def cacheSubquery(plan: LogicalPlan, cache: ArrayBuffer[Header]): (Int, Int) = {
val output = plan.output.head
cache.zipWithIndex.collectFirst(Function.unlift { case (header, subqueryIndex) =>
checkIdenticalPlans(plan, header.plan).map { outputMap =>
val mappedOutput = mapAttributes(output, outputMap)
val headerIndex = header.attributes.indexWhere(_.exprId == mappedOutput.exprId)
subqueryIndex -> headerIndex
}.orElse(tryMergePlans(plan, header.plan).map {
case (mergedPlan, outputMap) =>
val references = mutable.HashSet.empty[Int]
plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY_REFERENCE)) {
case ssr: ScalarSubqueryReference =>
references += ssr.subqueryIndex
references ++= cache(ssr.subqueryIndex).references
ssr
}

cache.zipWithIndex.collectFirst(Function.unlift {
case (header, subqueryIndex) if !references.contains(subqueryIndex) =>
checkIdenticalPlans(plan, header.plan).map { outputMap =>
val mappedOutput = mapAttributes(output, outputMap)
var headerIndex = header.attributes.indexWhere(_.exprId == mappedOutput.exprId)
val newHeaderAttributes = if (headerIndex == -1) {
headerIndex = header.attributes.size
header.attributes :+ mappedOutput
} else {
header.attributes
}
cache(subqueryIndex) = Header(newHeaderAttributes, mergedPlan, true)
val headerIndex = header.attributes.indexWhere(_.exprId == mappedOutput.exprId)
subqueryIndex -> headerIndex
})
}.orElse{
tryMergePlans(plan, header.plan).map {
case (mergedPlan, outputMap) =>
val mappedOutput = mapAttributes(output, outputMap)
var headerIndex = header.attributes.indexWhere(_.exprId == mappedOutput.exprId)
val newHeaderAttributes = if (headerIndex == -1) {
headerIndex = header.attributes.size
header.attributes :+ mappedOutput
} else {
header.attributes
}
cache(subqueryIndex) =
Header(newHeaderAttributes, mergedPlan, true, header.references ++ references)
subqueryIndex -> headerIndex
}
}
case _ => None
}).getOrElse {
cache += Header(Seq(output), plan, false)
cache += Header(Seq(output), plan, false, references.toSet)
cache.length - 1 -> 0
}
}
Expand All @@ -210,12 +230,6 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] {
cachedPlan: LogicalPlan): Option[(LogicalPlan, AttributeMap[Attribute])] = {
checkIdenticalPlans(newPlan, cachedPlan).map(cachedPlan -> _).orElse(
(newPlan, cachedPlan) match {
case (_, _) if newPlan.containsPattern(SCALAR_SUBQUERY_REFERENCE) ||
cachedPlan.containsPattern(SCALAR_SUBQUERY_REFERENCE) =>
// Subquery expressions with nested subquery expressions within are not supported for now.
// TODO: support this optimization by collecting the transitive subquery references in the
// new plan and recording them in order to suppress merging the new plan into those.
None
case (np: Project, cp: Project) =>
tryMergePlans(np.child, cp.child).map { case (mergedChild, outputMap) =>
val (mergedProjectList, newOutputMap) =
Expand Down
35 changes: 29 additions & 6 deletions sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2157,7 +2157,7 @@ class SubquerySuite extends QueryTest
}
}

test("SPARK-40618: Do not merge scalar subqueries with nested subqueries inside") {
test("Merge non-correlated scalar subqueries from different parent plans") {
Seq(false, true).foreach { enableAQE =>
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) {
Expand Down Expand Up @@ -2189,13 +2189,13 @@ class SubquerySuite extends QueryTest
}

if (enableAQE) {
assert(subqueryIds.size == 4, "Missing or unexpected SubqueryExec in the plan")
assert(reusedSubqueryIds.size == 2,
"Missing or unexpected reused ReusedSubqueryExec in the plan")
} else {
assert(subqueryIds.size == 3, "Missing or unexpected SubqueryExec in the plan")
assert(reusedSubqueryIds.size == 3,
"Missing or unexpected reused ReusedSubqueryExec in the plan")
} else {
assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan")
assert(reusedSubqueryIds.size == 4,
"Missing or unexpected reused ReusedSubqueryExec in the plan")
}
}
}
Expand Down Expand Up @@ -2332,9 +2332,32 @@ class SubquerySuite extends QueryTest
// This test contains a subquery expression with another subquery expression nested inside.
// It acts as a regression test to ensure that the MergeScalarSubqueries rule does not attempt
// to merge them together.
withTable("t") {
withTable("t", "t2") {
sql("create table t(col int) using csv")
checkAnswer(sql("select(select sum((select sum(col) from t)) from t)"), Row(null))

checkAnswer(sql(
"""
|select
| (select sum(
| (select sum(
| (select sum(col) from t))
| from t))
| from t)
|""".stripMargin),
Row(null))

sql("create table t2(col int) using csv")
checkAnswer(sql(
"""
|select
| (select sum(
| (select sum(
| (select sum(col) from t))
| from t2))
| from t)
|""".stripMargin),
Row(null))
}
}
}