diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index c21f330be064..378627f320c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} -import org.apache.spark.sql.catalyst.trees.LeafLike +import org.apache.spark.sql.catalyst.trees.{LeafLike, UnaryLike} import org.apache.spark.sql.connector.ExternalCommandRunner import org.apache.spark.sql.execution.{CommandExecutionMode, ExplainMode, LeafExecNode, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.SQLMetric @@ -51,6 +51,8 @@ trait RunnableCommand extends Command { trait LeafRunnableCommand extends RunnableCommand with LeafLike[LogicalPlan] +trait UnaryRunnableCommand extends RunnableCommand with UnaryLike[LogicalPlan] + /** * A physical operator that executes the run method of a `RunnableCommand` and * saves the result to prevent multiple executions. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 3848d5505155..dd748d67ab71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -141,9 +141,13 @@ case class CreateDataSourceTableAsSelectCommand( mode: SaveMode, query: LogicalPlan, outputColumnNames: Seq[String]) - extends LeafRunnableCommand { + extends UnaryRunnableCommand { assert(query.resolved) - override def innerChildren: Seq[LogicalPlan] = query :: Nil + override def child: LogicalPlan = query + + override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = { + copy(query = newChild) + } override def run(sparkSession: SparkSession): Seq[Row] = { assert(table.tableType != CatalogTableType.VIEW) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 03a3acaf526f..e1b60b687cab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -4667,6 +4667,13 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark |SELECT * FROM z |""".stripMargin).collect() } + + test("SPARK-43883: CTAS commands are unary nodes") { + withTable("t") { + val ctasQuery = sql("CREATE TABLE t USING parquet AS SELECT 1") + assert(ctasQuery.logicalPlan.containsChild.size == 1) + } + } } case class Foo(bar: Option[String]) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index 4127e7c75d79..319a9a049069 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -20,11 +20,12 @@ package org.apache.spark.sql.hive.execution import scala.util.control.NonFatal import org.apache.spark.sql.{Row, SaveMode, SparkSession} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.command.{DataWritingCommand, LeafRunnableCommand} +import org.apache.spark.sql.execution.command.{DataWritingCommand, UnaryRunnableCommand} /** * Create table and insert the query result into it. @@ -38,11 +39,14 @@ case class CreateHiveTableAsSelectCommand( query: LogicalPlan, outputColumnNames: Seq[String], mode: SaveMode) - extends LeafRunnableCommand { + extends UnaryRunnableCommand { assert(query.resolved) - override def innerChildren: Seq[LogicalPlan] = query :: Nil + protected val tableIdentifier: TableIdentifier = tableDesc.identifier + override def child: LogicalPlan = query - protected val tableIdentifier = tableDesc.identifier + override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = { + copy(query = newChild) + } override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog