diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala index 976a5d385d87..62ebfa834318 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.analysis -import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.catalyst.plans.logical.{Command, CTERelationDef, CTERelationRef, InsertIntoDir, LogicalPlan, ParsedStatement, SubqueryAlias, UnresolvedWith, WithCTE} @@ -55,27 +55,27 @@ object CTESubstitution extends Rule[LogicalPlan] { case _: Command | _: ParsedStatement | _: InsertIntoDir => true case _ => false } - val cteDefs = mutable.ArrayBuffer.empty[CTERelationDef] + val cteDefs = ArrayBuffer.empty[CTERelationDef] val (substituted, lastSubstituted) = LegacyBehaviorPolicy.withName(conf.getConf(LEGACY_CTE_PRECEDENCE_POLICY)) match { case LegacyBehaviorPolicy.EXCEPTION => assertNoNameConflictsInCTE(plan) - traverseAndSubstituteCTE(plan, isCommand, cteDefs) + traverseAndSubstituteCTE(plan, isCommand, Seq.empty, cteDefs) case LegacyBehaviorPolicy.LEGACY => (legacyTraverseAndSubstituteCTE(plan, cteDefs), None) case LegacyBehaviorPolicy.CORRECTED => - traverseAndSubstituteCTE(plan, isCommand, cteDefs) + traverseAndSubstituteCTE(plan, isCommand, Seq.empty, cteDefs) } if (cteDefs.isEmpty) { substituted } else if (substituted eq lastSubstituted.get) { - WithCTE(substituted, cteDefs.sortBy(_.id).toSeq) + WithCTE(substituted, cteDefs.toSeq) } else { var done = false substituted.resolveOperatorsWithPruning(_ => !done) { case p if p eq lastSubstituted.get => done = true - WithCTE(p, cteDefs.sortBy(_.id).toSeq) + WithCTE(p, cteDefs.toSeq) } } } @@ -98,7 +98,7 @@ object CTESubstitution extends Rule[LogicalPlan] { val resolver = conf.resolver plan match { case UnresolvedWith(child, relations) => - val newNames = mutable.ArrayBuffer.empty[String] + val newNames = ArrayBuffer.empty[String] newNames ++= outerCTERelationNames relations.foreach { case (name, relation) => @@ -121,11 +121,11 @@ object CTESubstitution extends Rule[LogicalPlan] { private def legacyTraverseAndSubstituteCTE( plan: LogicalPlan, - cteDefs: mutable.ArrayBuffer[CTERelationDef]): LogicalPlan = { + cteDefs: ArrayBuffer[CTERelationDef]): LogicalPlan = { plan.resolveOperatorsUp { case UnresolvedWith(child, relations) => val resolvedCTERelations = - resolveCTERelations(relations, isLegacy = true, isCommand = false, cteDefs) + resolveCTERelations(relations, isLegacy = true, isCommand = false, Seq.empty, cteDefs) substituteCTE(child, alwaysInline = true, resolvedCTERelations) } } @@ -170,21 +170,23 @@ object CTESubstitution extends Rule[LogicalPlan] { * SELECT * FROM t * ) * @param plan the plan to be traversed - * @return the plan where CTE substitution is applied + * @param isCommand if this is a command + * @param outerCTEDefs already resolved outer CTE definitions with names + * @param cteDefs all accumulated CTE definitions + * @return the plan where CTE substitution is applied and optionally the last substituted `With` + * where CTE definitions will be gathered to */ private def traverseAndSubstituteCTE( plan: LogicalPlan, isCommand: Boolean, - cteDefs: mutable.ArrayBuffer[CTERelationDef]): (LogicalPlan, Option[LogicalPlan]) = { + outerCTEDefs: Seq[(String, CTERelationDef)], + cteDefs: ArrayBuffer[CTERelationDef]): (LogicalPlan, Option[LogicalPlan]) = { var lastSubstituted: Option[LogicalPlan] = None val newPlan = plan.resolveOperatorsUpWithPruning( _.containsAnyPattern(UNRESOLVED_WITH, PLAN_EXPRESSION)) { case UnresolvedWith(child: LogicalPlan, relations) => val resolvedCTERelations = - resolveCTERelations(relations, isLegacy = false, isCommand, cteDefs) - if (!isCommand) { - cteDefs ++= resolvedCTERelations.map(_._2) - } + resolveCTERelations(relations, isLegacy = false, isCommand, outerCTEDefs, cteDefs) lastSubstituted = Some(substituteCTE(child, isCommand, resolvedCTERelations)) lastSubstituted.get @@ -200,10 +202,14 @@ object CTESubstitution extends Rule[LogicalPlan] { relations: Seq[(String, SubqueryAlias)], isLegacy: Boolean, isCommand: Boolean, - cteDefs: mutable.ArrayBuffer[CTERelationDef]): Seq[(String, CTERelationDef)] = { - val resolvedCTERelations = new mutable.ArrayBuffer[(String, CTERelationDef)](relations.size) + outerCTEDefs: Seq[(String, CTERelationDef)], + cteDefs: ArrayBuffer[CTERelationDef]): Seq[(String, CTERelationDef)] = { + var resolvedCTERelations = if (isLegacy || isCommand) { + Seq.empty + } else { + outerCTEDefs + } for ((name, relation) <- relations) { - val lastCTEDefCount = cteDefs.length val innerCTEResolved = if (isLegacy) { // In legacy mode, outer CTE relations take precedence. Here we don't resolve the inner // `With` nodes, later we will substitute `UnresolvedRelation`s with outer CTE relations. @@ -221,31 +227,18 @@ object CTESubstitution extends Rule[LogicalPlan] { // WITH t3 AS (SELECT * FROM t1) // ) // t3 should resolve the t1 to `SELECT 2` instead of `SELECT 1`. - traverseAndSubstituteCTE(relation, isCommand, cteDefs)._1 - } - - if (cteDefs.length > lastCTEDefCount) { - // We have added more CTE relations to the `cteDefs` from the inner CTE, and these relations - // should also be substituted with `resolvedCTERelations` as inner CTE relation can refer to - // outer CTE relation. For example: - // WITH t1 AS (SELECT 1) - // t2 AS ( - // WITH t3 AS (SELECT * FROM t1) - // ) - for (i <- lastCTEDefCount until cteDefs.length) { - val substituted = - substituteCTE(cteDefs(i).child, isLegacy || isCommand, resolvedCTERelations.toSeq) - cteDefs(i) = cteDefs(i).copy(child = substituted) - } + traverseAndSubstituteCTE(relation, isCommand, resolvedCTERelations, cteDefs)._1 } - // CTE definition can reference a previous one - val substituted = - substituteCTE(innerCTEResolved, isLegacy || isCommand, resolvedCTERelations.toSeq) + val substituted = substituteCTE(innerCTEResolved, isLegacy || isCommand, resolvedCTERelations) val cteRelation = CTERelationDef(substituted) - resolvedCTERelations += (name -> cteRelation) + if (!(isLegacy || isCommand)) { + cteDefs += cteRelation + } + // Prepending new CTEs makes sure that those have higher priority over outer ones. + resolvedCTERelations +:= (name -> cteRelation) } - resolvedCTERelations.toSeq + resolvedCTERelations } private def substituteCTE( diff --git a/sql/core/src/test/resources/sql-tests/inputs/cte-nested.sql b/sql/core/src/test/resources/sql-tests/inputs/cte-nested.sql index 3b64b5daa82d..b5d7fa5687bc 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/cte-nested.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/cte-nested.sql @@ -135,4 +135,15 @@ WITH abc AS (SELECT 1) SELECT ( WITH aBc AS (SELECT 2) SELECT * FROM aBC -); \ No newline at end of file +); + +-- SPARK-38404: CTE in CTE definition references outer +WITH + t1 AS (SELECT 1), + t2 AS ( + WITH t3 AS ( + SELECT * FROM t1 + ) + SELECT * FROM t3 + ) +SELECT * FROM t2; \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/cte-legacy.sql.out b/sql/core/src/test/resources/sql-tests/results/cte-legacy.sql.out index 4d0e5ea829d3..db7d420a745c 100644 --- a/sql/core/src/test/resources/sql-tests/results/cte-legacy.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/cte-legacy.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 16 +-- Number of queries: 17 -- !query @@ -219,3 +219,20 @@ SELECT ( struct -- !query output 1 + + +-- !query +WITH + t1 AS (SELECT 1), + t2 AS ( + WITH t3 AS ( + SELECT * FROM t1 + ) + SELECT * FROM t3 + ) +SELECT * FROM t2 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Table or view not found: t1; line 5 pos 20 diff --git a/sql/core/src/test/resources/sql-tests/results/cte-nested.sql.out b/sql/core/src/test/resources/sql-tests/results/cte-nested.sql.out index a8db4599dafc..f714a11d1df3 100644 --- a/sql/core/src/test/resources/sql-tests/results/cte-nested.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/cte-nested.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 16 +-- Number of queries: 17 -- !query @@ -227,3 +227,19 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException Name aBc is ambiguous in nested CTE. Please set spark.sql.legacy.ctePrecedencePolicy to CORRECTED so that name defined in inner CTE takes precedence. If set it to LEGACY, outer CTE definitions will take precedence. See more details in SPARK-28228. + + +-- !query +WITH + t1 AS (SELECT 1), + t2 AS ( + WITH t3 AS ( + SELECT * FROM t1 + ) + SELECT * FROM t3 + ) +SELECT * FROM t2 +-- !query schema +struct<1:int> +-- !query output +1 diff --git a/sql/core/src/test/resources/sql-tests/results/cte-nonlegacy.sql.out b/sql/core/src/test/resources/sql-tests/results/cte-nonlegacy.sql.out index 74394ee3ffc8..2ab13003d04d 100644 --- a/sql/core/src/test/resources/sql-tests/results/cte-nonlegacy.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/cte-nonlegacy.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 16 +-- Number of queries: 17 -- !query @@ -219,3 +219,19 @@ SELECT ( struct -- !query output 2 + + +-- !query +WITH + t1 AS (SELECT 1), + t2 AS ( + WITH t3 AS ( + SELECT * FROM t1 + ) + SELECT * FROM t3 + ) +SELECT * FROM t2 +-- !query schema +struct<1:int> +-- !query output +1