@@ -37,23 +37,19 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{CTE, PLAN_EXPRESSION}
3737 * query level.
3838 *
3939 * @param alwaysInline if true, inline all CTEs in the query plan.
40+ * @param keepDanglingRelations if true, dangling CTE relations will be kept in the original
41+ * `WithCTE` node.
4042 */
41- case class InlineCTE (alwaysInline : Boolean = false ) extends Rule [LogicalPlan ] {
43+ case class InlineCTE (
44+ alwaysInline : Boolean = false ,
45+ keepDanglingRelations : Boolean = false ) extends Rule [LogicalPlan ] {
4246
4347 override def apply (plan : LogicalPlan ): LogicalPlan = {
4448 if (! plan.isInstanceOf [Subquery ] && plan.containsPattern(CTE )) {
45- val cteMap = mutable.SortedMap .empty[Long , ( CTERelationDef , Int , mutable. Map [ Long , Int ]) ]
49+ val cteMap = mutable.SortedMap .empty[Long , CTEReferenceInfo ]
4650 buildCTEMap(plan, cteMap)
4751 cleanCTEMap(cteMap)
48- val notInlined = mutable.ArrayBuffer .empty[CTERelationDef ]
49- val inlined = inlineCTE(plan, cteMap, notInlined)
50- // CTEs in SQL Commands have been inlined by `CTESubstitution` already, so it is safe to add
51- // WithCTE as top node here.
52- if (notInlined.isEmpty) {
53- inlined
54- } else {
55- WithCTE (inlined, notInlined.toSeq)
56- }
52+ inlineCTE(plan, cteMap)
5753 } else {
5854 plan
5955 }
@@ -74,34 +70,33 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] {
7470 *
7571 * @param plan The plan to collect the CTEs from
7672 * @param cteMap A mutable map that accumulates the CTEs and their reference information by CTE
77- * ids. The value of the map is tuple whose elements are:
78- * - The CTE definition
79- * - The number of incoming references to the CTE. This includes references from
80- * other CTEs and regular places.
81- * - A mutable inner map that tracks outgoing references (counts) to other CTEs.
73+ * ids.
8274 * @param outerCTEId While collecting the map we use this optional CTE id to identify the
8375 * current outer CTE.
8476 */
85- def buildCTEMap (
77+ private def buildCTEMap (
8678 plan : LogicalPlan ,
87- cteMap : mutable.Map [Long , ( CTERelationDef , Int , mutable. Map [ Long , Int ]) ],
79+ cteMap : mutable.Map [Long , CTEReferenceInfo ],
8880 outerCTEId : Option [Long ] = None ): Unit = {
8981 plan match {
9082 case WithCTE (child, cteDefs) =>
9183 cteDefs.foreach { cteDef =>
92- cteMap(cteDef.id) = (cteDef, 0 , mutable.Map .empty.withDefaultValue(0 ))
84+ cteMap(cteDef.id) = CTEReferenceInfo (
85+ cteDef = cteDef,
86+ refCount = 0 ,
87+ outgoingRefs = mutable.Map .empty.withDefaultValue(0 ),
88+ shouldInline = true
89+ )
9390 }
9491 cteDefs.foreach { cteDef =>
9592 buildCTEMap(cteDef, cteMap, Some (cteDef.id))
9693 }
9794 buildCTEMap(child, cteMap, outerCTEId)
9895
9996 case ref : CTERelationRef =>
100- val (cteDef, refCount, refMap) = cteMap(ref.cteId)
101- cteMap(ref.cteId) = (cteDef, refCount + 1 , refMap)
97+ cteMap(ref.cteId) = cteMap(ref.cteId).withRefCountIncreased(1 )
10298 outerCTEId.foreach { cteId =>
103- val (_, _, outerRefMap) = cteMap(cteId)
104- outerRefMap(ref.cteId) += 1
99+ cteMap(cteId).increaseOutgoingRefCount(ref.cteId, 1 )
105100 }
106101
107102 case _ =>
@@ -129,46 +124,58 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] {
129124 * @param cteMap A mutable map that accumulates the CTEs and their reference information by CTE
130125 * ids. Needs to be sorted to speed up cleaning.
131126 */
132- private def cleanCTEMap (
133- cteMap : mutable.SortedMap [Long , (CTERelationDef , Int , mutable.Map [Long , Int ])]
134- ) = {
127+ private def cleanCTEMap (cteMap : mutable.SortedMap [Long , CTEReferenceInfo ]): Unit = {
135128 cteMap.keys.toSeq.reverse.foreach { currentCTEId =>
136- val (_, currentRefCount, refMap) = cteMap(currentCTEId)
137- if (currentRefCount == 0 ) {
138- refMap.foreach { case (referencedCTEId, uselessRefCount) =>
139- val (cteDef, refCount, refMap) = cteMap(referencedCTEId)
140- cteMap(referencedCTEId) = (cteDef, refCount - uselessRefCount, refMap)
129+ val refInfo = cteMap(currentCTEId)
130+ if (refInfo.refCount == 0 ) {
131+ refInfo.outgoingRefs.foreach { case (referencedCTEId, uselessRefCount) =>
132+ cteMap(referencedCTEId) = cteMap(referencedCTEId).withRefCountDecreased(uselessRefCount)
141133 }
142134 }
143135 }
144136 }
145137
146138 private def inlineCTE (
147139 plan : LogicalPlan ,
148- cteMap : mutable.Map [Long , (CTERelationDef , Int , mutable.Map [Long , Int ])],
149- notInlined : mutable.ArrayBuffer [CTERelationDef ]): LogicalPlan = {
140+ cteMap : mutable.Map [Long , CTEReferenceInfo ]): LogicalPlan = {
150141 plan match {
151142 case WithCTE (child, cteDefs) =>
152- cteDefs.foreach { cteDef =>
153- val (cte, refCount, refMap) = cteMap(cteDef.id)
154- if (refCount > 0 ) {
155- val inlined = cte.copy(child = inlineCTE(cte.child, cteMap, notInlined))
156- cteMap(cteDef.id) = (inlined, refCount, refMap)
157- if (! shouldInline(inlined, refCount)) {
158- notInlined.append(inlined)
159- }
143+ val remainingDefs = cteDefs.filter { cteDef =>
144+ val refInfo = cteMap(cteDef.id)
145+ if (refInfo.refCount > 0 ) {
146+ val newDef = refInfo.cteDef.copy(child = inlineCTE(refInfo.cteDef.child, cteMap))
147+ val inlineDecision = shouldInline(newDef, refInfo.refCount)
148+ cteMap(cteDef.id) = cteMap(cteDef.id).copy(
149+ cteDef = newDef, shouldInline = inlineDecision
150+ )
151+ // Retain the not-inlined CTE relations in place.
152+ ! inlineDecision
153+ } else {
154+ keepDanglingRelations
160155 }
161156 }
162- inlineCTE(child, cteMap, notInlined)
157+ val inlined = inlineCTE(child, cteMap)
158+ if (remainingDefs.isEmpty) {
159+ inlined
160+ } else {
161+ WithCTE (inlined, remainingDefs)
162+ }
163163
164164 case ref : CTERelationRef =>
165- val (cteDef, refCount, _) = cteMap(ref.cteId)
166- if (shouldInline(cteDef, refCount) ) {
167- if (ref.outputSet == cteDef.outputSet) {
168- cteDef.child
165+ val refInfo = cteMap(ref.cteId)
166+ if (refInfo. shouldInline) {
167+ if (ref.outputSet == refInfo. cteDef.outputSet) {
168+ refInfo. cteDef.child
169169 } else {
170170 val ctePlan = DeduplicateRelations (
171- Join (cteDef.child, cteDef.child, Inner , None , JoinHint (None , None ))).children(1 )
171+ Join (
172+ refInfo.cteDef.child,
173+ refInfo.cteDef.child,
174+ Inner ,
175+ None ,
176+ JoinHint (None , None )
177+ )
178+ ).children(1 )
172179 val projectList = ref.output.zip(ctePlan.output).map { case (tgtAttr, srcAttr) =>
173180 if (srcAttr.semanticEquals(tgtAttr)) {
174181 tgtAttr
@@ -184,13 +191,41 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] {
184191
185192 case _ if plan.containsPattern(CTE ) =>
186193 plan
187- .withNewChildren(plan.children.map(child => inlineCTE(child, cteMap, notInlined )))
194+ .withNewChildren(plan.children.map(child => inlineCTE(child, cteMap)))
188195 .transformExpressionsWithPruning(_.containsAllPatterns(PLAN_EXPRESSION , CTE )) {
189196 case e : SubqueryExpression =>
190- e.withNewPlan(inlineCTE(e.plan, cteMap, notInlined ))
197+ e.withNewPlan(inlineCTE(e.plan, cteMap))
191198 }
192199
193200 case _ => plan
194201 }
195202 }
196203}
204+
205+ /**
206+ * The bookkeeping information for tracking CTE relation references.
207+ *
208+ * @param cteDef The CTE relation definition
209+ * @param refCount The number of incoming references to this CTE relation. This includes references
210+ * from other CTE relations and regular places.
211+ * @param outgoingRefs A mutable map that tracks outgoing reference counts to other CTE relations.
212+ * @param shouldInline If true, this CTE relation should be inlined in the places that reference it.
213+ */
214+ case class CTEReferenceInfo (
215+ cteDef : CTERelationDef ,
216+ refCount : Int ,
217+ outgoingRefs : mutable.Map [Long , Int ],
218+ shouldInline : Boolean ) {
219+
220+ def withRefCountIncreased (count : Int ): CTEReferenceInfo = {
221+ copy(refCount = refCount + count)
222+ }
223+
224+ def withRefCountDecreased (count : Int ): CTEReferenceInfo = {
225+ copy(refCount = refCount - count)
226+ }
227+
228+ def increaseOutgoingRefCount (cteDefId : Long , count : Int ): Unit = {
229+ outgoingRefs(cteDefId) += count
230+ }
231+ }
0 commit comments