From 5da54d261367dc19bc55911df9e2cc2e0dd21eb8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 9 Dec 2019 22:21:36 +0800 Subject: [PATCH 1/6] support column position in DS v2 --- .../spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../sql/connector/catalog/IdentifierImpl.java | 21 +-- .../sql/connector/catalog/TableChange.java | 169 ++++++++++++++++-- .../catalyst/analysis/ResolveCatalogs.scala | 23 ++- .../sql/catalyst/parser/AstBuilder.scala | 28 +-- .../catalyst/plans/logical/statements.scala | 10 +- .../catalog/CatalogV2Implicits.scala | 14 +- .../sql/connector/catalog/CatalogV2Util.scala | 7 + .../sql/catalyst/parser/DDLParserSuite.scala | 71 +++++--- .../analysis/ResolveSessionCatalog.scala | 27 ++- .../spark/sql/connector/AlterTableTests.scala | 30 ++++ .../sql/connector/DataSourceV2SQLSuite.scala | 4 +- .../sql/execution/command/DDLSuite.scala | 10 ++ 13 files changed, 331 insertions(+), 85 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index cd9748eaa6f28..9b8c4a52fab68 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -855,7 +855,7 @@ intervalUnit ; colPosition - : FIRST | AFTER multipartIdentifier + : position=FIRST | position=AFTER multipartIdentifier ; dataType diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/IdentifierImpl.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/IdentifierImpl.java index 56d13ef742cea..cfc71bbb78727 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/IdentifierImpl.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/IdentifierImpl.java @@ -17,14 +17,14 @@ package org.apache.spark.sql.connector.catalog; -import com.google.common.base.Preconditions; -import org.apache.spark.annotation.Experimental; - import java.util.Arrays; import java.util.Objects; -import java.util.stream.Collectors; import java.util.stream.Stream; +import com.google.common.base.Preconditions; + +import org.apache.spark.annotation.Experimental; + /** * An {@link Identifier} implementation. */ @@ -51,19 +51,10 @@ public String name() { return name; } - private String escapeQuote(String part) { - if (part.contains("`")) { - return part.replace("`", "``"); - } else { - return part; - } - } - @Override public String toString() { - return Stream.concat(Stream.of(namespace), Stream.of(name)) - .map(part -> '`' + escapeQuote(part) + '`') - .collect(Collectors.joining(".")); + return CatalogV2Implicits.quoteNameParts(Stream.concat( + Stream.of(namespace), Stream.of(name)).toArray(String[]::new)); } @Override diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java index 20c22388b0ef9..5012deb15fef1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java @@ -17,11 +17,12 @@ package org.apache.spark.sql.connector.catalog; -import org.apache.spark.annotation.Experimental; -import org.apache.spark.sql.types.DataType; - import java.util.Arrays; import java.util.Objects; +import javax.annotation.Nullable; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.types.DataType; /** * TableChange subclasses represent requested changes to a table. These are passed to @@ -76,7 +77,7 @@ static TableChange removeProperty(String property) { * @return a TableChange for the addition */ static TableChange addColumn(String[] fieldNames, DataType dataType) { - return new AddColumn(fieldNames, dataType, true, null); + return new AddColumn(fieldNames, dataType, true, null, null); } /** @@ -92,7 +93,7 @@ static TableChange addColumn(String[] fieldNames, DataType dataType) { * @return a TableChange for the addition */ static TableChange addColumn(String[] fieldNames, DataType dataType, boolean isNullable) { - return new AddColumn(fieldNames, dataType, isNullable, null); + return new AddColumn(fieldNames, dataType, isNullable, null, null); } /** @@ -113,7 +114,30 @@ static TableChange addColumn( DataType dataType, boolean isNullable, String comment) { - return new AddColumn(fieldNames, dataType, isNullable, comment); + return new AddColumn(fieldNames, dataType, isNullable, comment, null); + } + + /** + * Create a TableChange for adding a column. + *

+ * If the field already exists, the change will result in an {@link IllegalArgumentException}. + * If the new field is nested and its parent does not exist or is not a struct, the change will + * result in an {@link IllegalArgumentException}. + * + * @param fieldNames field names of the new column + * @param dataType the new column's data type + * @param isNullable whether the new column can contain null + * @param comment the new field's comment string + * @param position the new columns's position + * @return a TableChange for the addition + */ + static TableChange addColumn( + String[] fieldNames, + DataType dataType, + boolean isNullable, + String comment, + ColumnPosition position) { + return new AddColumn(fieldNames, dataType, isNullable, comment, position); } /** @@ -180,6 +204,21 @@ static TableChange updateColumnComment(String[] fieldNames, String newComment) { return new UpdateColumnComment(fieldNames, newComment); } + /** + * Create a TableChange for updating the position of a field. + *

+ * The name is used to find the field to update. + *

+ * If the field does not exist, the change will result in an {@link IllegalArgumentException}. + * + * @param fieldNames field names of the column to update + * @param newPosition the new position + * @return a TableChange for the update + */ + static TableChange updateColumnPosition(String[] fieldNames, ColumnPosition newPosition) { + return new UpdateColumnPosition(fieldNames, newPosition); + } + /** * Create a TableChange for deleting a field. *

@@ -259,6 +298,60 @@ public int hashCode() { } } + interface ColumnPosition { + First FIRST = new First(); + + static ColumnPosition After(String[] column) { + return new After(column); + } + } + + /** + * Column position FIRST means the specified column should be the first column. + */ + final class First implements ColumnPosition { + private First() {} + + @Override + public String toString() { + return "FIRST"; + } + } + + /** + * Column position AFTER means the specified column should be put after the given `column`. + */ + final class After implements ColumnPosition { + private final String[] column; + + private After(String[] column) { + assert column != null; + this.column = column; + } + + public String[] getColumn() { + return column; + } + + @Override + public String toString() { + return "AFTER " + CatalogV2Implicits.quoteNameParts(column); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + After after = (After) o; + return Arrays.equals(column, after.column); + } + + @Override + public int hashCode() { + return Arrays.hashCode(column); + } + } + interface ColumnChange extends TableChange { String[] fieldNames(); } @@ -275,12 +368,19 @@ final class AddColumn implements ColumnChange { private final DataType dataType; private final boolean isNullable; private final String comment; - - private AddColumn(String[] fieldNames, DataType dataType, boolean isNullable, String comment) { + private final ColumnPosition position; + + private AddColumn( + String[] fieldNames, + DataType dataType, + boolean isNullable, + String comment, + ColumnPosition position) { this.fieldNames = fieldNames; this.dataType = dataType; this.isNullable = isNullable; this.comment = comment; + this.position = position; } @Override @@ -296,10 +396,16 @@ public boolean isNullable() { return isNullable; } + @Nullable public String comment() { return comment; } + @Nullable + public ColumnPosition position() { + return position; + } + @Override public boolean equals(Object o) { if (this == o) return true; @@ -308,12 +414,13 @@ public boolean equals(Object o) { return isNullable == addColumn.isNullable && Arrays.equals(fieldNames, addColumn.fieldNames) && dataType.equals(addColumn.dataType) && - comment.equals(addColumn.comment); + Objects.equals(comment, addColumn.comment) && + Objects.equals(position, addColumn.position); } @Override public int hashCode() { - int result = Objects.hash(dataType, isNullable, comment); + int result = Objects.hash(dataType, isNullable, comment, position); result = 31 * result + Arrays.hashCode(fieldNames); return result; } @@ -453,6 +560,48 @@ public int hashCode() { } } + /** + * A TableChange to update the position of a field. + *

+ * The field names are used to find the field to update. + *

+ * If the field does not exist, the change must result in an {@link IllegalArgumentException}. + */ + final class UpdateColumnPosition implements ColumnChange { + private final String[] fieldNames; + private final ColumnPosition position; + + private UpdateColumnPosition(String[] fieldNames, ColumnPosition position) { + this.fieldNames = fieldNames; + this.position = position; + } + + @Override + public String[] fieldNames() { + return fieldNames; + } + + public ColumnPosition position() { + return position; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + UpdateColumnPosition that = (UpdateColumnPosition) o; + return Arrays.equals(fieldNames, that.fieldNames) && + position.equals(that.position); + } + + @Override + public int hashCode() { + int result = Objects.hash(position); + result = 31 * result + Arrays.hashCode(fieldNames); + return result; + } + } + /** * A TableChange to delete a field. *

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala index 8183aa36a5b90..af5a04582edec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala @@ -35,19 +35,32 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case AlterTableAddColumnsStatement( nameParts @ NonSessionCatalogAndTable(catalog, tbl), cols) => val changes = cols.map { col => - TableChange.addColumn(col.name.toArray, col.dataType, true, col.comment.orNull) + TableChange.addColumn( + col.name.toArray, + col.dataType, + true, + col.comment.orNull, + col.position.orNull) } createAlterTable(nameParts, catalog, tbl, changes) case AlterTableAlterColumnStatement( - nameParts @ NonSessionCatalogAndTable(catalog, tbl), colName, dataType, comment) => + nameParts @ NonSessionCatalogAndTable(catalog, tbl), colName, dataType, comment, pos) => + val colNameArray = colName.toArray val typeChange = dataType.map { newDataType => - TableChange.updateColumnType(colName.toArray, newDataType, true) + TableChange.updateColumnType(colNameArray, newDataType, true) } val commentChange = comment.map { newComment => - TableChange.updateColumnComment(colName.toArray, newComment) + TableChange.updateColumnComment(colNameArray, newComment) } - createAlterTable(nameParts, catalog, tbl, typeChange.toSeq ++ commentChange) + val positionChange = pos.map { newPosition => + TableChange.updateColumnPosition(colNameArray, newPosition) + } + createAlterTable( + nameParts, + catalog, + tbl, + typeChange.toSeq ++ commentChange ++ positionChange) case AlterTableRenameColumnStatement( nameParts @ NonSessionCatalogAndTable(catalog, tbl), col, newName) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 1beadc5e37801..09e0dc6de04c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.catalyst.util.IntervalUtils.IntervalUnit import org.apache.spark.sql.connector.catalog.SupportsNamespaces +import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, Expression => V2Expression, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -2803,19 +2804,24 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec)) } + override def visitColPosition(ctx: ColPositionContext): ColumnPosition = { + ctx.position.getType match { + case SqlBaseParser.FIRST => ColumnPosition.FIRST + case SqlBaseParser.AFTER => + ColumnPosition.After(typedVisit[Seq[String]](ctx.multipartIdentifier).toArray) + } + } + /** * Parse new column info from ADD COLUMN into a QualifiedColType. */ override def visitQualifiedColTypeWithPosition( ctx: QualifiedColTypeWithPositionContext): QualifiedColType = withOrigin(ctx) { - if (ctx.colPosition != null) { - operationNotAllowed("ALTER TABLE table ADD COLUMN ... FIRST | AFTER otherCol", ctx) - } - QualifiedColType( typedVisit[Seq[String]](ctx.name), typedVisit[DataType](ctx.dataType), - Option(ctx.comment).map(string)) + Option(ctx.comment).map(string), + Option(ctx.colPosition).map(typedVisit[ColumnPosition])) } /** @@ -2863,19 +2869,17 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging override def visitAlterTableColumn( ctx: AlterTableColumnContext): LogicalPlan = withOrigin(ctx) { val verb = if (ctx.CHANGE != null) "CHANGE" else "ALTER" - if (ctx.colPosition != null) { - operationNotAllowed(s"ALTER TABLE table $verb COLUMN ... FIRST | AFTER otherCol", ctx) - } - - if (ctx.dataType == null && ctx.comment == null) { - operationNotAllowed(s"ALTER TABLE table $verb COLUMN requires a TYPE or a COMMENT", ctx) + if (ctx.dataType == null && ctx.comment == null && ctx.colPosition == null) { + operationNotAllowed( + s"ALTER TABLE table $verb COLUMN requires a TYPE or a COMMENT or a FIRST/AFTER", ctx) } AlterTableAlterColumnStatement( visitMultipartIdentifier(ctx.table), typedVisit[Seq[String]](ctx.column), Option(ctx.dataType).map(typedVisit[DataType]), - Option(ctx.comment).map(string)) + Option(ctx.comment).map(string), + Option(ctx.colPosition).map(typedVisit[ColumnPosition])) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala index 13356bfd04ffd..a818cc441ec2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.ViewType import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.types.{DataType, StructType} @@ -141,7 +142,11 @@ case class ReplaceTableAsSelectStatement( /** * Column data as parsed by ALTER TABLE ... ADD COLUMNS. */ -case class QualifiedColType(name: Seq[String], dataType: DataType, comment: Option[String]) +case class QualifiedColType( + name: Seq[String], + dataType: DataType, + comment: Option[String], + position: Option[ColumnPosition]) /** * ALTER TABLE ... ADD COLUMNS command, as parsed from SQL. @@ -157,7 +162,8 @@ case class AlterTableAlterColumnStatement( tableName: Seq[String], column: Seq[String], dataType: Option[DataType], - comment: Option[String]) extends ParsedStatement + comment: Option[String], + position: Option[ColumnPosition]) extends ParsedStatement /** * ALTER TABLE ... RENAME COLUMN command, as parsed from SQL. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala index 882e968f34b59..ac7bc93bdf480 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala @@ -89,13 +89,7 @@ private[sql] object CatalogV2Implicits { } implicit class IdentifierHelper(ident: Identifier) { - def quoted: String = { - if (ident.namespace.nonEmpty) { - ident.namespace.map(quote).mkString(".") + "." + quote(ident.name) - } else { - quote(ident.name) - } - } + def quoted: String = ident.toString def asMultipartIdentifier: Seq[String] = ident.namespace :+ ident.name } @@ -115,7 +109,11 @@ private[sql] object CatalogV2Implicits { s"$quoted is not a valid TableIdentifier as it has more than 2 name parts.") } - def quoted: String = parts.map(quote).mkString(".") + def quoted: String = quoteNameParts(parts.toArray) + } + + def quoteNameParts(nameParts: Array[String]): String = { + nameParts.map(quote).mkString(".") } private def quote(part: String): String = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index 0dcd595ded191..576f40ab1f7b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -102,6 +102,10 @@ private[sql] object CatalogV2Util { changes.foldLeft(schema) { (schema, change) => change match { case add: AddColumn => + if (add.position != null) { + throw new UnsupportedOperationException("column position is not supported yet.") + } + add.fieldNames match { case Array(name) => val newField = StructField(name, add.dataType, nullable = add.isNullable) @@ -147,6 +151,9 @@ private[sql] object CatalogV2Util { replace(schema, update.fieldNames, field => Some(field.withComment(update.newComment))) + case _: UpdateColumnPosition => + throw new UnsupportedOperationException("column position is not supported yet.") + case delete: DeleteColumn => replace(schema, delete.fieldNames, _ => None) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index b0d9a00d653ce..439e35d05480d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -24,6 +24,8 @@ import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, GlobalTempView, Loc import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions.{EqualTo, Literal} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition.After +import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition.FIRST import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType, TimestampType} @@ -492,7 +494,7 @@ class DDLParserSuite extends AnalysisTest { comparePlans( parsePlan("ALTER TABLE table_name ADD COLUMN x int"), AlterTableAddColumnsStatement(Seq("table_name"), Seq( - QualifiedColType(Seq("x"), IntegerType, None) + QualifiedColType(Seq("x"), IntegerType, None, None) ))) } @@ -500,8 +502,8 @@ class DDLParserSuite extends AnalysisTest { comparePlans( parsePlan("ALTER TABLE table_name ADD COLUMNS x int, y string"), AlterTableAddColumnsStatement(Seq("table_name"), Seq( - QualifiedColType(Seq("x"), IntegerType, None), - QualifiedColType(Seq("y"), StringType, None) + QualifiedColType(Seq("x"), IntegerType, None, None), + QualifiedColType(Seq("y"), StringType, None, None) ))) } @@ -509,7 +511,7 @@ class DDLParserSuite extends AnalysisTest { comparePlans( parsePlan("ALTER TABLE table_name ADD COLUMNS x int"), AlterTableAddColumnsStatement(Seq("table_name"), Seq( - QualifiedColType(Seq("x"), IntegerType, None) + QualifiedColType(Seq("x"), IntegerType, None, None) ))) } @@ -517,7 +519,7 @@ class DDLParserSuite extends AnalysisTest { comparePlans( parsePlan("ALTER TABLE table_name ADD COLUMNS (x int)"), AlterTableAddColumnsStatement(Seq("table_name"), Seq( - QualifiedColType(Seq("x"), IntegerType, None) + QualifiedColType(Seq("x"), IntegerType, None, None) ))) } @@ -525,7 +527,7 @@ class DDLParserSuite extends AnalysisTest { comparePlans( parsePlan("ALTER TABLE table_name ADD COLUMNS (x int COMMENT 'doc')"), AlterTableAddColumnsStatement(Seq("table_name"), Seq( - QualifiedColType(Seq("x"), IntegerType, Some("doc")) + QualifiedColType(Seq("x"), IntegerType, Some("doc"), None) ))) } @@ -533,7 +535,21 @@ class DDLParserSuite extends AnalysisTest { comparePlans( parsePlan("ALTER TABLE table_name ADD COLUMN x int COMMENT 'doc'"), AlterTableAddColumnsStatement(Seq("table_name"), Seq( - QualifiedColType(Seq("x"), IntegerType, Some("doc")) + QualifiedColType(Seq("x"), IntegerType, Some("doc"), None) + ))) + } + + test("alter table: add column with position") { + comparePlans( + parsePlan("ALTER TABLE table_name ADD COLUMN x int FIRST"), + AlterTableAddColumnsStatement(Seq("table_name"), Seq( + QualifiedColType(Seq("x"), IntegerType, None, Some(FIRST)) + ))) + + comparePlans( + parsePlan("ALTER TABLE table_name ADD COLUMN x int AFTER y"), + AlterTableAddColumnsStatement(Seq("table_name"), Seq( + QualifiedColType(Seq("x"), IntegerType, None, Some(After(Array("y")))) ))) } @@ -541,25 +557,19 @@ class DDLParserSuite extends AnalysisTest { comparePlans( parsePlan("ALTER TABLE table_name ADD COLUMN x.y.z int COMMENT 'doc'"), AlterTableAddColumnsStatement(Seq("table_name"), Seq( - QualifiedColType(Seq("x", "y", "z"), IntegerType, Some("doc")) + QualifiedColType(Seq("x", "y", "z"), IntegerType, Some("doc"), None) ))) } test("alter table: add multiple columns with nested column name") { comparePlans( - parsePlan("ALTER TABLE table_name ADD COLUMN x.y.z int COMMENT 'doc', a.b string"), + parsePlan("ALTER TABLE table_name ADD COLUMN x.y.z int COMMENT 'doc', a.b string FIRST"), AlterTableAddColumnsStatement(Seq("table_name"), Seq( - QualifiedColType(Seq("x", "y", "z"), IntegerType, Some("doc")), - QualifiedColType(Seq("a", "b"), StringType, None) + QualifiedColType(Seq("x", "y", "z"), IntegerType, Some("doc"), None), + QualifiedColType(Seq("a", "b"), StringType, None, Some(FIRST)) ))) } - test("alter table: add column at position (not supported)") { - assertUnsupported("ALTER TABLE table_name ADD COLUMNS name bigint COMMENT 'doc' FIRST, a.b int") - assertUnsupported("ALTER TABLE table_name ADD COLUMN name bigint COMMENT 'doc' FIRST") - assertUnsupported("ALTER TABLE table_name ADD COLUMN name string AFTER a.b") - } - test("alter table: set location") { comparePlans( parsePlan("ALTER TABLE a.b.c SET LOCATION 'new location'"), @@ -589,6 +599,7 @@ class DDLParserSuite extends AnalysisTest { Seq("table_name"), Seq("a", "b", "c"), Some(LongType), + None, None)) } @@ -599,6 +610,7 @@ class DDLParserSuite extends AnalysisTest { Seq("table_name"), Seq("a", "b", "c"), Some(LongType), + None, None)) } @@ -609,22 +621,31 @@ class DDLParserSuite extends AnalysisTest { Seq("table_name"), Seq("a", "b", "c"), None, - Some("new comment"))) + Some("new comment"), + None)) } - test("alter table: update column type and comment") { + test("alter table: update column position") { comparePlans( - parsePlan("ALTER TABLE table_name CHANGE COLUMN a.b.c TYPE bigint COMMENT 'new comment'"), + parsePlan("ALTER TABLE table_name CHANGE COLUMN a.b.c FIRST"), AlterTableAlterColumnStatement( Seq("table_name"), Seq("a", "b", "c"), - Some(LongType), - Some("new comment"))) + None, + None, + Some(FIRST))) } - test("alter table: change column position (not supported)") { - assertUnsupported("ALTER TABLE table_name CHANGE COLUMN name COMMENT 'doc' FIRST") - assertUnsupported("ALTER TABLE table_name CHANGE COLUMN name TYPE INT AFTER other_col") + test("alter table: update column type, comment and position") { + comparePlans( + parsePlan("ALTER TABLE table_name CHANGE COLUMN a.b.c " + + "TYPE bigint COMMENT 'new comment' AFTER x.y"), + AlterTableAlterColumnStatement( + Seq("table_name"), + Seq("a", "b", "c"), + Some(LongType), + Some("new comment"), + Some(After(Array("x", "y"))))) } test("alter table: drop column") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 53eb7dae2ca0a..1da4ec92ac2d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -55,13 +55,18 @@ class ResolveSessionCatalog( AlterTableAddColumnsCommand(tbl.asTableIdentifier, cols.map(convertToStructField)) }.getOrElse { val changes = cols.map { col => - TableChange.addColumn(col.name.toArray, col.dataType, true, col.comment.orNull) + TableChange.addColumn( + col.name.toArray, + col.dataType, + true, + col.comment.orNull, + col.position.orNull) } createAlterTable(nameParts, catalog, tbl, changes) } case AlterTableAlterColumnStatement( - nameParts @ SessionCatalogAndTable(catalog, tbl), colName, dataType, comment) => + nameParts @ SessionCatalogAndTable(catalog, tbl), colName, dataType, comment, pos) => loadTable(catalog, tbl.asIdentifier).collect { case v1Table: V1Table => if (colName.length > 1) { @@ -72,6 +77,10 @@ class ResolveSessionCatalog( throw new AnalysisException( "ALTER COLUMN with v1 tables must specify new data type.") } + if (pos.isDefined) { + throw new AnalysisException("" + + "ALTER COLUMN ... FIRST | ALTER is only supported with v2 tables.") + } val builder = new MetadataBuilder // Add comment to metadata comment.map(c => builder.putString("comment", c)) @@ -87,13 +96,21 @@ class ResolveSessionCatalog( builder.build()) AlterTableChangeColumnCommand(tbl.asTableIdentifier, colName(0), newColumn) }.getOrElse { + val nameParts = colName.toArray val typeChange = dataType.map { newDataType => - TableChange.updateColumnType(colName.toArray, newDataType, true) + TableChange.updateColumnType(nameParts, newDataType, true) } val commentChange = comment.map { newComment => - TableChange.updateColumnComment(colName.toArray, newComment) + TableChange.updateColumnComment(nameParts, newComment) + } + val positionChange = pos.map { newPosition => + TableChange.updateColumnPosition(nameParts, newPosition) } - createAlterTable(nameParts, catalog, tbl, typeChange.toSeq ++ commentChange) + createAlterTable( + nameParts, + catalog, + tbl, + typeChange.toSeq ++ commentChange ++ positionChange) } case AlterTableRenameColumnStatement( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala index 7392850f276cc..8882274b040fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala @@ -101,6 +101,21 @@ trait AlterTableTests extends SharedSparkSession { } } + test("AlterTable: add column with position") { + val t = s"${catalogAndNamespace}table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int) USING $v2Format") + val e1 = intercept[UnsupportedOperationException] { + sql(s"ALTER TABLE $t ADD COLUMN data string FIRST") + } + assert(e1.getMessage.contains("column position is not supported")) + val e2 = intercept[UnsupportedOperationException] { + sql(s"ALTER TABLE $t ADD COLUMN data string AFTER id") + } + assert(e2.getMessage.contains("column position is not supported")) + } + } + test("AlterTable: add multiple columns") { val t = s"${catalogAndNamespace}table_name" withTable(t) { @@ -471,6 +486,21 @@ trait AlterTableTests extends SharedSparkSession { } } + test("AlterTable: update column position") { + val t = s"${catalogAndNamespace}table_name" + withTable(t) { + sql(s"CREATE TABLE $t (a int, b int) USING $v2Format") + val e1 = intercept[UnsupportedOperationException] { + sql(s"ALTER TABLE $t ALTER COLUMN b FIRST") + } + assert(e1.getMessage.contains("column position is not supported")) + val e2 = intercept[UnsupportedOperationException] { + sql(s"ALTER TABLE $t ALTER COLUMN a AFTER b") + } + assert(e2.getMessage.contains("column position is not supported")) + } + } + test("AlterTable: update column type and comment") { val t = s"${catalogAndNamespace}table_name" withTable(t) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 6675636c0e62f..60e6c018a3b66 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -1191,7 +1191,7 @@ class DataSourceV2SQLSuite } test("tableCreation: duplicate column names in the table definition") { - val errorMsg = "Found duplicate column(s) in the table definition of `t`" + val errorMsg = "Found duplicate column(s) in the table definition of t" Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { assertAnalysisError( @@ -1215,7 +1215,7 @@ class DataSourceV2SQLSuite } test("tableCreation: duplicate nested column names in the table definition") { - val errorMsg = "Found duplicate column(s) in the table definition of `t`" + val errorMsg = "Found duplicate column(s) in the table definition of t" Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { assertAnalysisError( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 10873678e05f2..2bb121b27e7d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -181,6 +181,16 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSparkSession { assert(e.contains("Hive built-in ORC data source must be used with Hive support enabled")) } } + + test("ALTER TABLE ALTER COLUMN with position is not supported") { + withTable("t") { + sql("CREATE TABLE t(i INT) USING parquet") + val e = intercept[AnalysisException] { + sql("ALTER TABLE t ALTER COLUMN i TYPE INT FIRST") + } + assert(e.message.contains("ALTER COLUMN ... FIRST | ALTER is only supported with v2 tables")) + } + } } abstract class DDLSuite extends QueryTest with SQLTestUtils { From 43398af2bf8134c7540329df707516e70705989e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 11 Dec 2019 15:02:33 +0800 Subject: [PATCH 2/6] address comments --- .../spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../sql/connector/catalog/TableChange.java | 18 +++++++++++------- .../catalyst/analysis/ResolveCatalogs.scala | 2 +- .../spark/sql/catalyst/parser/AstBuilder.scala | 3 +-- .../sql/catalyst/parser/DDLParserSuite.scala | 9 ++++----- 5 files changed, 18 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 9b8c4a52fab68..c102dd251e34a 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -855,7 +855,7 @@ intervalUnit ; colPosition - : position=FIRST | position=AFTER multipartIdentifier + : position=FIRST | position=AFTER afterCol=errorCapturingIdentifier ; dataType diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java index 5012deb15fef1..2f0f6a89cc734 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java @@ -301,13 +301,15 @@ public int hashCode() { interface ColumnPosition { First FIRST = new First(); - static ColumnPosition After(String[] column) { + static ColumnPosition createAfter(String column) { return new After(column); } } /** * Column position FIRST means the specified column should be the first column. + * Note that, the specified column may be a nested field, and then FIRST means this field should + * be the first one within the struct. */ final class First implements ColumnPosition { private First() {} @@ -320,22 +322,24 @@ public String toString() { /** * Column position AFTER means the specified column should be put after the given `column`. + * Note that, the specified column may be a nested field, and then the given `column` refers to + * a nested field in the same struct. */ final class After implements ColumnPosition { - private final String[] column; + private final String column; - private After(String[] column) { + private After(String column) { assert column != null; this.column = column; } - public String[] getColumn() { + public String getColumn() { return column; } @Override public String toString() { - return "AFTER " + CatalogV2Implicits.quoteNameParts(column); + return "AFTER " + column; } @Override @@ -343,12 +347,12 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; After after = (After) o; - return Arrays.equals(column, after.column); + return column.equals(after.column); } @Override public int hashCode() { - return Arrays.hashCode(column); + return Objects.hash(column); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala index af5a04582edec..3361173c9962f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala @@ -45,7 +45,7 @@ class ResolveCatalogs(val catalogManager: CatalogManager) createAlterTable(nameParts, catalog, tbl, changes) case AlterTableAlterColumnStatement( - nameParts @ NonSessionCatalogAndTable(catalog, tbl), colName, dataType, comment, pos) => + nameParts @ NonSessionCatalogAndTable(catalog, tbl), colName, dataType, comment, pos) => val colNameArray = colName.toArray val typeChange = dataType.map { newDataType => TableChange.updateColumnType(colNameArray, newDataType, true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 09e0dc6de04c7..efea5378bae61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2807,8 +2807,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging override def visitColPosition(ctx: ColPositionContext): ColumnPosition = { ctx.position.getType match { case SqlBaseParser.FIRST => ColumnPosition.FIRST - case SqlBaseParser.AFTER => - ColumnPosition.After(typedVisit[Seq[String]](ctx.multipartIdentifier).toArray) + case SqlBaseParser.AFTER => ColumnPosition.createAfter(ctx.afterCol.getText) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 439e35d05480d..9b5793cbc14ba 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -24,8 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, GlobalTempView, Loc import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions.{EqualTo, Literal} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition.After -import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition.FIRST +import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition.{createAfter, FIRST} import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType, TimestampType} @@ -549,7 +548,7 @@ class DDLParserSuite extends AnalysisTest { comparePlans( parsePlan("ALTER TABLE table_name ADD COLUMN x int AFTER y"), AlterTableAddColumnsStatement(Seq("table_name"), Seq( - QualifiedColType(Seq("x"), IntegerType, None, Some(After(Array("y")))) + QualifiedColType(Seq("x"), IntegerType, None, Some(createAfter("y"))) ))) } @@ -639,13 +638,13 @@ class DDLParserSuite extends AnalysisTest { test("alter table: update column type, comment and position") { comparePlans( parsePlan("ALTER TABLE table_name CHANGE COLUMN a.b.c " + - "TYPE bigint COMMENT 'new comment' AFTER x.y"), + "TYPE bigint COMMENT 'new comment' AFTER d"), AlterTableAlterColumnStatement( Seq("table_name"), Seq("a", "b", "c"), Some(LongType), Some("new comment"), - Some(After(Array("x", "y"))))) + Some(createAfter("d")))) } test("alter table: drop column") { From af271ee86424c60d4318a88d30450b4b1205983d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 12 Dec 2019 02:19:09 +0800 Subject: [PATCH 3/6] address comment --- .../org/apache/spark/sql/connector/catalog/TableChange.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java index 2f0f6a89cc734..674f6234709fd 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java @@ -323,7 +323,7 @@ public String toString() { /** * Column position AFTER means the specified column should be put after the given `column`. * Note that, the specified column may be a nested field, and then the given `column` refers to - * a nested field in the same struct. + * a field in the same struct. */ final class After implements ColumnPosition { private final String column; From 8d865ab337dc42fb1fa44dfe8ed7683b49136a5e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 12 Dec 2019 15:21:22 +0800 Subject: [PATCH 4/6] address comments --- .../sql/connector/catalog/IdentifierImpl.java | 6 ++++-- .../spark/sql/connector/catalog/TableChange.java | 11 ++++++++--- .../spark/sql/catalyst/parser/AstBuilder.scala | 4 ++-- .../connector/catalog/CatalogV2Implicits.scala | 16 +++++++++------- .../sql/catalyst/parser/DDLParserSuite.scala | 12 ++++++------ 5 files changed, 29 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/IdentifierImpl.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/IdentifierImpl.java index cfc71bbb78727..a56007b2a5ab8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/IdentifierImpl.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/IdentifierImpl.java @@ -19,6 +19,7 @@ import java.util.Arrays; import java.util.Objects; +import java.util.stream.Collectors; import java.util.stream.Stream; import com.google.common.base.Preconditions; @@ -53,8 +54,9 @@ public String name() { @Override public String toString() { - return CatalogV2Implicits.quoteNameParts(Stream.concat( - Stream.of(namespace), Stream.of(name)).toArray(String[]::new)); + return Stream.concat(Stream.of(namespace), Stream.of(name)) + .map(CatalogV2Implicits::quote) + .collect(Collectors.joining(".")); } @Override diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java index 674f6234709fd..d2a1c455983b9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java @@ -299,9 +299,12 @@ public int hashCode() { } interface ColumnPosition { - First FIRST = new First(); - static ColumnPosition createAfter(String column) { + static ColumnPosition first() { + return First.singleton; + } + + static ColumnPosition after(String column) { return new After(column); } } @@ -312,6 +315,8 @@ static ColumnPosition createAfter(String column) { * be the first one within the struct. */ final class First implements ColumnPosition { + private static First singleton = new First(); + private First() {} @Override @@ -333,7 +338,7 @@ private After(String column) { this.column = column; } - public String getColumn() { + public String column() { return column; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index efea5378bae61..8e51f65144042 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2806,8 +2806,8 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging override def visitColPosition(ctx: ColPositionContext): ColumnPosition = { ctx.position.getType match { - case SqlBaseParser.FIRST => ColumnPosition.FIRST - case SqlBaseParser.AFTER => ColumnPosition.createAfter(ctx.afterCol.getText) + case SqlBaseParser.FIRST => ColumnPosition.first() + case SqlBaseParser.AFTER => ColumnPosition.after(ctx.afterCol.getText) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala index ac7bc93bdf480..86e5894b369aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala @@ -89,7 +89,13 @@ private[sql] object CatalogV2Implicits { } implicit class IdentifierHelper(ident: Identifier) { - def quoted: String = ident.toString + def quoted: String = { + if (ident.namespace.nonEmpty) { + ident.namespace.map(quote).mkString(".") + "." + quote(ident.name) + } else { + quote(ident.name) + } + } def asMultipartIdentifier: Seq[String] = ident.namespace :+ ident.name } @@ -109,14 +115,10 @@ private[sql] object CatalogV2Implicits { s"$quoted is not a valid TableIdentifier as it has more than 2 name parts.") } - def quoted: String = quoteNameParts(parts.toArray) - } - - def quoteNameParts(nameParts: Array[String]): String = { - nameParts.map(quote).mkString(".") + def quoted: String = parts.map(quote).mkString(".") } - private def quote(part: String): String = { + def quote(part: String): String = { if (part.contains(".") || part.contains("`")) { s"`${part.replace("`", "``")}`" } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 9b5793cbc14ba..2d4a19a0a2ea7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, GlobalTempView, Loc import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions.{EqualTo, Literal} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition.{createAfter, FIRST} +import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition.{after, first} import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType, TimestampType} @@ -542,13 +542,13 @@ class DDLParserSuite extends AnalysisTest { comparePlans( parsePlan("ALTER TABLE table_name ADD COLUMN x int FIRST"), AlterTableAddColumnsStatement(Seq("table_name"), Seq( - QualifiedColType(Seq("x"), IntegerType, None, Some(FIRST)) + QualifiedColType(Seq("x"), IntegerType, None, Some(first())) ))) comparePlans( parsePlan("ALTER TABLE table_name ADD COLUMN x int AFTER y"), AlterTableAddColumnsStatement(Seq("table_name"), Seq( - QualifiedColType(Seq("x"), IntegerType, None, Some(createAfter("y"))) + QualifiedColType(Seq("x"), IntegerType, None, Some(after("y"))) ))) } @@ -565,7 +565,7 @@ class DDLParserSuite extends AnalysisTest { parsePlan("ALTER TABLE table_name ADD COLUMN x.y.z int COMMENT 'doc', a.b string FIRST"), AlterTableAddColumnsStatement(Seq("table_name"), Seq( QualifiedColType(Seq("x", "y", "z"), IntegerType, Some("doc"), None), - QualifiedColType(Seq("a", "b"), StringType, None, Some(FIRST)) + QualifiedColType(Seq("a", "b"), StringType, None, Some(first())) ))) } @@ -632,7 +632,7 @@ class DDLParserSuite extends AnalysisTest { Seq("a", "b", "c"), None, None, - Some(FIRST))) + Some(first()))) } test("alter table: update column type, comment and position") { @@ -644,7 +644,7 @@ class DDLParserSuite extends AnalysisTest { Seq("a", "b", "c"), Some(LongType), Some("new comment"), - Some(createAfter("d")))) + Some(after("d")))) } test("alter table: drop column") { From ea58952c606c77c83230be090a7c5457df36d5ae Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 13 Dec 2019 01:51:09 +0800 Subject: [PATCH 5/6] improve test --- .../sql/connector/catalog/CatalogV2Util.scala | 59 ++++++++++++------- .../analysis/ResolveSessionCatalog.scala | 8 +-- .../sql-tests/results/change-column.sql.out | 20 ++----- .../spark/sql/connector/AlterTableTests.scala | 48 +++++++++------ 4 files changed, 77 insertions(+), 58 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index 576f40ab1f7b1..33619e81d6d02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -102,32 +102,18 @@ private[sql] object CatalogV2Util { changes.foldLeft(schema) { (schema, change) => change match { case add: AddColumn => - if (add.position != null) { - throw new UnsupportedOperationException("column position is not supported yet.") - } - add.fieldNames match { case Array(name) => - val newField = StructField(name, add.dataType, nullable = add.isNullable) - Option(add.comment) match { - case Some(comment) => - schema.add(newField.withComment(comment)) - case _ => - schema.add(newField) - } + val field = StructField(name, add.dataType, nullable = add.isNullable) + val newField = Option(add.comment).map(field.withComment).getOrElse(field) + addField(schema, newField, add.position()) case names => replace(schema, names.init, parent => parent.dataType match { case parentType: StructType => val field = StructField(names.last, add.dataType, nullable = add.isNullable) - val newParentType = Option(add.comment) match { - case Some(comment) => - parentType.add(field.withComment(comment)) - case None => - parentType.add(field) - } - - Some(StructField(parent.name, newParentType, parent.nullable, parent.metadata)) + val newField = Option(add.comment).map(field.withComment).getOrElse(field) + Some(parent.copy(dataType = addField(parentType, newField, add.position()))) case _ => throw new IllegalArgumentException(s"Not a struct: ${names.init.last}") @@ -151,8 +137,24 @@ private[sql] object CatalogV2Util { replace(schema, update.fieldNames, field => Some(field.withComment(update.newComment))) - case _: UpdateColumnPosition => - throw new UnsupportedOperationException("column position is not supported yet.") + case update: UpdateColumnPosition => + def updateFieldPos(struct: StructType, name: String): StructType = { + val oldField = struct.fields.find(_.name == name).getOrElse { + throw new IllegalArgumentException("field not found: " + name) + } + val withFieldRemoved = StructType(struct.fields.filter(_ != oldField)) + addField(withFieldRemoved, oldField, update.position()) + } + + update.fieldNames() match { + case Array(name) => + updateFieldPos(schema, name) + case names => + replace(schema, names.init, parent => parent.dataType match { + case parentType: StructType => + Some(parent.copy(dataType = updateFieldPos(parentType, names.last))) + }) + } case delete: DeleteColumn => replace(schema, delete.fieldNames, _ => None) @@ -164,6 +166,21 @@ private[sql] object CatalogV2Util { } } + private def addField( + schema: StructType, + field: StructField, + position: ColumnPosition): StructType = { + if (position == null) { + schema.add(field) + } else if (position.isInstanceOf[First]) { + StructType(field +: schema.fields) + } else { + val afterCol = position.asInstanceOf[After].column() + val (before, after) = schema.fields.span(_.name == afterCol) + StructType(before ++ (field +: after)) + } + } + private def replace( struct: StructType, fieldNames: Seq[String], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 1da4ec92ac2d0..75651bf5e24d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -96,15 +96,15 @@ class ResolveSessionCatalog( builder.build()) AlterTableChangeColumnCommand(tbl.asTableIdentifier, colName(0), newColumn) }.getOrElse { - val nameParts = colName.toArray + val colNameArray = colName.toArray val typeChange = dataType.map { newDataType => - TableChange.updateColumnType(nameParts, newDataType, true) + TableChange.updateColumnType(colNameArray, newDataType, true) } val commentChange = comment.map { newComment => - TableChange.updateColumnComment(nameParts, newComment) + TableChange.updateColumnComment(colNameArray, newComment) } val positionChange = pos.map { newPosition => - TableChange.updateColumnPosition(nameParts, newPosition) + TableChange.updateColumnPosition(colNameArray, newPosition) } createAlterTable( nameParts, diff --git a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out index 21a344c071bc4..82326346b361c 100644 --- a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out @@ -27,7 +27,7 @@ struct<> -- !query 2 output org.apache.spark.sql.catalyst.parser.ParseException -Operation not allowed: ALTER TABLE table CHANGE COLUMN requires a TYPE or a COMMENT(line 1, pos 0) +Operation not allowed: ALTER TABLE table CHANGE COLUMN requires a TYPE or a COMMENT or a FIRST/AFTER(line 1, pos 0) == SQL == ALTER TABLE test_change CHANGE a @@ -87,13 +87,8 @@ ALTER TABLE test_change CHANGE a TYPE INT AFTER b -- !query 8 schema struct<> -- !query 8 output -org.apache.spark.sql.catalyst.parser.ParseException - -Operation not allowed: ALTER TABLE table CHANGE COLUMN ... FIRST | AFTER otherCol(line 1, pos 0) - -== SQL == -ALTER TABLE test_change CHANGE a TYPE INT AFTER b -^^^ +org.apache.spark.sql.AnalysisException +ALTER COLUMN ... FIRST | ALTER is only supported with v2 tables.; -- !query 9 @@ -101,13 +96,8 @@ ALTER TABLE test_change CHANGE b TYPE STRING FIRST -- !query 9 schema struct<> -- !query 9 output -org.apache.spark.sql.catalyst.parser.ParseException - -Operation not allowed: ALTER TABLE table CHANGE COLUMN ... FIRST | AFTER otherCol(line 1, pos 0) - -== SQL == -ALTER TABLE test_change CHANGE b TYPE STRING FIRST -^^^ +org.apache.spark.sql.AnalysisException +ALTER COLUMN ... FIRST | ALTER is only supported with v2 tables.; -- !query 10 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala index 8882274b040fa..09582f1dbba52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala @@ -104,15 +104,21 @@ trait AlterTableTests extends SharedSparkSession { test("AlterTable: add column with position") { val t = s"${catalogAndNamespace}table_name" withTable(t) { - sql(s"CREATE TABLE $t (id int) USING $v2Format") - val e1 = intercept[UnsupportedOperationException] { - sql(s"ALTER TABLE $t ADD COLUMN data string FIRST") - } - assert(e1.getMessage.contains("column position is not supported")) - val e2 = intercept[UnsupportedOperationException] { - sql(s"ALTER TABLE $t ADD COLUMN data string AFTER id") - } - assert(e2.getMessage.contains("column position is not supported")) + sql(s"CREATE TABLE $t (id struct) USING $v2Format") + + sql(s"ALTER TABLE $t ADD COLUMN a string FIRST") + assert(getTableMetadata(t).schema.names.toSeq == Seq("a", "id")) + + sql(s"ALTER TABLE $t ADD COLUMN b string AFTER a") + assert(getTableMetadata(t).schema.names.toSeq == Seq("a", "b", "id")) + + sql(s"ALTER TABLE $t ADD COLUMN id.y string FIRST") + assert(getTableMetadata(t).schema.last.dataType.asInstanceOf[StructType].names.toSeq == + Seq("y", "x")) + + sql(s"ALTER TABLE $t ADD COLUMN id.z string AFTER y") + assert(getTableMetadata(t).schema.last.dataType.asInstanceOf[StructType].names.toSeq == + Seq("y", "z", "x")) } } @@ -489,15 +495,21 @@ trait AlterTableTests extends SharedSparkSession { test("AlterTable: update column position") { val t = s"${catalogAndNamespace}table_name" withTable(t) { - sql(s"CREATE TABLE $t (a int, b int) USING $v2Format") - val e1 = intercept[UnsupportedOperationException] { - sql(s"ALTER TABLE $t ALTER COLUMN b FIRST") - } - assert(e1.getMessage.contains("column position is not supported")) - val e2 = intercept[UnsupportedOperationException] { - sql(s"ALTER TABLE $t ALTER COLUMN a AFTER b") - } - assert(e2.getMessage.contains("column position is not supported")) + sql(s"CREATE TABLE $t (a int, b struct) USING $v2Format") + + sql(s"ALTER TABLE $t ALTER COLUMN b FIRST") + assert(getTableMetadata(t).schema().names.toSeq == Seq("b", "a")) + + sql(s"ALTER TABLE $t ALTER COLUMN b AFTER a") + assert(getTableMetadata(t).schema().names.toSeq == Seq("a", "b")) + + sql(s"ALTER TABLE $t ALTER COLUMN b.y FIRST") + assert(getTableMetadata(t).schema.apply("b").dataType.asInstanceOf[StructType].names.toSeq == + Seq("y", "x")) + + sql(s"ALTER TABLE $t ALTER COLUMN b.y AFTER x") + assert(getTableMetadata(t).schema.apply("b").dataType.asInstanceOf[StructType].names.toSeq == + Seq("x", "y")) } } From c01f565d048f9f84aa08616113941e5f072158c4 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 13 Dec 2019 15:47:28 +0800 Subject: [PATCH 6/6] address comments --- .../sql/connector/catalog/TableChange.java | 4 +- .../sql/connector/catalog/CatalogV2Util.scala | 10 +- .../spark/sql/connector/AlterTableTests.scala | 108 +++++++++++++----- 3 files changed, 92 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java index d2a1c455983b9..783439935c8d2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java @@ -301,7 +301,7 @@ public int hashCode() { interface ColumnPosition { static ColumnPosition first() { - return First.singleton; + return First.SINGLETON; } static ColumnPosition after(String column) { @@ -315,7 +315,7 @@ static ColumnPosition after(String column) { * be the first one within the struct. */ final class First implements ColumnPosition { - private static First singleton = new First(); + private static final First SINGLETON = new First(); private First() {} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index 33619e81d6d02..2f4914dd7db30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -140,7 +140,7 @@ private[sql] object CatalogV2Util { case update: UpdateColumnPosition => def updateFieldPos(struct: StructType, name: String): StructType = { val oldField = struct.fields.find(_.name == name).getOrElse { - throw new IllegalArgumentException("field not found: " + name) + throw new IllegalArgumentException("Field not found: " + name) } val withFieldRemoved = StructType(struct.fields.filter(_ != oldField)) addField(withFieldRemoved, oldField, update.position()) @@ -153,6 +153,8 @@ private[sql] object CatalogV2Util { replace(schema, names.init, parent => parent.dataType match { case parentType: StructType => Some(parent.copy(dataType = updateFieldPos(parentType, names.last))) + case _ => + throw new IllegalArgumentException(s"Not a struct: ${names.init.last}") }) } @@ -176,7 +178,11 @@ private[sql] object CatalogV2Util { StructType(field +: schema.fields) } else { val afterCol = position.asInstanceOf[After].column() - val (before, after) = schema.fields.span(_.name == afterCol) + val fieldIndex = schema.fields.indexWhere(_.name == afterCol) + if (fieldIndex == -1) { + throw new IllegalArgumentException("AFTER column not found: " + afterCol) + } + val (before, after) = schema.fields.splitAt(fieldIndex + 1) StructType(before ++ (field +: after)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala index 09582f1dbba52..2ba3c99dfbefd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala @@ -104,21 +104,43 @@ trait AlterTableTests extends SharedSparkSession { test("AlterTable: add column with position") { val t = s"${catalogAndNamespace}table_name" withTable(t) { - sql(s"CREATE TABLE $t (id struct) USING $v2Format") + sql(s"CREATE TABLE $t (point struct) USING $v2Format") sql(s"ALTER TABLE $t ADD COLUMN a string FIRST") - assert(getTableMetadata(t).schema.names.toSeq == Seq("a", "id")) - - sql(s"ALTER TABLE $t ADD COLUMN b string AFTER a") - assert(getTableMetadata(t).schema.names.toSeq == Seq("a", "b", "id")) - - sql(s"ALTER TABLE $t ADD COLUMN id.y string FIRST") - assert(getTableMetadata(t).schema.last.dataType.asInstanceOf[StructType].names.toSeq == - Seq("y", "x")) - - sql(s"ALTER TABLE $t ADD COLUMN id.z string AFTER y") - assert(getTableMetadata(t).schema.last.dataType.asInstanceOf[StructType].names.toSeq == - Seq("y", "z", "x")) + assert(getTableMetadata(t).schema == new StructType() + .add("a", StringType) + .add("point", new StructType().add("x", IntegerType))) + + sql(s"ALTER TABLE $t ADD COLUMN b string AFTER point") + assert(getTableMetadata(t).schema == new StructType() + .add("a", StringType) + .add("point", new StructType().add("x", IntegerType)) + .add("b", StringType)) + + val e1 = intercept[SparkException]( + sql(s"ALTER TABLE $t ADD COLUMN c string AFTER non_exist")) + assert(e1.getMessage().contains("AFTER column not found")) + + sql(s"ALTER TABLE $t ADD COLUMN point.y int FIRST") + assert(getTableMetadata(t).schema == new StructType() + .add("a", StringType) + .add("point", new StructType() + .add("y", IntegerType) + .add("x", IntegerType)) + .add("b", StringType)) + + sql(s"ALTER TABLE $t ADD COLUMN point.z int AFTER x") + assert(getTableMetadata(t).schema == new StructType() + .add("a", StringType) + .add("point", new StructType() + .add("y", IntegerType) + .add("x", IntegerType) + .add("z", IntegerType)) + .add("b", StringType)) + + val e2 = intercept[SparkException]( + sql(s"ALTER TABLE $t ADD COLUMN point.x2 int AFTER non_exist")) + assert(e2.getMessage().contains("AFTER column not found")) } } @@ -495,21 +517,55 @@ trait AlterTableTests extends SharedSparkSession { test("AlterTable: update column position") { val t = s"${catalogAndNamespace}table_name" withTable(t) { - sql(s"CREATE TABLE $t (a int, b struct) USING $v2Format") + sql(s"CREATE TABLE $t (a int, b int, point struct) USING $v2Format") sql(s"ALTER TABLE $t ALTER COLUMN b FIRST") - assert(getTableMetadata(t).schema().names.toSeq == Seq("b", "a")) - - sql(s"ALTER TABLE $t ALTER COLUMN b AFTER a") - assert(getTableMetadata(t).schema().names.toSeq == Seq("a", "b")) - - sql(s"ALTER TABLE $t ALTER COLUMN b.y FIRST") - assert(getTableMetadata(t).schema.apply("b").dataType.asInstanceOf[StructType].names.toSeq == - Seq("y", "x")) - - sql(s"ALTER TABLE $t ALTER COLUMN b.y AFTER x") - assert(getTableMetadata(t).schema.apply("b").dataType.asInstanceOf[StructType].names.toSeq == - Seq("x", "y")) + assert(getTableMetadata(t).schema == new StructType() + .add("b", IntegerType) + .add("a", IntegerType) + .add("point", new StructType() + .add("x", IntegerType) + .add("y", IntegerType) + .add("z", IntegerType))) + + sql(s"ALTER TABLE $t ALTER COLUMN b AFTER point") + assert(getTableMetadata(t).schema == new StructType() + .add("a", IntegerType) + .add("point", new StructType() + .add("x", IntegerType) + .add("y", IntegerType) + .add("z", IntegerType)) + .add("b", IntegerType)) + + val e1 = intercept[SparkException]( + sql(s"ALTER TABLE $t ALTER COLUMN b AFTER non_exist")) + assert(e1.getMessage.contains("AFTER column not found")) + + sql(s"ALTER TABLE $t ALTER COLUMN point.y FIRST") + assert(getTableMetadata(t).schema == new StructType() + .add("a", IntegerType) + .add("point", new StructType() + .add("y", IntegerType) + .add("x", IntegerType) + .add("z", IntegerType)) + .add("b", IntegerType)) + + sql(s"ALTER TABLE $t ALTER COLUMN point.y AFTER z") + assert(getTableMetadata(t).schema == new StructType() + .add("a", IntegerType) + .add("point", new StructType() + .add("x", IntegerType) + .add("z", IntegerType) + .add("y", IntegerType)) + .add("b", IntegerType)) + + val e2 = intercept[SparkException]( + sql(s"ALTER TABLE $t ALTER COLUMN point.y AFTER non_exist")) + assert(e2.getMessage.contains("AFTER column not found")) + + // `AlterTable.resolved` checks column existence. + intercept[AnalysisException]( + sql(s"ALTER TABLE $t ALTER COLUMN a.y AFTER x")) } }