diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala index 4e3234f9c0dc..7b9fd514e745 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions.SubqueryExpression -import org.apache.spark.sql.catalyst.plans.logical.{Command, CTERelationDef, CTERelationRef, InsertIntoDir, LogicalPlan, ParsedStatement, SubqueryAlias, UnresolvedWith, WithCTE} +import org.apache.spark.sql.catalyst.plans.logical.{Command, CTERelationDef, CTERelationRef, InsertIntoDir, InsertIntoStatement, LogicalPlan, ParsedStatement, SubqueryAlias, UnresolvedWith, WithCTE} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util.TypeUtils._ @@ -52,7 +52,13 @@ object CTESubstitution extends Rule[LogicalPlan] { if (!plan.containsPattern(UNRESOLVED_WITH)) { return plan } - val isCommand = plan.exists { + // New plan with CTEs moved to command's query. + val (planWithCTE, isCommandWithCTE) = plan match { + case UnresolvedWith(child: InsertIntoStatement, cteRelations) => + (child.copy(query = UnresolvedWith(child.query, cteRelations)), true) + case _ => (plan, false) + } + val isCommand = !isCommandWithCTE && plan.exists { case _: Command | _: ParsedStatement | _: InsertIntoDir => true case _ => false } @@ -60,12 +66,12 @@ object CTESubstitution extends Rule[LogicalPlan] { val (substituted, firstSubstituted) = LegacyBehaviorPolicy.withName(conf.getConf(LEGACY_CTE_PRECEDENCE_POLICY)) match { case LegacyBehaviorPolicy.EXCEPTION => - assertNoNameConflictsInCTE(plan) - traverseAndSubstituteCTE(plan, isCommand, Seq.empty, cteDefs) + assertNoNameConflictsInCTE(planWithCTE) + traverseAndSubstituteCTE(planWithCTE, isCommand, Seq.empty, cteDefs) case LegacyBehaviorPolicy.LEGACY => - (legacyTraverseAndSubstituteCTE(plan, cteDefs), None) + (legacyTraverseAndSubstituteCTE(planWithCTE, cteDefs), None) case LegacyBehaviorPolicy.CORRECTED => - traverseAndSubstituteCTE(plan, isCommand, Seq.empty, cteDefs) + traverseAndSubstituteCTE(planWithCTE, isCommand, Seq.empty, cteDefs) } if (cteDefs.isEmpty) { substituted diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/with.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/with.sql.out index e53480e96bed..f58f8faa0be3 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/with.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/with.sql.out @@ -452,10 +452,14 @@ with test as (select 42) insert into test select * from test -- !query analysis InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/test, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/test], Append, `spark_catalog`.`default`.`test`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/test), [i] +- Project [cast(42#x as int) AS i#x] - +- Project [42#x] - +- SubqueryAlias test - +- Project [42 AS 42#x] - +- OneRowRelation + +- WithCTE + :- CTERelationDef xxxx, false + : +- SubqueryAlias test + : +- Project [42 AS 42#x] + : +- OneRowRelation + +- Project [42#x] + +- SubqueryAlias test + +- CTERelationRef xxxx, true, [42#x] -- !query diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 48bdd799017c..54eb4f4b2df1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -2528,6 +2528,14 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { } } + test("SPARK-44356: CTE on top of INSERT INTO") { + withTable("t") { + sql("CREATE TABLE t(i int, part1 int, part2 int) using parquet") + sql("WITH v1(c1) as (values (1)) INSERT INTO t select c1, 2, 3 from v1") + checkAnswer(spark.table("t"), Row(1, 2, 3)) + } + } + test("SELECT clause with star wildcard") { withTable("t1") { sql("CREATE TABLE t1(c1 int, c2 string) using parquet")