diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 49f0d438d0a23..fd5a86b3ba403 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -145,9 +145,9 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB def checkAnalysis(plan: LogicalPlan): Unit = { val inlineCTE = InlineCTE(alwaysInline = true) - val cteMap = mutable.HashMap.empty[Long, (CTERelationDef, Int)] + val cteMap = mutable.HashMap.empty[Long, (CTERelationDef, Int, mutable.Map[Long, Int])] inlineCTE.buildCTEMap(plan, cteMap) - cteMap.values.foreach { case (relation, refCount) => + cteMap.values.foreach { case (relation, refCount, _) => // If a CTE relation is never used, it will disappear after inline. Here we explicitly check // analysis for it, to make sure the entire query plan is valid. if (refCount == 0) checkAnalysis0(relation.child) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala index 1e4364b3f4a9d..8d7ff4cbf163d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala @@ -42,8 +42,9 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { if (!plan.isInstanceOf[Subquery] && plan.containsPattern(CTE)) { - val cteMap = mutable.HashMap.empty[Long, (CTERelationDef, Int)] + val cteMap = mutable.SortedMap.empty[Long, (CTERelationDef, Int, mutable.Map[Long, Int])] buildCTEMap(plan, cteMap) + cleanCTEMap(cteMap) val notInlined = mutable.ArrayBuffer.empty[CTERelationDef] val inlined = inlineCTE(plan, cteMap, notInlined) // CTEs in SQL Commands have been inlined by `CTESubstitution` already, so it is safe to add @@ -68,50 +69,91 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] { cteDef.child.exists(_.expressions.exists(_.isInstanceOf[OuterReference])) } + /** + * Accumulates all the CTEs from a plan into a special map. + * + * @param plan The plan to collect the CTEs from + * @param cteMap A mutable map that accumulates the CTEs and their reference information by CTE + * ids. The value of the map is tuple whose elements are: + * - The CTE definition + * - The number of incoming references to the CTE. This includes references from + * other CTEs and regular places. + * - A mutable inner map that tracks outgoing references (counts) to other CTEs. + * @param outerCTEId While collecting the map we use this optional CTE id to identify the + * current outer CTE. + */ def buildCTEMap( plan: LogicalPlan, - cteMap: mutable.HashMap[Long, (CTERelationDef, Int)]): Unit = { + cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])], + outerCTEId: Option[Long] = None): Unit = { plan match { - case WithCTE(_, cteDefs) => + case WithCTE(child, cteDefs) => + cteDefs.foreach { cteDef => + cteMap(cteDef.id) = (cteDef, 0, mutable.Map.empty.withDefaultValue(0)) + } cteDefs.foreach { cteDef => - cteMap.put(cteDef.id, (cteDef, 0)) + buildCTEMap(cteDef, cteMap, Some(cteDef.id)) } + buildCTEMap(child, cteMap, outerCTEId) case ref: CTERelationRef => - val (cteDef, refCount) = cteMap(ref.cteId) - cteMap.update(ref.cteId, (cteDef, refCount + 1)) + val (cteDef, refCount, refMap) = cteMap(ref.cteId) + cteMap(ref.cteId) = (cteDef, refCount + 1, refMap) + outerCTEId.foreach { cteId => + val (_, _, outerRefMap) = cteMap(cteId) + outerRefMap(ref.cteId) += 1 + } case _ => - } - - if (plan.containsPattern(CTE)) { - plan.children.foreach { child => - buildCTEMap(child, cteMap) - } + if (plan.containsPattern(CTE)) { + plan.children.foreach { child => + buildCTEMap(child, cteMap, outerCTEId) + } - plan.expressions.foreach { expr => - if (expr.containsAllPatterns(PLAN_EXPRESSION, CTE)) { - expr.foreach { - case e: SubqueryExpression => - buildCTEMap(e.plan, cteMap) - case _ => + plan.expressions.foreach { expr => + if (expr.containsAllPatterns(PLAN_EXPRESSION, CTE)) { + expr.foreach { + case e: SubqueryExpression => buildCTEMap(e.plan, cteMap, outerCTEId) + case _ => + } + } } } + } + } + + /** + * Cleans the CTE map by removing those CTEs that are not referenced at all and corrects those + * CTE's reference counts where the removed CTE referred to. + * + * @param cteMap A mutable map that accumulates the CTEs and their reference information by CTE + * ids. Needs to be sorted to speed up cleaning. + */ + private def cleanCTEMap( + cteMap: mutable.SortedMap[Long, (CTERelationDef, Int, mutable.Map[Long, Int])] + ) = { + cteMap.keys.toSeq.reverse.foreach { currentCTEId => + val (_, currentRefCount, refMap) = cteMap(currentCTEId) + if (currentRefCount == 0) { + refMap.foreach { case (referencedCTEId, uselessRefCount) => + val (cteDef, refCount, refMap) = cteMap(referencedCTEId) + cteMap(referencedCTEId) = (cteDef, refCount - uselessRefCount, refMap) + } } } } private def inlineCTE( plan: LogicalPlan, - cteMap: mutable.HashMap[Long, (CTERelationDef, Int)], + cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])], notInlined: mutable.ArrayBuffer[CTERelationDef]): LogicalPlan = { plan match { case WithCTE(child, cteDefs) => cteDefs.foreach { cteDef => - val (cte, refCount) = cteMap(cteDef.id) + val (cte, refCount, refMap) = cteMap(cteDef.id) if (refCount > 0) { val inlined = cte.copy(child = inlineCTE(cte.child, cteMap, notInlined)) - cteMap.update(cteDef.id, (inlined, refCount)) + cteMap(cteDef.id) = (inlined, refCount, refMap) if (!shouldInline(inlined, refCount)) { notInlined.append(inlined) } @@ -120,7 +162,7 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] { inlineCTE(child, cteMap, notInlined) case ref: CTERelationRef => - val (cteDef, refCount) = cteMap(ref.cteId) + val (cteDef, refCount, _) = cteMap(ref.cteId) if (shouldInline(cteDef, refCount)) { if (ref.outputSet == cteDef.outputSet) { cteDef.child diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 123364f18ceb4..e14a01e15a36c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -4648,6 +4648,17 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark sql("SELECT /*+ hash(t2) */ * FROM t1 join t2 on c1 = c2") } } + + test("SPARK-43199: InlineCTE is idempotent") { + sql( + """ + |WITH + | x(r) AS (SELECT random()), + | y(r) AS (SELECT * FROM x), + | z(r) AS (SELECT * FROM x) + |SELECT * FROM z + |""".stripMargin).collect() + } } case class Foo(bar: Option[String])