diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala index 69f77e8f3f460..1cb3f3f157cfd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions._ @@ -126,8 +127,14 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] { * merged as there can be subqueries that are different ([[checkIdenticalPlans]] is * false) due to an extra [[Project]] node in one of them. In that case * `attributes.size` remains 1 after merging, but the merged flag becomes true. + * @param references A set of subquery indexes in the cache to track all (including transitive) + * nested subqueries. */ - case class Header(attributes: Seq[Attribute], plan: LogicalPlan, merged: Boolean) + case class Header( + attributes: Seq[Attribute], + plan: LogicalPlan, + merged: Boolean, + references: Set[Int]) private def extractCommonScalarSubqueries(plan: LogicalPlan) = { val cache = ArrayBuffer.empty[Header] @@ -166,26 +173,39 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] { // "Header". private def cacheSubquery(plan: LogicalPlan, cache: ArrayBuffer[Header]): (Int, Int) = { val output = plan.output.head - cache.zipWithIndex.collectFirst(Function.unlift { case (header, subqueryIndex) => - checkIdenticalPlans(plan, header.plan).map { outputMap => - val mappedOutput = mapAttributes(output, outputMap) - val headerIndex = header.attributes.indexWhere(_.exprId == mappedOutput.exprId) - subqueryIndex -> headerIndex - }.orElse(tryMergePlans(plan, header.plan).map { - case (mergedPlan, outputMap) => + val references = mutable.HashSet.empty[Int] + plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY_REFERENCE)) { + case ssr: ScalarSubqueryReference => + references += ssr.subqueryIndex + references ++= cache(ssr.subqueryIndex).references + ssr + } + + cache.zipWithIndex.collectFirst(Function.unlift { + case (header, subqueryIndex) if !references.contains(subqueryIndex) => + checkIdenticalPlans(plan, header.plan).map { outputMap => val mappedOutput = mapAttributes(output, outputMap) - var headerIndex = header.attributes.indexWhere(_.exprId == mappedOutput.exprId) - val newHeaderAttributes = if (headerIndex == -1) { - headerIndex = header.attributes.size - header.attributes :+ mappedOutput - } else { - header.attributes - } - cache(subqueryIndex) = Header(newHeaderAttributes, mergedPlan, true) + val headerIndex = header.attributes.indexWhere(_.exprId == mappedOutput.exprId) subqueryIndex -> headerIndex - }) + }.orElse{ + tryMergePlans(plan, header.plan).map { + case (mergedPlan, outputMap) => + val mappedOutput = mapAttributes(output, outputMap) + var headerIndex = header.attributes.indexWhere(_.exprId == mappedOutput.exprId) + val newHeaderAttributes = if (headerIndex == -1) { + headerIndex = header.attributes.size + header.attributes :+ mappedOutput + } else { + header.attributes + } + cache(subqueryIndex) = + Header(newHeaderAttributes, mergedPlan, true, header.references ++ references) + subqueryIndex -> headerIndex + } + } + case _ => None }).getOrElse { - cache += Header(Seq(output), plan, false) + cache += Header(Seq(output), plan, false, references.toSet) cache.length - 1 -> 0 } } @@ -210,12 +230,6 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] { cachedPlan: LogicalPlan): Option[(LogicalPlan, AttributeMap[Attribute])] = { checkIdenticalPlans(newPlan, cachedPlan).map(cachedPlan -> _).orElse( (newPlan, cachedPlan) match { - case (_, _) if newPlan.containsPattern(SCALAR_SUBQUERY_REFERENCE) || - cachedPlan.containsPattern(SCALAR_SUBQUERY_REFERENCE) => - // Subquery expressions with nested subquery expressions within are not supported for now. - // TODO: support this optimization by collecting the transitive subquery references in the - // new plan and recording them in order to suppress merging the new plan into those. - None case (np: Project, cp: Project) => tryMergePlans(np.child, cp.child).map { case (mergedChild, outputMap) => val (mergedProjectList, newOutputMap) = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 765c2a5223759..a86ddf388a9b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -2157,7 +2157,7 @@ class SubquerySuite extends QueryTest } } - test("SPARK-40618: Do not merge scalar subqueries with nested subqueries inside") { + test("Merge non-correlated scalar subqueries from different parent plans") { Seq(false, true).foreach { enableAQE => withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { @@ -2189,13 +2189,13 @@ class SubquerySuite extends QueryTest } if (enableAQE) { - assert(subqueryIds.size == 4, "Missing or unexpected SubqueryExec in the plan") - assert(reusedSubqueryIds.size == 2, - "Missing or unexpected reused ReusedSubqueryExec in the plan") - } else { assert(subqueryIds.size == 3, "Missing or unexpected SubqueryExec in the plan") assert(reusedSubqueryIds.size == 3, "Missing or unexpected reused ReusedSubqueryExec in the plan") + } else { + assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan") + assert(reusedSubqueryIds.size == 4, + "Missing or unexpected reused ReusedSubqueryExec in the plan") } } } @@ -2332,9 +2332,32 @@ class SubquerySuite extends QueryTest // This test contains a subquery expression with another subquery expression nested inside. // It acts as a regression test to ensure that the MergeScalarSubqueries rule does not attempt // to merge them together. - withTable("t") { + withTable("t", "t2") { sql("create table t(col int) using csv") checkAnswer(sql("select(select sum((select sum(col) from t)) from t)"), Row(null)) + + checkAnswer(sql( + """ + |select + | (select sum( + | (select sum( + | (select sum(col) from t)) + | from t)) + | from t) + |""".stripMargin), + Row(null)) + + sql("create table t2(col int) using csv") + checkAnswer(sql( + """ + |select + | (select sum( + | (select sum( + | (select sum(col) from t)) + | from t2)) + | from t) + |""".stripMargin), + Row(null)) } } }