diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala index 4e3234f9c0dc..09bf49d39360 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions.SubqueryExpression -import org.apache.spark.sql.catalyst.plans.logical.{Command, CTERelationDef, CTERelationRef, InsertIntoDir, LogicalPlan, ParsedStatement, SubqueryAlias, UnresolvedWith, WithCTE} +import org.apache.spark.sql.catalyst.plans.logical.{CTEInChildren, CTERelationDef, CTERelationRef, LogicalPlan, SubqueryAlias, UnresolvedWith, WithCTE} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util.TypeUtils._ @@ -30,8 +30,7 @@ import org.apache.spark.sql.internal.SQLConf.{LEGACY_CTE_PRECEDENCE_POLICY, Lega /** * Analyze WITH nodes and substitute child plan with CTE references or CTE definitions depending * on the conditions below: - * 1. If in legacy mode, or if the query is a SQL command or DML statement, replace with CTE - * definitions, i.e., inline CTEs. + * 1. If in legacy mode, replace with CTE definitions, i.e., inline CTEs. * 2. Otherwise, replace with CTE references `CTERelationRef`s. The decision to inline or not * inline will be made later by the rule `InlineCTE` after query analysis. * @@ -46,42 +45,41 @@ import org.apache.spark.sql.internal.SQLConf.{LEGACY_CTE_PRECEDENCE_POLICY, Lega * dependency for any valid CTE query (i.e., given CTE definitions A and B with B referencing A, * A is guaranteed to appear before B). Otherwise, it must be an invalid user query, and an * analysis exception will be thrown later by relation resolving rules. + * + * If the query is a SQL command or DML statement (extends `CTEInChildren`), + * place `WithCTE` into their children. */ object CTESubstitution extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { if (!plan.containsPattern(UNRESOLVED_WITH)) { return plan } - val isCommand = plan.exists { - case _: Command | _: ParsedStatement | _: InsertIntoDir => true - case _ => false - } val cteDefs = ArrayBuffer.empty[CTERelationDef] val (substituted, firstSubstituted) = LegacyBehaviorPolicy.withName(conf.getConf(LEGACY_CTE_PRECEDENCE_POLICY)) match { case LegacyBehaviorPolicy.EXCEPTION => assertNoNameConflictsInCTE(plan) - traverseAndSubstituteCTE(plan, isCommand, Seq.empty, cteDefs) + traverseAndSubstituteCTE(plan, Seq.empty, cteDefs) case LegacyBehaviorPolicy.LEGACY => (legacyTraverseAndSubstituteCTE(plan, cteDefs), None) case LegacyBehaviorPolicy.CORRECTED => - traverseAndSubstituteCTE(plan, isCommand, Seq.empty, cteDefs) + traverseAndSubstituteCTE(plan, Seq.empty, cteDefs) } if (cteDefs.isEmpty) { substituted } else if (substituted eq firstSubstituted.get) { - WithCTE(substituted, cteDefs.toSeq) + withCTEDefs(substituted, cteDefs.toSeq) } else { var done = false substituted.resolveOperatorsWithPruning(_ => !done) { case p if p eq firstSubstituted.get => // `firstSubstituted` is the parent of all other CTEs (if any). done = true - WithCTE(p, cteDefs.toSeq) + withCTEDefs(p, cteDefs.toSeq) case p if p.children.count(_.containsPattern(CTE)) > 1 => // This is the first common parent of all CTEs. done = true - WithCTE(p, cteDefs.toSeq) + withCTEDefs(p, cteDefs.toSeq) } } } @@ -131,7 +129,7 @@ object CTESubstitution extends Rule[LogicalPlan] { plan.resolveOperatorsUp { case UnresolvedWith(child, relations) => val resolvedCTERelations = - resolveCTERelations(relations, isLegacy = true, isCommand = false, Seq.empty, cteDefs) + resolveCTERelations(relations, isLegacy = true, Seq.empty, cteDefs) substituteCTE(child, alwaysInline = true, resolvedCTERelations) } } @@ -168,7 +166,6 @@ object CTESubstitution extends Rule[LogicalPlan] { * SELECT * FROM t * ) * @param plan the plan to be traversed - * @param isCommand if this is a command * @param outerCTEDefs already resolved outer CTE definitions with names * @param cteDefs all accumulated CTE definitions * @return the plan where CTE substitution is applied and optionally the last substituted `With` @@ -176,7 +173,6 @@ object CTESubstitution extends Rule[LogicalPlan] { */ private def traverseAndSubstituteCTE( plan: LogicalPlan, - isCommand: Boolean, outerCTEDefs: Seq[(String, CTERelationDef)], cteDefs: ArrayBuffer[CTERelationDef]): (LogicalPlan, Option[LogicalPlan]) = { var firstSubstituted: Option[LogicalPlan] = None @@ -184,11 +180,11 @@ object CTESubstitution extends Rule[LogicalPlan] { _.containsAnyPattern(UNRESOLVED_WITH, PLAN_EXPRESSION)) { case UnresolvedWith(child: LogicalPlan, relations) => val resolvedCTERelations = - resolveCTERelations(relations, isLegacy = false, isCommand, outerCTEDefs, cteDefs) ++ + resolveCTERelations(relations, isLegacy = false, outerCTEDefs, cteDefs) ++ outerCTEDefs val substituted = substituteCTE( - traverseAndSubstituteCTE(child, isCommand, resolvedCTERelations, cteDefs)._1, - isCommand, + traverseAndSubstituteCTE(child, resolvedCTERelations, cteDefs)._1, + false, resolvedCTERelations) if (firstSubstituted.isEmpty) { firstSubstituted = Some(substituted) @@ -206,10 +202,9 @@ object CTESubstitution extends Rule[LogicalPlan] { private def resolveCTERelations( relations: Seq[(String, SubqueryAlias)], isLegacy: Boolean, - isCommand: Boolean, outerCTEDefs: Seq[(String, CTERelationDef)], cteDefs: ArrayBuffer[CTERelationDef]): Seq[(String, CTERelationDef)] = { - var resolvedCTERelations = if (isLegacy || isCommand) { + var resolvedCTERelations = if (isLegacy) { Seq.empty } else { outerCTEDefs @@ -232,12 +227,12 @@ object CTESubstitution extends Rule[LogicalPlan] { // WITH t3 AS (SELECT * FROM t1) // ) // t3 should resolve the t1 to `SELECT 2` instead of `SELECT 1`. - traverseAndSubstituteCTE(relation, isCommand, resolvedCTERelations, cteDefs)._1 + traverseAndSubstituteCTE(relation, resolvedCTERelations, cteDefs)._1 } // CTE definition can reference a previous one - val substituted = substituteCTE(innerCTEResolved, isLegacy || isCommand, resolvedCTERelations) + val substituted = substituteCTE(innerCTEResolved, isLegacy, resolvedCTERelations) val cteRelation = CTERelationDef(substituted) - if (!(isLegacy || isCommand)) { + if (!(isLegacy)) { cteDefs += cteRelation } // Prepending new CTEs makes sure that those have higher priority over outer ones. @@ -249,7 +244,7 @@ object CTESubstitution extends Rule[LogicalPlan] { private def substituteCTE( plan: LogicalPlan, alwaysInline: Boolean, - cteRelations: Seq[(String, CTERelationDef)]): LogicalPlan = + cteRelations: Seq[(String, CTERelationDef)]): LogicalPlan = { plan.resolveOperatorsUpWithPruning( _.containsAnyPattern(RELATION_TIME_TRAVEL, UNRESOLVED_RELATION, PLAN_EXPRESSION)) { case RelationTimeTravel(UnresolvedRelation(Seq(table), _, _), _, _) @@ -273,4 +268,21 @@ object CTESubstitution extends Rule[LogicalPlan] { e.withNewPlan(apply(substituteCTE(e.plan, alwaysInline, cteRelations))) } } + } + + /** + * Finds all logical nodes that should have `WithCTE` in their children like + * `InsertIntoStatement`, put `WithCTE` on top of the children and don't place `WithCTE` + * on top of the plan. If there are no such nodes, put `WithCTE` on the top. + */ + private def withCTEDefs(p: LogicalPlan, cteDefs: Seq[CTERelationDef]): LogicalPlan = { + val withCTE = WithCTE(p, cteDefs) + var onTop = true + val newPlan = p.resolveOperatorsDown { + case cteInChildren: CTEInChildren => + onTop = false + cteInChildren.withCTE(withCTE) + } + if (onTop) withCTE else WithCTE(newPlan, cteDefs) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index f8ba042009b2..4cf09a9a734a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -677,7 +677,7 @@ case class InsertIntoDir( provider: Option[String], child: LogicalPlan, overwrite: Boolean = true) - extends UnaryNode { + extends UnaryNode with CTEInChildren { override def output: Seq[Attribute] = Seq.empty override def metadataOutput: Seq[Attribute] = Nil @@ -896,6 +896,15 @@ case class WithWindowDefinition( copy(child = newChild) } +/** + * The logical node is able to insert the given `WithCTE` into its children. + */ +trait CTEInChildren extends LogicalPlan { + def withCTE(withCTE: WithCTE): LogicalPlan = { + withNewChildren(children.map(withCTE.withNewPlan)) + } +} + /** * @param order The ordering expressions * @param global True means global sorting apply for entire data set, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala index 669750ee448d..9efc3b13bc26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.types.DataType * Parsed logical plans are located in Catalyst so that as much SQL parsing logic as possible is be * kept in a [[org.apache.spark.sql.catalyst.parser.AbstractSqlParser]]. */ -abstract class ParsedStatement extends LogicalPlan { +abstract class ParsedStatement extends LogicalPlan with CTEInChildren { // Redact properties and options when parsed nodes are used by generic methods like toString override def productIterator: Iterator[Any] = super.productIterator.map { case mapArg: Map[_, _] => conf.redactOptions(mapArg) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 739ffa487e39..0f31f0068195 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -46,7 +46,7 @@ trait KeepAnalyzedQuery extends Command { /** * Base trait for DataSourceV2 write commands */ -trait V2WriteCommand extends UnaryCommand with KeepAnalyzedQuery { +trait V2WriteCommand extends UnaryCommand with KeepAnalyzedQuery with CTEInChildren { def table: NamedRelation def query: LogicalPlan def isByName: Boolean @@ -392,9 +392,18 @@ case class WriteDelta( } } -trait V2CreateTableAsSelectPlan extends V2CreateTablePlan with AnalysisOnlyCommand { +trait V2CreateTableAsSelectPlan + extends V2CreateTablePlan + with AnalysisOnlyCommand + with CTEInChildren { def query: LogicalPlan + override def withCTE(withCTE: WithCTE): LogicalPlan = { + withNameAndQuery( + newName = this.name, + newQuery = withCTE.copy(plan = this.query)) + } + override lazy val resolved: Boolean = childrenResolved && { // the table schema is created from the query schema, so the only resolution needed is to check // that the columns referenced by the table's partitioning exist in the query schema @@ -1234,12 +1243,18 @@ case class RepairTable( case class AlterViewAs( child: LogicalPlan, originalText: String, - query: LogicalPlan) extends BinaryCommand { + query: LogicalPlan) extends BinaryCommand with CTEInChildren { override def left: LogicalPlan = child override def right: LogicalPlan = query override protected def withNewChildrenInternal( newLeft: LogicalPlan, newRight: LogicalPlan): LogicalPlan = copy(child = newLeft, query = newRight) + + override def withCTE(withCTE: WithCTE): LogicalPlan = { + withNewChildrenInternal( + newLeft = this.left, + newRight = withCTE.copy(plan = this.right)) + } } /** @@ -1253,12 +1268,18 @@ case class CreateView( originalText: Option[String], query: LogicalPlan, allowExisting: Boolean, - replace: Boolean) extends BinaryCommand { + replace: Boolean) extends BinaryCommand with CTEInChildren { override def left: LogicalPlan = child override def right: LogicalPlan = query override protected def withNewChildrenInternal( newLeft: LogicalPlan, newRight: LogicalPlan): LogicalPlan = copy(child = newLeft, query = newRight) + + override def withCTE(withCTE: WithCTE): LogicalPlan = { + withNewChildrenInternal( + newLeft = this.left, + newRight = withCTE.copy(plan = this.right)) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala index 338ce8cac420..592ae04a055d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkContext import org.apache.spark.sql.{Row, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryCommand} +import org.apache.spark.sql.catalyst.plans.logical.{CTEInChildren, LogicalPlan, UnaryCommand} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker @@ -35,7 +35,7 @@ import org.apache.spark.util.SerializableConfiguration /** * A special `Command` which writes data out and updates metrics. */ -trait DataWritingCommand extends UnaryCommand { +trait DataWritingCommand extends UnaryCommand with CTEInChildren { /** * The input query plan that produces the data to be written. * IMPORTANT: the input query plan MUST be analyzed, so that we can carry its output columns diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala index 35c8bec37162..0a9064261c7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{CTEInChildren, LogicalPlan, WithCTE} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources._ @@ -42,7 +42,7 @@ case class InsertIntoDataSourceDirCommand( storage: CatalogStorageFormat, provider: String, query: LogicalPlan, - overwrite: Boolean) extends LeafRunnableCommand { + overwrite: Boolean) extends LeafRunnableCommand with CTEInChildren { override def innerChildren: Seq[LogicalPlan] = query :: Nil @@ -76,4 +76,8 @@ case class InsertIntoDataSourceDirCommand( Seq.empty[Row] } + + override def withCTE(withCTE: WithCTE): LogicalPlan = { + copy(query = withCTE.copy(plan = this.query)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 3848d5505155..b1b2fd53c74a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -21,7 +21,7 @@ import java.net.URI import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{CTEInChildren, LogicalPlan, WithCTE} import org.apache.spark.sql.catalyst.util.{removeInternalMetadata, CharVarcharUtils} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.CommandExecutionMode @@ -141,7 +141,7 @@ case class CreateDataSourceTableAsSelectCommand( mode: SaveMode, query: LogicalPlan, outputColumnNames: Seq[String]) - extends LeafRunnableCommand { + extends LeafRunnableCommand with CTEInChildren { assert(query.resolved) override def innerChildren: Seq[LogicalPlan] = query :: Nil @@ -233,4 +233,8 @@ case class CreateDataSourceTableAsSelectCommand( throw ex } } + + override def withCTE(withCTE: WithCTE): LogicalPlan = { + copy(query = withCTE.copy(plan = this.query)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 351f6d5456d8..30fcf6ccdaf5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -735,7 +735,7 @@ case class DescribeTableCommand( * 7. Common table expressions (CTEs) */ case class DescribeQueryCommand(queryText: String, plan: LogicalPlan) - extends DescribeCommandBase { + extends DescribeCommandBase with CTEInChildren { override val output = DescribeCommandSchema.describeTableAttributes() @@ -747,6 +747,10 @@ case class DescribeQueryCommand(queryText: String, plan: LogicalPlan) describeSchema(queryExecution.analyzed.schema, result, header = false) result.toSeq } + + override def withCTE(withCTE: WithCTE): LogicalPlan = { + copy(plan = withCTE.copy(plan = this.plan)) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 3718794ea590..8a12b162f994 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.{SQLConfHelper, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, GlobalTempView, LocalTempView, ViewType} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType, TemporaryViewRelation} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, SubqueryExpression} -import org.apache.spark.sql.catalyst.plans.logical.{AnalysisOnlyCommand, LogicalPlan, Project, View} +import org.apache.spark.sql.catalyst.plans.logical.{AnalysisOnlyCommand, CTEInChildren, LogicalPlan, Project, View, WithCTE} import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.NamespaceHelper import org.apache.spark.sql.errors.QueryCompilationErrors @@ -69,7 +69,7 @@ case class CreateViewCommand( viewType: ViewType, isAnalyzed: Boolean = false, referredTempFunctions: Seq[String] = Seq.empty) - extends RunnableCommand with AnalysisOnlyCommand { + extends RunnableCommand with AnalysisOnlyCommand with CTEInChildren { import ViewHelper._ @@ -215,6 +215,10 @@ case class CreateViewCommand( comment = comment ) } + + override def withCTE(withCTE: WithCTE): LogicalPlan = { + copy(plan = withCTE.copy(plan = this.plan)) + } } /** @@ -235,7 +239,7 @@ case class AlterViewAsCommand( query: LogicalPlan, isAnalyzed: Boolean = false, referredTempFunctions: Seq[String] = Seq.empty) - extends RunnableCommand with AnalysisOnlyCommand { + extends RunnableCommand with AnalysisOnlyCommand with CTEInChildren { import ViewHelper._ @@ -307,6 +311,10 @@ case class AlterViewAsCommand( session.sessionState.catalog.alterTable(updatedViewMeta) } + + override def withCTE(withCTE: WithCTE): LogicalPlan = { + copy(query = withCTE.copy(plan = this.query)) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala index 789b1d714fcb..7cffd6efdb70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{CTEInChildren, LogicalPlan, WithCTE} import org.apache.spark.sql.execution.command.LeafRunnableCommand import org.apache.spark.sql.sources.InsertableRelation @@ -31,7 +31,7 @@ case class InsertIntoDataSourceCommand( logicalRelation: LogicalRelation, query: LogicalPlan, overwrite: Boolean) - extends LeafRunnableCommand { + extends LeafRunnableCommand with CTEInChildren { override def innerChildren: Seq[QueryPlan[_]] = Seq(query) @@ -47,4 +47,8 @@ case class InsertIntoDataSourceCommand( Seq.empty[Row] } + + override def withCTE(withCTE: WithCTE): LogicalPlan = { + copy(query = withCTE.copy(plan = this.query)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index fe6ec094812e..1c98854b81cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogT import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._ import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{CTEInChildren, LogicalPlan, WithCTE} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.SparkPlan @@ -57,7 +57,7 @@ case class InsertIntoHadoopFsRelationCommand( catalogTable: Option[CatalogTable], fileIndex: Option[FileIndex], outputColumnNames: Seq[String]) - extends V1WriteCommand { + extends V1WriteCommand with CTEInChildren { private lazy val parameters = CaseInsensitiveMap(options) @@ -277,4 +277,8 @@ case class InsertIntoHadoopFsRelationCommand( override protected def withNewChildInternal( newChild: LogicalPlan): InsertIntoHadoopFsRelationCommand = copy(query = newChild) + + override def withCTE(withCTE: WithCTE): LogicalPlan = { + withNewChildInternal(withCTE.copy(plan = this.query)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala index 666ae9b5c6f3..2d76e7c3afa4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala @@ -21,7 +21,7 @@ import scala.util.control.NonFatal import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{CTEInChildren, LogicalPlan, WithCTE} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.execution.command.LeafRunnableCommand import org.apache.spark.sql.sources.CreatableRelationProvider @@ -39,7 +39,7 @@ case class SaveIntoDataSourceCommand( query: LogicalPlan, dataSource: CreatableRelationProvider, options: Map[String, String], - mode: SaveMode) extends LeafRunnableCommand { + mode: SaveMode) extends LeafRunnableCommand with CTEInChildren { override def innerChildren: Seq[QueryPlan[_]] = Seq(query) @@ -68,4 +68,8 @@ case class SaveIntoDataSourceCommand( override def clone(): LogicalPlan = { SaveIntoDataSourceCommand(query.clone(), dataSource, options, mode) } + + override def withCTE(withCTE: WithCTE): LogicalPlan = { + copy(query = withCTE.copy(plan = this.query)) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index 4127e7c75d79..5bf04460f522 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -21,7 +21,7 @@ import scala.util.control.NonFatal import org.apache.spark.sql.{Row, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.catalog.CatalogTable -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{CTEInChildren, LogicalPlan, WithCTE} import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.command.{DataWritingCommand, LeafRunnableCommand} @@ -38,7 +38,7 @@ case class CreateHiveTableAsSelectCommand( query: LogicalPlan, outputColumnNames: Seq[String], mode: SaveMode) - extends LeafRunnableCommand { + extends LeafRunnableCommand with CTEInChildren { assert(query.resolved) override def innerChildren: Seq[LogicalPlan] = query :: Nil @@ -111,4 +111,8 @@ case class CreateHiveTableAsSelectCommand( s"[Database: ${tableDesc.database}, " + s"TableName: ${tableDesc.identifier.table}]" } + + override def withCTE(withCTE: WithCTE): LogicalPlan = { + copy(query = withCTE.copy(plan = this.query)) + } }