diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index d261b5637bd3..7ed7152692cf 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -81,8 +81,7 @@ singleTableSchema statement : query #statementDefault - | insertStatement #insertStatementDefault - | multiSelectStatement #multiSelectStatementDefault + | ctes? dmlStatementNoWith #dmlStatement | USE db=identifier #use | CREATE database (IF NOT EXISTS)? identifier (COMMENT comment=STRING)? locationSpec? @@ -356,14 +355,14 @@ resource : identifier STRING ; -insertStatement - : (ctes)? insertInto queryTerm queryOrganization #singleInsertQuery - | (ctes)? fromClause multiInsertQueryBody+ #multiInsertQuery +dmlStatementNoWith + : insertInto queryTerm queryOrganization #singleInsertQuery + | fromClause multiInsertQueryBody+ #multiInsertQuery ; queryNoWith : queryTerm queryOrganization #noWithQuery - | fromClause selectStatement #queryWithFrom + | fromClause selectStatement+ #queryWithFrom ; queryOrganization @@ -379,10 +378,6 @@ multiInsertQueryBody : insertInto selectStatement ; -multiSelectStatement - : (ctes)? fromClause selectStatement+ #multiSelect - ; - selectStatement : querySpecification queryOrganization ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 6bb991c24175..68cd1bd2f0a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -117,6 +117,12 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging query.optionalMap(ctx.ctes)(withCTE) } + override def visitDmlStatement(ctx: DmlStatementContext): AnyRef = withOrigin(ctx) { + val dmlStmt = plan(ctx.dmlStatementNoWith) + // Apply CTEs + dmlStmt.optionalMap(ctx.ctes)(withCTE) + } + private def withCTE(ctx: CtesContext, plan: LogicalPlan): LogicalPlan = { val ctes = ctx.namedQuery.asScala.map { nCtx => val namedQuery = visitNamedQuery(nCtx) @@ -129,11 +135,21 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging override def visitQueryWithFrom(ctx: QueryWithFromContext): LogicalPlan = withOrigin(ctx) { val from = visitFromClause(ctx.fromClause) - validate(ctx.selectStatement.querySpecification.fromClause == null, - "Individual select statement can not have FROM cause as its already specified in the" + - " outer query block", ctx) - withQuerySpecification(ctx.selectStatement.querySpecification, from). - optionalMap(ctx.selectStatement.queryOrganization)(withQueryResultClauses) + val selects = ctx.selectStatement.asScala.map { select => + validate(select.querySpecification.fromClause == null, + "This select statement can not have FROM cause as its already specified upfront", + select) + + withQuerySpecification(select.querySpecification, from). + // Add organization statements. + optionalMap(select.queryOrganization)(withQueryResultClauses) + } + // If there are multiple SELECT just UNION them together into one query. + if (selects.length == 1) { + selects.head + } else { + Union(selects) + } } override def visitNoWithQuery(ctx: NoWithQueryContext): LogicalPlan = withOrigin(ctx) { @@ -182,36 +198,11 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging } // If there are multiple INSERTS just UNION them together into one query. - val insertPlan = inserts match { - case Seq(query) => query - case queries => Union(queries) - } - // Apply CTEs - insertPlan.optionalMap(ctx.ctes)(withCTE) - } - - override def visitMultiSelect(ctx: MultiSelectContext): LogicalPlan = withOrigin(ctx) { - val from = visitFromClause(ctx.fromClause) - - // Build the insert clauses. - val selects = ctx.selectStatement.asScala.map { - body => - validate(body.querySpecification.fromClause == null, - "Multi-select queries cannot have a FROM clause in their individual SELECT statements", - body) - - withQuerySpecification(body.querySpecification, from). - // Add organization statements. - optionalMap(body.queryOrganization)(withQueryResultClauses) - } - - // If there are multiple INSERTS just UNION them together into one query. - val selectUnionPlan = selects match { - case Seq(query) => query - case queries => Union(queries) + if (inserts.length == 1) { + inserts.head + } else { + Union(inserts) } - // Apply CTEs - selectUnionPlan.optionalMap(ctx.ctes)(withCTE) } /** @@ -219,10 +210,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging */ override def visitSingleInsertQuery( ctx: SingleInsertQueryContext): LogicalPlan = withOrigin(ctx) { - val insertPlan = withInsertInto(ctx.insertInto(), - plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses)) - // Apply CTEs - insertPlan.optionalMap(ctx.ctes)(withCTE) + withInsertInto( + ctx.insertInto(), + plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses)) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 9208d9358016..00836d352177 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -132,15 +132,19 @@ class PlanParserSuite extends AnalysisTest { table("a").select(star()).union(table("a").where('s < 10).select(star()))) intercept( "from a select * select * from x where a.s < 10", - "Multi-select queries cannot have a FROM clause in their individual SELECT statements") + "This select statement can not have FROM cause as its already specified upfront") intercept( "from a select * from b", - "Individual select statement can not have FROM cause as its already specified in " + - "the outer query block") + "This select statement can not have FROM cause as its already specified upfront") assertEqual( "from a insert into tbl1 select * insert into tbl2 select * where s < 10", table("a").select(star()).insertInto("tbl1").union( table("a").where('s < 10).select(star()).insertInto("tbl2"))) + assertEqual( + "select * from (from a select * select *)", + table("a").select(star()) + .union(table("a").select(star())) + .as("__auto_generated_subquery_name").select(star())) } test("query organization") {