Skip to content

Commit 8a0927c

Browse files
committed
[SPARK-48307][SQL] InlineCTE should keep not-inlined relations in the original WithCTE node
### What changes were proposed in this pull request? I noticed an outdated comment in the rule `InlineCTE` ``` // CTEs in SQL Commands have been inlined by `CTESubstitution` already, so it is safe to add // WithCTE as top node here. ``` This is not true anymore after #42036 . It's not a big deal as we replace not-inlined CTE relations with `Repartition` during optimization, so it doesn't matter where we put the `WithCTE` node with not-inlined CTE relations, as it will disappear eventually. But it's still better to keep it at its original place, as third-party rules may be sensitive about the plan shape. ### Why are the changes needed? to keep the plan shape as much as can after inlining CTE relations. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? new test ### Was this patch authored or co-authored using generative AI tooling? no Closes #46617 from cloud-fan/cte. Lead-authored-by: Wenchen Fan <[email protected]> Co-authored-by: Wenchen Fan <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent c7caac9 commit 8a0927c

File tree

3 files changed

+132
-88
lines changed

3 files changed

+132
-88
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 6 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -143,50 +143,17 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
143143
errorClass, missingCol, orderedCandidates, a.origin)
144144
}
145145

146-
private def checkUnreferencedCTERelations(
147-
cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])],
148-
visited: mutable.Map[Long, Boolean],
149-
danglingCTERelations: mutable.ArrayBuffer[CTERelationDef],
150-
cteId: Long): Unit = {
151-
if (visited(cteId)) {
152-
return
153-
}
154-
val (cteDef, _, refMap) = cteMap(cteId)
155-
refMap.foreach { case (id, _) =>
156-
checkUnreferencedCTERelations(cteMap, visited, danglingCTERelations, id)
157-
}
158-
danglingCTERelations.append(cteDef)
159-
visited(cteId) = true
160-
}
161-
162146
def checkAnalysis(plan: LogicalPlan): Unit = {
163-
val inlineCTE = InlineCTE(alwaysInline = true)
164-
val cteMap = mutable.HashMap.empty[Long, (CTERelationDef, Int, mutable.Map[Long, Int])]
165-
inlineCTE.buildCTEMap(plan, cteMap)
166-
val danglingCTERelations = mutable.ArrayBuffer.empty[CTERelationDef]
167-
val visited: mutable.Map[Long, Boolean] = mutable.Map.empty.withDefaultValue(false)
168-
// If a CTE relation is never used, it will disappear after inline. Here we explicitly collect
169-
// these dangling CTE relations, and put them back in the main query, to make sure the entire
170-
// query plan is valid.
171-
cteMap.foreach { case (cteId, (_, refCount, _)) =>
172-
// If a CTE relation ref count is 0, the other CTE relations that reference it should also be
173-
// collected. This code will also guarantee the leaf relations that do not reference
174-
// any others are collected first.
175-
if (refCount == 0) {
176-
checkUnreferencedCTERelations(cteMap, visited, danglingCTERelations, cteId)
177-
}
178-
}
179-
// Inline all CTEs in the plan to help check query plan structures in subqueries.
180-
var inlinedPlan: LogicalPlan = plan
181-
try {
182-
inlinedPlan = inlineCTE(plan)
147+
// We should inline all CTE relations to restore the original plan shape, as the analysis check
148+
// may need to match certain plan shapes. For dangling CTE relations, they will still be kept
149+
// in the original `WithCTE` node, as we need to perform analysis check for them as well.
150+
val inlineCTE = InlineCTE(alwaysInline = true, keepDanglingRelations = true)
151+
val inlinedPlan: LogicalPlan = try {
152+
inlineCTE(plan)
183153
} catch {
184154
case e: AnalysisException =>
185155
throw new ExtendedAnalysisException(e, plan)
186156
}
187-
if (danglingCTERelations.nonEmpty) {
188-
inlinedPlan = WithCTE(inlinedPlan, danglingCTERelations.toSeq)
189-
}
190157
try {
191158
checkAnalysis0(inlinedPlan)
192159
} catch {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala

Lines changed: 84 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -37,23 +37,19 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{CTE, PLAN_EXPRESSION}
3737
* query level.
3838
*
3939
* @param alwaysInline if true, inline all CTEs in the query plan.
40+
* @param keepDanglingRelations if true, dangling CTE relations will be kept in the original
41+
* `WithCTE` node.
4042
*/
41-
case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] {
43+
case class InlineCTE(
44+
alwaysInline: Boolean = false,
45+
keepDanglingRelations: Boolean = false) extends Rule[LogicalPlan] {
4246

4347
override def apply(plan: LogicalPlan): LogicalPlan = {
4448
if (!plan.isInstanceOf[Subquery] && plan.containsPattern(CTE)) {
45-
val cteMap = mutable.SortedMap.empty[Long, (CTERelationDef, Int, mutable.Map[Long, Int])]
49+
val cteMap = mutable.SortedMap.empty[Long, CTEReferenceInfo]
4650
buildCTEMap(plan, cteMap)
4751
cleanCTEMap(cteMap)
48-
val notInlined = mutable.ArrayBuffer.empty[CTERelationDef]
49-
val inlined = inlineCTE(plan, cteMap, notInlined)
50-
// CTEs in SQL Commands have been inlined by `CTESubstitution` already, so it is safe to add
51-
// WithCTE as top node here.
52-
if (notInlined.isEmpty) {
53-
inlined
54-
} else {
55-
WithCTE(inlined, notInlined.toSeq)
56-
}
52+
inlineCTE(plan, cteMap)
5753
} else {
5854
plan
5955
}
@@ -74,34 +70,33 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] {
7470
*
7571
* @param plan The plan to collect the CTEs from
7672
* @param cteMap A mutable map that accumulates the CTEs and their reference information by CTE
77-
* ids. The value of the map is tuple whose elements are:
78-
* - The CTE definition
79-
* - The number of incoming references to the CTE. This includes references from
80-
* other CTEs and regular places.
81-
* - A mutable inner map that tracks outgoing references (counts) to other CTEs.
73+
* ids.
8274
* @param outerCTEId While collecting the map we use this optional CTE id to identify the
8375
* current outer CTE.
8476
*/
85-
def buildCTEMap(
77+
private def buildCTEMap(
8678
plan: LogicalPlan,
87-
cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])],
79+
cteMap: mutable.Map[Long, CTEReferenceInfo],
8880
outerCTEId: Option[Long] = None): Unit = {
8981
plan match {
9082
case WithCTE(child, cteDefs) =>
9183
cteDefs.foreach { cteDef =>
92-
cteMap(cteDef.id) = (cteDef, 0, mutable.Map.empty.withDefaultValue(0))
84+
cteMap(cteDef.id) = CTEReferenceInfo(
85+
cteDef = cteDef,
86+
refCount = 0,
87+
outgoingRefs = mutable.Map.empty.withDefaultValue(0),
88+
shouldInline = true
89+
)
9390
}
9491
cteDefs.foreach { cteDef =>
9592
buildCTEMap(cteDef, cteMap, Some(cteDef.id))
9693
}
9794
buildCTEMap(child, cteMap, outerCTEId)
9895

9996
case ref: CTERelationRef =>
100-
val (cteDef, refCount, refMap) = cteMap(ref.cteId)
101-
cteMap(ref.cteId) = (cteDef, refCount + 1, refMap)
97+
cteMap(ref.cteId) = cteMap(ref.cteId).withRefCountIncreased(1)
10298
outerCTEId.foreach { cteId =>
103-
val (_, _, outerRefMap) = cteMap(cteId)
104-
outerRefMap(ref.cteId) += 1
99+
cteMap(cteId).increaseOutgoingRefCount(ref.cteId, 1)
105100
}
106101

107102
case _ =>
@@ -129,46 +124,58 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] {
129124
* @param cteMap A mutable map that accumulates the CTEs and their reference information by CTE
130125
* ids. Needs to be sorted to speed up cleaning.
131126
*/
132-
private def cleanCTEMap(
133-
cteMap: mutable.SortedMap[Long, (CTERelationDef, Int, mutable.Map[Long, Int])]
134-
) = {
127+
private def cleanCTEMap(cteMap: mutable.SortedMap[Long, CTEReferenceInfo]): Unit = {
135128
cteMap.keys.toSeq.reverse.foreach { currentCTEId =>
136-
val (_, currentRefCount, refMap) = cteMap(currentCTEId)
137-
if (currentRefCount == 0) {
138-
refMap.foreach { case (referencedCTEId, uselessRefCount) =>
139-
val (cteDef, refCount, refMap) = cteMap(referencedCTEId)
140-
cteMap(referencedCTEId) = (cteDef, refCount - uselessRefCount, refMap)
129+
val refInfo = cteMap(currentCTEId)
130+
if (refInfo.refCount == 0) {
131+
refInfo.outgoingRefs.foreach { case (referencedCTEId, uselessRefCount) =>
132+
cteMap(referencedCTEId) = cteMap(referencedCTEId).withRefCountDecreased(uselessRefCount)
141133
}
142134
}
143135
}
144136
}
145137

146138
private def inlineCTE(
147139
plan: LogicalPlan,
148-
cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])],
149-
notInlined: mutable.ArrayBuffer[CTERelationDef]): LogicalPlan = {
140+
cteMap: mutable.Map[Long, CTEReferenceInfo]): LogicalPlan = {
150141
plan match {
151142
case WithCTE(child, cteDefs) =>
152-
cteDefs.foreach { cteDef =>
153-
val (cte, refCount, refMap) = cteMap(cteDef.id)
154-
if (refCount > 0) {
155-
val inlined = cte.copy(child = inlineCTE(cte.child, cteMap, notInlined))
156-
cteMap(cteDef.id) = (inlined, refCount, refMap)
157-
if (!shouldInline(inlined, refCount)) {
158-
notInlined.append(inlined)
159-
}
143+
val remainingDefs = cteDefs.filter { cteDef =>
144+
val refInfo = cteMap(cteDef.id)
145+
if (refInfo.refCount > 0) {
146+
val newDef = refInfo.cteDef.copy(child = inlineCTE(refInfo.cteDef.child, cteMap))
147+
val inlineDecision = shouldInline(newDef, refInfo.refCount)
148+
cteMap(cteDef.id) = cteMap(cteDef.id).copy(
149+
cteDef = newDef, shouldInline = inlineDecision
150+
)
151+
// Retain the not-inlined CTE relations in place.
152+
!inlineDecision
153+
} else {
154+
keepDanglingRelations
160155
}
161156
}
162-
inlineCTE(child, cteMap, notInlined)
157+
val inlined = inlineCTE(child, cteMap)
158+
if (remainingDefs.isEmpty) {
159+
inlined
160+
} else {
161+
WithCTE(inlined, remainingDefs)
162+
}
163163

164164
case ref: CTERelationRef =>
165-
val (cteDef, refCount, _) = cteMap(ref.cteId)
166-
if (shouldInline(cteDef, refCount)) {
167-
if (ref.outputSet == cteDef.outputSet) {
168-
cteDef.child
165+
val refInfo = cteMap(ref.cteId)
166+
if (refInfo.shouldInline) {
167+
if (ref.outputSet == refInfo.cteDef.outputSet) {
168+
refInfo.cteDef.child
169169
} else {
170170
val ctePlan = DeduplicateRelations(
171-
Join(cteDef.child, cteDef.child, Inner, None, JoinHint(None, None))).children(1)
171+
Join(
172+
refInfo.cteDef.child,
173+
refInfo.cteDef.child,
174+
Inner,
175+
None,
176+
JoinHint(None, None)
177+
)
178+
).children(1)
172179
val projectList = ref.output.zip(ctePlan.output).map { case (tgtAttr, srcAttr) =>
173180
if (srcAttr.semanticEquals(tgtAttr)) {
174181
tgtAttr
@@ -184,13 +191,41 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] {
184191

185192
case _ if plan.containsPattern(CTE) =>
186193
plan
187-
.withNewChildren(plan.children.map(child => inlineCTE(child, cteMap, notInlined)))
194+
.withNewChildren(plan.children.map(child => inlineCTE(child, cteMap)))
188195
.transformExpressionsWithPruning(_.containsAllPatterns(PLAN_EXPRESSION, CTE)) {
189196
case e: SubqueryExpression =>
190-
e.withNewPlan(inlineCTE(e.plan, cteMap, notInlined))
197+
e.withNewPlan(inlineCTE(e.plan, cteMap))
191198
}
192199

193200
case _ => plan
194201
}
195202
}
196203
}
204+
205+
/**
206+
* The bookkeeping information for tracking CTE relation references.
207+
*
208+
* @param cteDef The CTE relation definition
209+
* @param refCount The number of incoming references to this CTE relation. This includes references
210+
* from other CTE relations and regular places.
211+
* @param outgoingRefs A mutable map that tracks outgoing reference counts to other CTE relations.
212+
* @param shouldInline If true, this CTE relation should be inlined in the places that reference it.
213+
*/
214+
case class CTEReferenceInfo(
215+
cteDef: CTERelationDef,
216+
refCount: Int,
217+
outgoingRefs: mutable.Map[Long, Int],
218+
shouldInline: Boolean) {
219+
220+
def withRefCountIncreased(count: Int): CTEReferenceInfo = {
221+
copy(refCount = refCount + count)
222+
}
223+
224+
def withRefCountDecreased(count: Int): CTEReferenceInfo = {
225+
copy(refCount = refCount - count)
226+
}
227+
228+
def increaseOutgoingRefCount(cteDefId: Long, count: Int): Unit = {
229+
outgoingRefs(cteDefId) += count
230+
}
231+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.apache.spark.sql.catalyst.analysis.TestRelation
21+
import org.apache.spark.sql.catalyst.dsl.expressions._
22+
import org.apache.spark.sql.catalyst.dsl.plans._
23+
import org.apache.spark.sql.catalyst.plans.PlanTest
24+
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CTERelationDef, CTERelationRef, LogicalPlan, OneRowRelation, WithCTE}
25+
import org.apache.spark.sql.catalyst.rules.RuleExecutor
26+
27+
class InlineCTESuite extends PlanTest {
28+
29+
object Optimize extends RuleExecutor[LogicalPlan] {
30+
val batches = Batch("inline CTE", FixedPoint(100), InlineCTE()) :: Nil
31+
}
32+
33+
test("SPARK-48307: not-inlined CTE relation in command") {
34+
val cteDef = CTERelationDef(OneRowRelation().select(rand(0).as("a")))
35+
val cteRef = CTERelationRef(cteDef.id, cteDef.resolved, cteDef.output, cteDef.isStreaming)
36+
val plan = AppendData.byName(
37+
TestRelation(Seq($"a".double)),
38+
WithCTE(cteRef.except(cteRef, isAll = true), Seq(cteDef))
39+
).analyze
40+
comparePlans(Optimize.execute(plan), plan)
41+
}
42+
}

0 commit comments

Comments
 (0)