diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index f1012edd2de2d..64b71af925527 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -788,6 +788,20 @@ }, "sqlState" : "XX000" }, + "CONSTRAINT_ALREADY_EXISTS" : { + "message" : [ + "Constraint '' already exists. Please delete the old constraint first.", + "Old constraint:", + "" + ], + "sqlState" : "42710" + }, + "CONSTRAINT_DOES_NOT_EXIST" : { + "message" : [ + "Cannot drop nonexistent constraint from table ." + ], + "sqlState" : "42704" + }, "CONVERSION_INVALID_INPUT" : { "message" : [ "The value () cannot be converted to because it is malformed. Correct the value as per the syntax, or change its format. Use to tolerate malformed input and return NULL instead." @@ -2303,6 +2317,29 @@ ], "sqlState" : "22P03" }, + "INVALID_CHECK_CONSTRAINT": { + "message" : [ + "The check constraint expression is invalid." + ], + "subClass" : { + "INVALID_V2_PREDICATE" : { + "message" : [ + "It cannot be converted to a data source V2 predicate." + ] + }, + "NONDETERMINISTIC" : { + "message" : [ + "It contains nondeterministic expression." + ] + }, + "MISSING_NAME": { + "message": [ + "The check constraint must have a name." + ] + } + }, + "sqlState": "42621" + }, "INVALID_COLUMN_NAME_AS_PATH" : { "message" : [ "The datasource cannot save the column because its name contains some characters that are not allowed in file paths. Please, use an alias to rename it." @@ -2328,6 +2365,12 @@ }, "sqlState" : "22022" }, + "INVALID_CONSTRAINT_CHARACTERISTICS": { + "message": [ + "Constraint characteristics [] are duplicated or conflict with each other." + ], + "sqlState": "42613" + }, "INVALID_CORRUPT_RECORD_TYPE" : { "message" : [ "The column for corrupt records must have the nullable STRING type, but got ." @@ -3884,6 +3927,12 @@ ], "sqlState" : "42P20" }, + "MULTIPLE_PRIMARY_KEYS" : { + "message" : [ + "Multiple primary keys are defined. Please ensure that only one primary key is defined for the table." + ], + "sqlState" : "42K0E" + }, "MULTIPLE_QUERY_RESULT_CLAUSES_WITH_PIPE_OPERATORS" : { "message" : [ " and cannot coexist in the same SQL pipe operator using '|>'. Please separate the multiple result clauses into separate pipe operators and then retry the query again." @@ -5452,6 +5501,12 @@ }, "sqlState" : "0A000" }, + "UNSUPPORTED_CONSTRAINT_CHARACTERISTIC": { + "message": [ + "Constraint characteristic '' is not supported for constraint type ''." + ], + "sqlState": "0A000" + }, "UNSUPPORTED_DATASOURCE_FOR_DIRECT_QUERY" : { "message" : [ "Unsupported data source type for direct query on files: " @@ -5611,6 +5666,11 @@ "Attach a comment to the namespace ." ] }, + "CONSTRAINT_TYPE" : { + "message" : [ + "Constraint ." + ] + }, "CONTINUE_EXCEPTION_HANDLER" : { "message" : [ "CONTINUE exception handler is not supported. Use EXIT handler." diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index e38efc27b78f9..ca8326918feca 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -371,10 +371,17 @@ abstract class SparkFunSuite "Invalid objectType of a query context Actual:" + actual.toString) assert(actual.objectName() === expected.objectName, "Invalid objectName of a query context. Actual:" + actual.toString) - assert(actual.startIndex() === expected.startIndex, - "Invalid startIndex of a query context. Actual:" + actual.toString) - assert(actual.stopIndex() === expected.stopIndex, - "Invalid stopIndex of a query context. Actual:" + actual.toString) + // If startIndex and stopIndex are -1, it means we simply want to check the + // fragment of the query context. This should be the case when the fragment is + // distinguished within the query text. + if (expected.startIndex != -1) { + assert(actual.startIndex() === expected.startIndex, + "Invalid startIndex of a query context. Actual:" + actual.toString) + } + if (expected.stopIndex != -1) { + assert(actual.stopIndex() === expected.stopIndex, + "Invalid stopIndex of a query context. Actual:" + actual.toString) + } assert(actual.fragment() === expected.fragment, "Invalid fragment of a query context. Actual:" + actual.toString) } else if (actual.contextType() == QueryContextType.DataFrame) { @@ -478,6 +485,12 @@ abstract class SparkFunSuite ExpectedContext("", "", start, stop, fragment) } + // Check the fragment only. This is only used when the fragment is distinguished within + // the query text + def apply(fragment: String): ExpectedContext = { + ExpectedContext("", "", -1, -1, fragment) + } + def apply( objectType: String, objectName: String, diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 index b868eea41b692..0975b4dc61f05 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 @@ -222,6 +222,7 @@ DROP: 'DROP'; ELSE: 'ELSE'; ELSEIF: 'ELSEIF'; END: 'END'; +ENFORCED: 'ENFORCED'; ESCAPE: 'ESCAPE'; ESCAPED: 'ESCAPED'; EVOLUTION: 'EVOLUTION'; @@ -290,6 +291,7 @@ ITEMS: 'ITEMS'; ITERATE: 'ITERATE'; JOIN: 'JOIN'; JSON: 'JSON'; +KEY: 'KEY'; KEYS: 'KEYS'; LANGUAGE: 'LANGUAGE'; LAST: 'LAST'; @@ -337,6 +339,8 @@ NOT: 'NOT'; NULL: 'NULL'; NULLS: 'NULLS'; NUMERIC: 'NUMERIC'; +NORELY: 'NORELY'; +NOVALIDATE: 'NOVALIDATE'; OF: 'OF'; OFFSET: 'OFFSET'; ON: 'ON'; @@ -376,6 +380,7 @@ RECURSIVE: 'RECURSIVE'; REDUCE: 'REDUCE'; REFERENCES: 'REFERENCES'; REFRESH: 'REFRESH'; +RELY: 'RELY'; RENAME: 'RENAME'; REPAIR: 'REPAIR'; REPEAT: 'REPEAT'; @@ -475,6 +480,7 @@ UPDATE: 'UPDATE'; USE: 'USE'; USER: 'USER'; USING: 'USING'; +VALIDATE: 'VALIDATE'; VALUE: 'VALUE'; VALUES: 'VALUES'; VARCHAR: 'VARCHAR'; diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 59a0b1ce7a3c5..f930552444a9a 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -198,7 +198,7 @@ statement (RESTRICT | CASCADE)? #dropNamespace | SHOW namespaces ((FROM | IN) multipartIdentifier)? (LIKE? pattern=stringLit)? #showNamespaces - | createTableHeader (LEFT_PAREN colDefinitionList RIGHT_PAREN)? tableProvider? + | createTableHeader (LEFT_PAREN tableElementList RIGHT_PAREN)? tableProvider? createTableClauses (AS? query)? #createTable | CREATE TABLE (IF errorCapturingNot EXISTS)? target=tableIdentifier @@ -208,7 +208,7 @@ statement createFileFormat | locationSpec | (TBLPROPERTIES tableProps=propertyList))* #createTableLike - | replaceTableHeader (LEFT_PAREN colDefinitionList RIGHT_PAREN)? tableProvider? + | replaceTableHeader (LEFT_PAREN tableElementList RIGHT_PAREN)? tableProvider? createTableClauses (AS? query)? #replaceTable | ANALYZE TABLE identifierReference partitionSpec? COMPUTE STATISTICS @@ -261,6 +261,10 @@ statement | ALTER TABLE identifierReference (clusterBySpec | CLUSTER BY NONE) #alterClusterBy | ALTER TABLE identifierReference collationSpec #alterTableCollation + | ALTER TABLE identifierReference ADD tableConstraintDefinition #addTableConstraint + | ALTER TABLE identifierReference + DROP CONSTRAINT (IF EXISTS)? name=identifier + (RESTRICT | CASCADE)? #dropTableConstraint | DROP TABLE (IF EXISTS)? identifierReference PURGE? #dropTable | DROP VIEW (IF EXISTS)? identifierReference #dropView | CREATE (OR REPLACE)? (GLOBAL? TEMPORARY)? @@ -1334,6 +1338,15 @@ colType : colName=errorCapturingIdentifier dataType (errorCapturingNot NULL)? commentSpec? ; +tableElementList + : tableElement (COMMA tableElement)* + ; + +tableElement + : tableConstraintDefinition + | colDefinition + ; + colDefinitionList : colDefinition (COMMA colDefinition)* ; @@ -1347,6 +1360,7 @@ colDefinitionOption | defaultExpression | generationExpression | commentSpec + | columnConstraintDefinition ; generationExpression @@ -1516,6 +1530,62 @@ number | MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral ; +columnConstraintDefinition + : (CONSTRAINT name=errorCapturingIdentifier)? columnConstraint constraintCharacteristic* + ; + +columnConstraint + : checkConstraint + | uniqueSpec + | referenceSpec + ; + +tableConstraintDefinition + : (CONSTRAINT name=errorCapturingIdentifier)? tableConstraint constraintCharacteristic* + ; + +tableConstraint + : checkConstraint + | uniqueConstraint + | foreignKeyConstraint + ; + +checkConstraint + : CHECK LEFT_PAREN (expr=booleanExpression) RIGHT_PAREN + ; + +uniqueSpec + : UNIQUE + | PRIMARY KEY + ; + +uniqueConstraint + : uniqueSpec identifierList + ; + +referenceSpec + : REFERENCES multipartIdentifier (parentColumns=identifierList)? + ; + +foreignKeyConstraint + : FOREIGN KEY identifierList referenceSpec + ; + +constraintCharacteristic + : enforcedCharacteristic + | relyCharacteristic + ; + +enforcedCharacteristic + : ENFORCED + | NOT ENFORCED + ; + +relyCharacteristic + : RELY + | NORELY + ; + alterColumnSpecList : alterColumnSpec (COMMA alterColumnSpec)* ; @@ -1673,6 +1743,7 @@ ansiNonReserved | DOUBLE | DROP | ELSEIF + | ENFORCED | ESCAPED | EVOLUTION | EXCHANGE @@ -1761,6 +1832,8 @@ ansiNonReserved | NANOSECONDS | NO | NONE + | NORELY + | NOVALIDATE | NULLS | NUMERIC | OF @@ -1792,6 +1865,7 @@ ansiNonReserved | RECOVER | REDUCE | REFRESH + | RELY | RENAME | REPAIR | REPEAT @@ -1875,6 +1949,7 @@ ansiNonReserved | UNTIL | UPDATE | USE + | VALIDATE | VALUE | VALUES | VARCHAR diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala index 0bd9f38014984..a73ec2041c01e 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala @@ -791,4 +791,24 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { def clusterByWithBucketing(ctx: ParserRuleContext): Throwable = { new ParseException(errorClass = "SPECIFY_CLUSTER_BY_WITH_BUCKETING_IS_NOT_ALLOWED", ctx) } + + def invalidConstraintCharacteristics( + ctx: ParserRuleContext, + characteristics: String): Throwable = { + new ParseException( + errorClass = "INVALID_CONSTRAINT_CHARACTERISTICS", + messageParameters = Map("characteristics" -> characteristics), + ctx) + } + + def constraintNotSupportedError(ctx: ParserRuleContext, constraint: String): Throwable = { + new ParseException( + errorClass = "UNSUPPORTED_FEATURE.CONSTRAINT_TYPE", + messageParameters = Map("constraint" -> constraint), + ctx) + } + + def multiplePrimaryKeysError(ctx: ParserRuleContext): Throwable = { + new ParseException(errorClass = "MULTIPLE_PRIMARY_KEYS", ctx) + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Table.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Table.java index d5eb03dcf94d4..166554b0b4ca0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Table.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Table.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.connector.catalog; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.catalog.constraints.Constraint; import org.apache.spark.sql.connector.expressions.Transform; import org.apache.spark.sql.types.StructType; @@ -83,4 +84,9 @@ default Map properties() { * Returns the set of capabilities for this table. */ Set capabilities(); + + /** + * Returns the constraints for this table. + */ + default Constraint[] constraints() { return new Constraint[0]; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java index 77dbaa7687b41..f98a72b0f9b6c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.connector.catalog; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.catalog.constraints.Constraint; import org.apache.spark.sql.connector.expressions.Transform; import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; @@ -230,6 +231,30 @@ default Table createTable( return createTable(ident, CatalogV2Util.v2ColumnsToStructType(columns), partitions, properties); } + /** + * Create a table in the catalog. + * + * @param ident a table identifier + * @param columns the columns of the new table. + * @param partitions transforms to use for partitioning data in the table + * @param properties a string map of table properties + * @param constraints constraints for the new table + * @return metadata for the new table. This can be null if getting the metadata for the new table + * is expensive. Spark will call {@link #loadTable(Identifier)} if needed (e.g. CTAS). + * + * @throws TableAlreadyExistsException If a table or view already exists for the identifier + * @throws UnsupportedOperationException If a requested partition transform is not supported + * @throws NoSuchNamespaceException If the identifier namespace does not exist (optional) + */ + default Table createTable( + Identifier ident, + Column[] columns, + Transform[] partitions, + Map properties, + Constraint[] constraints) throws TableAlreadyExistsException, NoSuchNamespaceException { + return createTable(ident, CatalogV2Util.v2ColumnsToStructType(columns), partitions, properties); + } + /** * If true, mark all the fields of the query schema as nullable when executing * CREATE/REPLACE TABLE ... AS SELECT ... and creating the table. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java index d7a51c519e09b..43b8efaf55fc8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java @@ -22,6 +22,7 @@ import javax.annotation.Nullable; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.catalog.constraints.Constraint; import org.apache.spark.sql.connector.expressions.NamedReference; import org.apache.spark.sql.types.DataType; @@ -260,6 +261,24 @@ static TableChange clusterBy(NamedReference[] clusteringColumns) { return new ClusterBy(clusteringColumns); } + /** + * Create a TableChange for adding a new Table Constraint + */ + static TableChange addConstraint(Constraint constraint, Boolean validate) { + return new AddConstraint(constraint, validate); + } + + /** + * Create a TableChange for dropping a Table Constraint + */ + static TableChange dropConstraint(String name, Boolean ifExists, Boolean cascade) { + DropConstraintMode mode = DropConstraintMode.RESTRICT; + if (cascade) { + mode = DropConstraintMode.CASCADE; + } + return new DropConstraint(name, ifExists, mode); + } + /** * A TableChange to set a table property. *

@@ -787,4 +806,75 @@ public int hashCode() { return Arrays.hashCode(clusteringColumns); } } + + /** A TableChange to alter table and add a constraint. */ + final class AddConstraint implements TableChange { + private final Constraint constraint; + private final boolean validate; + + private AddConstraint(Constraint constraint, boolean validate) { + this.constraint = constraint; + this.validate = validate; + } + + public Constraint getConstraint() { + return constraint; + } + + public boolean isValidate() { + return validate; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AddConstraint that = (AddConstraint) o; + return constraint.equals(that.constraint) && validate == that.validate; + } + + @Override + public int hashCode() { + return Objects.hash(constraint, validate); + } + } + + enum DropConstraintMode { RESTRICT, CASCADE } + + /** A TableChange to alter table and drop a constraint. */ + final class DropConstraint implements TableChange { + private final String name; + private final boolean ifExists; + private final DropConstraintMode mode; + + private DropConstraint(String name, boolean ifExists, DropConstraintMode mode) { + this.name = name; + this.ifExists = ifExists; + this.mode = mode; + } + + public String getName() { + return name; + } + + public boolean isIfExists() { + return ifExists; + } + + public DropConstraintMode getMode() { + return mode; + } + + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DropConstraint that = (DropConstraint) o; + return that.name.equals(name) && that.ifExists == ifExists && mode == that.mode; + } + + @Override + public int hashCode() { + return Objects.hash(name, ifExists, mode); + } + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/BaseConstraint.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/BaseConstraint.java new file mode 100644 index 0000000000000..3d5bd5afe8aa7 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/BaseConstraint.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.constraints; + +import java.util.StringJoiner; + +import org.apache.spark.sql.connector.expressions.NamedReference; + +abstract class BaseConstraint implements Constraint { + + private final String name; + private final boolean enforced; + private final ValidationStatus validationStatus; + private final boolean rely; + + protected BaseConstraint( + String name, + boolean enforced, + ValidationStatus validationStatus, + boolean rely) { + this.name = name; + this.enforced = enforced; + this.validationStatus = validationStatus; + this.rely = rely; + } + + protected abstract String definition(); + + @Override + public String name() { + return name; + } + + @Override + public boolean enforced() { + return enforced; + } + + @Override + public ValidationStatus validationStatus() { + return validationStatus; + } + + @Override + public boolean rely() { + return rely; + } + + @Override + public String toDDL() { + return String.format( + "CONSTRAINT %s %s %s %s %s", + name, + definition(), + enforced ? "ENFORCED" : "NOT ENFORCED", + validationStatus, + rely ? "RELY" : "NORELY"); + } + + @Override + public String toString() { + return toDDL(); + } + + protected String toSQL(NamedReference[] columns) { + StringJoiner joiner = new StringJoiner(", "); + + for (NamedReference column : columns) { + joiner.add(column.toString()); + } + + return joiner.toString(); + } + + abstract static class Builder { + private final String name; + private boolean enforced = true; + private ValidationStatus validationStatus = ValidationStatus.UNVALIDATED; + private boolean rely = false; + + Builder(String name) { + this.name = name; + } + + protected abstract B self(); + + public abstract C build(); + + public String name() { + return name; + } + + public B enforced(boolean enforced) { + this.enforced = enforced; + return self(); + } + + public boolean enforced() { + return enforced; + } + + public B validationStatus(ValidationStatus validationStatus) { + this.validationStatus = validationStatus; + return self(); + } + + public ValidationStatus validationStatus() { + return validationStatus; + } + + public B rely(boolean rely) { + this.rely = rely; + return self(); + } + + public boolean rely() { + return rely; + } + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Check.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Check.java new file mode 100644 index 0000000000000..fbec7651d2515 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Check.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.constraints; + +import java.util.Map; +import java.util.Objects; + +import org.apache.spark.SparkIllegalArgumentException; +import org.apache.spark.sql.connector.expressions.filter.Predicate; + +/** + * A CHECK constraint. + *

+ * A CHECK constraint defines a condition each row in a table must satisfy. Connectors can define + * such constraints either in SQL (Spark SQL dialect) or using a {@link Predicate predicate} if the + * condition can be expressed using a supported expression. A CHECK constraint can reference one or + * more columns. Such constraint is considered violated if its condition evaluates to {@code FALSE}, + * but not {@code NULL}. The search condition must be deterministic and cannot contain subqueries + * and certain functions like aggregates or UDFs. + * + * @since 4.1.0 + */ +public class Check extends BaseConstraint { + + private final String sql; + private final Predicate predicate; + + private Check( + String name, + String sql, + Predicate predicate, + boolean enforced, + ValidationStatus validationStatus, + boolean rely) { + super(name, enforced, validationStatus, rely); + this.sql = sql; + this.predicate = predicate; + } + + /** + * Returns the SQL representation of the search condition (Spark SQL dialect). + */ + public String sql() { + return sql; + } + + /** + * Returns the search condition. + */ + public Predicate predicate() { + return predicate; + } + + @Override + protected String definition() { + return String.format("CHECK %s", sql != null ? sql : predicate); + } + + @Override + public boolean equals(Object other) { + if (this == other) return true; + if (other == null || getClass() != other.getClass()) return false; + Check that = (Check) other; + return Objects.equals(name(), that.name()) && + Objects.equals(sql, that.sql) && + Objects.equals(predicate, that.predicate) && + enforced() == that.enforced() && + Objects.equals(validationStatus(), that.validationStatus()) && + rely() == that.rely(); + } + + @Override + public int hashCode() { + return Objects.hash(name(), sql, predicate, enforced(), validationStatus(), rely()); + } + + public static class Builder extends BaseConstraint.Builder { + + private String sql; + private Predicate predicate; + + Builder(String name) { + super(name); + } + + @Override + protected Builder self() { + return this; + } + + public Builder sql(String sql) { + this.sql = sql; + return this; + } + + public Builder predicate(Predicate predicate) { + this.predicate = predicate; + return this; + } + + public Check build() { + if (sql == null && predicate == null) { + throw new SparkIllegalArgumentException( + "INTERNAL_ERROR", + Map.of("message", "SQL and predicate in CHECK can't be both null")); + } + return new Check(name(), sql, predicate, enforced(), validationStatus(), rely()); + } + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Constraint.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Constraint.java new file mode 100644 index 0000000000000..f8381326a2054 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Constraint.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.constraints; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * A constraint that restricts states of data in a table. + * + * @since 4.1.0 + */ +@Evolving +public interface Constraint { + /** + * Returns the name of this constraint. + */ + String name(); + + /** + * Indicates whether this constraint is actively enforced. If enforced, data modifications + * that violate the constraint fail with a constraint violation error. + */ + boolean enforced(); + + /** + * Indicates whether the existing data in the table satisfies this constraint. The constraint + * can be valid (the data is guaranteed to satisfy the constraint), invalid (some records violate + * the constraint), or unvalidated (the validity is unknown). + */ + ValidationStatus validationStatus(); + + /** + * Indicates whether this constraint is assumed to hold true if the validity is unknown. + */ + boolean rely(); + + /** + * Returns the definition of this constraint in the DDL format. + */ + String toDDL(); + + /** + * Instantiates a builder for a CHECK constraint. + * + * @param name the constraint name + * @return a CHECK constraint builder + */ + static Check.Builder check(String name) { + return new Check.Builder(name); + } + + /** + * Instantiates a builder for a UNIQUE constraint. + * + * @param name the constraint name + * @param columns columns that comprise the unique key + * @return a UNIQUE constraint builder + */ + static Unique.Builder unique(String name, NamedReference[] columns) { + return new Unique.Builder(name, columns); + } + + /** + * Instantiates a builder for a PRIMARY KEY constraint. + * + * @param name the constraint name + * @param columns columns that comprise the primary key + * @return a PRIMARY KEY constraint builder + */ + static PrimaryKey.Builder primaryKey(String name, NamedReference[] columns) { + return new PrimaryKey.Builder(name, columns); + } + + /** + * Instantiates a builder for a FOREIGN KEY constraint. + * + * @param name the constraint name + * @param columns the referencing columns + * @param refTable the referenced table identifier + * @param refColumns the referenced columns in the referenced table + * @return a FOREIGN KEY constraint builder + */ + static ForeignKey.Builder foreignKey( + String name, + NamedReference[] columns, + Identifier refTable, + NamedReference[] refColumns) { + return new ForeignKey.Builder(name, columns, refTable, refColumns); + } + + /** + * An indicator of the validity of the constraint. + *

+ * A constraint may be validated independently of enforcement, meaning it can be validated + * without being actively enforced, or vice versa. A constraint can be valid (the data is + * guaranteed to satisfy the constraint), invalid (some records violate the constraint), + * or unvalidated (the validity is unknown). + */ + enum ValidationStatus { + VALID, INVALID, UNVALIDATED + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/ForeignKey.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/ForeignKey.java new file mode 100644 index 0000000000000..4763f95ba98b5 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/ForeignKey.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.constraints; + +import java.util.Arrays; +import java.util.Objects; + +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * A FOREIGN KEY constraint. + *

+ * A FOREIGN KEY constraint specifies one or more columns (referencing columns) in a table that + * refer to corresponding columns (referenced columns) in another table. The referenced columns + * must form a UNIQUE or PRIMARY KEY constraint in the referenced table. For this constraint to be + * satisfied, each row in the table must contain values in the referencing columns that exactly + * match values of a row in the referenced table. + * + * @since 4.1.0 + */ +public class ForeignKey extends BaseConstraint { + + private final NamedReference[] columns; + private final Identifier refTable; + private final NamedReference[] refColumns; + + ForeignKey( + String name, + NamedReference[] columns, + Identifier refTable, + NamedReference[] refColumns, + boolean enforced, + ValidationStatus validationStatus, + boolean rely) { + super(name, enforced, validationStatus, rely); + this.columns = columns; + this.refTable = refTable; + this.refColumns = refColumns; + } + + /** + * Returns the referencing columns. + */ + public NamedReference[] columns() { + return columns; + } + + /** + * Returns the referenced table. + */ + public Identifier referencedTable() { + return refTable; + } + + /** + * Returns the referenced columns in the referenced table. + */ + public NamedReference[] referencedColumns() { + return refColumns; + } + + @Override + protected String definition() { + return String.format( + "FOREIGN KEY (%s) REFERENCES %s (%s)", + toSQL(columns), + refTable, + toSQL(refColumns)); + } + + @Override + public boolean equals(Object other) { + if (this == other) return true; + if (other == null || getClass() != other.getClass()) return false; + ForeignKey that = (ForeignKey) other; + return Objects.equals(name(), that.name()) && + Arrays.equals(columns, that.columns) && + Objects.equals(refTable, that.refTable) && + Arrays.equals(refColumns, that.refColumns) && + enforced() == that.enforced() && + Objects.equals(validationStatus(), that.validationStatus()) && + rely() == that.rely(); + } + + @Override + public int hashCode() { + int result = Objects.hash(name(), refTable, enforced(), validationStatus(), rely()); + result = 31 * result + Arrays.hashCode(columns); + result = 31 * result + Arrays.hashCode(refColumns); + return result; + } + + public static class Builder extends BaseConstraint.Builder { + + private final NamedReference[] columns; + private final Identifier refTable; + private final NamedReference[] refColumns; + + public Builder( + String name, + NamedReference[] columns, + Identifier refTable, + NamedReference[] refColumns) { + super(name); + this.columns = columns; + this.refTable = refTable; + this.refColumns = refColumns; + } + + @Override + protected Builder self() { + return this; + } + + @Override + public ForeignKey build() { + return new ForeignKey( + name(), + columns, + refTable, + refColumns, + enforced(), + validationStatus(), + rely()); + } + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/PrimaryKey.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/PrimaryKey.java new file mode 100644 index 0000000000000..caaf29c10538b --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/PrimaryKey.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.constraints; + +import java.util.Arrays; +import java.util.Objects; + +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * A PRIMARY KEY constraint. + *

+ * A PRIMARY KEY constraint specifies ore or more columns as a primary key. Such constraint is + * satisfied if and only if no two rows in a table have the same non-null values in the primary + * key columns and none of the values in the specified column or columns are {@code NULL}. + * A table can have at most one primary key. + * + * @since 4.1.0 + */ +public class PrimaryKey extends BaseConstraint { + + private final NamedReference[] columns; + + PrimaryKey( + String name, + NamedReference[] columns, + boolean enforced, + ValidationStatus validationStatus, + boolean rely) { + super(name, enforced, validationStatus, rely); + this.columns = columns; + } + + /** + * Returns the columns that comprise the primary key. + */ + public NamedReference[] columns() { + return columns; + } + + @Override + protected String definition() { + return String.format("PRIMARY KEY (%s)", toSQL(columns)); + } + + @Override + public boolean equals(Object other) { + if (this == other) return true; + if (other == null || getClass() != other.getClass()) return false; + PrimaryKey that = (PrimaryKey) other; + return Objects.equals(name(), that.name()) && + Arrays.equals(columns, that.columns()) && + enforced() == that.enforced() && + Objects.equals(validationStatus(), that.validationStatus()) && + rely() == that.rely(); + } + + @Override + public int hashCode() { + int result = Objects.hash(name(), enforced(), validationStatus(), rely()); + result = 31 * result + Arrays.hashCode(columns); + return result; + } + + public static class Builder extends BaseConstraint.Builder { + + private final NamedReference[] columns; + + Builder(String name, NamedReference[] columns) { + super(name); + this.columns = columns; + } + + @Override + protected Builder self() { + return this; + } + + @Override + public PrimaryKey build() { + return new PrimaryKey(name(), columns, enforced(), validationStatus(), rely()); + } + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Unique.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Unique.java new file mode 100644 index 0000000000000..394ad6b814e61 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Unique.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.constraints; + +import java.util.Arrays; +import java.util.Objects; + +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * A UNIQUE constraint. + *

+ * A UNIQUE constraint specifies one or more columns as unique columns. Such constraint is satisfied + * if and only if no two rows in a table have the same non-null values in the unique columns. + * + * @since 4.1.0 + */ +public class Unique extends BaseConstraint { + + private final NamedReference[] columns; + + private Unique( + String name, + NamedReference[] columns, + boolean enforced, + ValidationStatus validationStatus, + boolean rely) { + super(name, enforced, validationStatus, rely); + this.columns = columns; + } + + /** + * Returns the columns that comprise the unique key. + */ + public NamedReference[] columns() { + return columns; + } + + @Override + protected String definition() { + return String.format("UNIQUE (%s)", toSQL(columns)); + } + + @Override + public boolean equals(Object other) { + if (this == other) return true; + if (other == null || getClass() != other.getClass()) return false; + Unique that = (Unique) other; + return Objects.equals(name(), that.name()) && + Arrays.equals(columns, that.columns()) && + enforced() == that.enforced() && + Objects.equals(validationStatus(), that.validationStatus()) && + rely() == that.rely(); + } + + @Override + public int hashCode() { + int result = Objects.hash(name(), enforced(), validationStatus(), rely()); + result = 31 * result + Arrays.hashCode(columns); + return result; + } + + public static class Builder extends BaseConstraint.Builder { + + private final NamedReference[] columns; + + Builder(String name, NamedReference[] columns) { + super(name); + this.columns = columns; + } + + @Override + protected Builder self() { + return this; + } + + public Unique build() { + return new Unique(name(), columns, enforced(), validationStatus(), rely()); + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 1b45fcde91266..fb0721a0db6bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -1190,6 +1190,21 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString } case _ => } + + case AddConstraint(table: ResolvedTable, check: CheckConstraint) => + if (!check.resolved) { + check.child.failAnalysis( + errorClass = "INVALID_CHECK_CONSTRAINT.UNRESOLVED", + messageParameters = Map.empty + ) + } + + if (!check.deterministic) { + check.child.failAnalysis( + errorClass = "INVALID_CHECK_CONSTRAINT.NONDETERMINISTIC", + messageParameters = Map.empty + ) + } case _ => } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala index 05158fbee3de6..65debcb102285 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala @@ -18,11 +18,12 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.SparkThrowable -import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.ComputeCurrentTime import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND +import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.DefaultColumnAnalyzer import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, MapType, StructType} @@ -46,20 +47,55 @@ object ResolveTableSpec extends Rule[LogicalPlan] { preparedPlan.resolveOperatorsWithPruning(_.containsAnyPattern(COMMAND), ruleId) { case t: CreateTable => - resolveTableSpec(t, t.tableSpec, s => t.copy(tableSpec = s)) + resolveTableSpec(t, t.tableSpec, + fakeRelationFromColumns(t.columns), s => t.copy(tableSpec = s)) case t: CreateTableAsSelect => - resolveTableSpec(t, t.tableSpec, s => t.copy(tableSpec = s)) + resolveTableSpec(t, t.tableSpec, None, s => t.copy(tableSpec = s)) case t: ReplaceTable => - resolveTableSpec(t, t.tableSpec, s => t.copy(tableSpec = s)) + resolveTableSpec(t, t.tableSpec, + fakeRelationFromColumns(t.columns), s => t.copy(tableSpec = s)) case t: ReplaceTableAsSelect => - resolveTableSpec(t, t.tableSpec, s => t.copy(tableSpec = s)) + resolveTableSpec(t, t.tableSpec, None, s => t.copy(tableSpec = s)) } } + private def fakeRelationFromColumns(columns: Seq[ColumnDefinition]): Option[LogicalPlan] = { + val attributeList = columns.map { col => + AttributeReference(col.name, col.dataType)() + } + Some(LocalRelation(attributeList)) + } + + private def analyzeConstraints( + constraints: Seq[TableConstraint], + fakeRelation: LogicalPlan): Seq[TableConstraint] = { + val analyzedExpressions = constraints.map { + case c: CheckConstraint => + val alias = Alias(c.child, c.name)() + val project = Project(Seq(alias), fakeRelation) + val analyzed = DefaultColumnAnalyzer.execute(project) + DefaultColumnAnalyzer.checkAnalysis0(analyzed) + + val analyzedExpression = (analyzed collectFirst { + case Project(Seq(Alias(e: Expression, _)), _) => e + }).get + if (!analyzedExpression.deterministic) { + analyzedExpression.failAnalysis( + errorClass = "INVALID_CHECK_CONSTRAINT.NONDETERMINISTIC", + messageParameters = Map.empty + ) + } + c.withNewChildren(Seq(analyzedExpression)).asInstanceOf[CheckConstraint] + case other => other + } + analyzedExpressions + } + /** Helper method to resolve the table specification within a logical plan. */ private def resolveTableSpec( input: LogicalPlan, tableSpec: TableSpecBase, + fakeRelation: Option[LogicalPlan], withNewSpec: TableSpecBase => LogicalPlan): LogicalPlan = tableSpec match { case u: UnresolvedTableSpec if u.optionExpression.resolved => val newOptions: Seq[(String, String)] = u.optionExpression.options.map { @@ -86,6 +122,12 @@ object ResolveTableSpec extends Rule[LogicalPlan] { } (key, newValue) } + val newConstraints = if (fakeRelation.isDefined) { + analyzeConstraints(u.constraints, fakeRelation.get) + } else { + u.constraints + } + // assert(newConstraints.childrenResolved) val newTableSpec = TableSpec( properties = u.properties, provider = u.provider, @@ -94,7 +136,8 @@ object ResolveTableSpec extends Rule[LogicalPlan] { comment = u.comment, collation = u.collation, serde = u.serde, - external = u.external) + external = u.external, + constraints = newConstraints.map(_.asConstraint(isCreateTable = true))) withNewSpec(newTableSpec) case _ => input diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala new file mode 100644 index 0000000000000..f4c171868ad3a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.antlr.v4.runtime.ParserRuleContext + +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.util.V2ExpressionBuilder +import org.apache.spark.sql.connector.catalog.constraints.Constraint +import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.types.{DataType, StringType} + +trait TableConstraint { + // Convert to a data source v2 constraint + def asConstraint(isCreateTable: Boolean): Constraint + + def withNameAndCharacteristic( + name: String, + c: ConstraintCharacteristic, + ctx: ParserRuleContext): TableConstraint + + def name: String + + def characteristic: ConstraintCharacteristic + + def generateConstraintNameIfNeeded(tableName: String): TableConstraint = { + if (name == null || name.isEmpty) { + this.withNameAndCharacteristic( + name = generateConstraintName(tableName), + c = characteristic, + ctx = null) + } else { + this + } + } + + // Generate a constraint name based on the table name if the name is not specified + protected def generateConstraintName(tableName: String): String + + protected def defaultConstraintCharacteristic: ConstraintCharacteristic + + protected def getCharacteristicValues: (Boolean, Boolean) = { + val rely = characteristic.rely.getOrElse(defaultConstraintCharacteristic.rely.get) + val enforced = characteristic.enforced.getOrElse(defaultConstraintCharacteristic.enforced.get) + (rely, enforced) + } +} + +case class ConstraintCharacteristic(enforced: Option[Boolean], rely: Option[Boolean]) + +object ConstraintCharacteristic { + val empty: ConstraintCharacteristic = ConstraintCharacteristic(None, None) +} + +case class CheckConstraint( + child: Expression, + condition: String, + override val name: String = null, + override val characteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty) + extends UnaryExpression + with Unevaluable + with TableConstraint { + + def asConstraint(isCreateTable: Boolean): Constraint = { + val predicate = new V2ExpressionBuilder(child, true).buildPredicate().orNull + val (rely, enforced) = getCharacteristicValues + // The validation status is set to UNVALIDATED for create table and + // VALID for alter table. + val validateStatus = if (isCreateTable) { + Constraint.ValidationStatus.UNVALIDATED + } else { + Constraint.ValidationStatus.VALID + } + Constraint + .check(name) + .sql(condition) + .predicate(predicate) + .rely(rely) + .enforced(enforced) + .validationStatus(validateStatus) + .build() + } + + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) + + override def withNameAndCharacteristic( + name: String, + c: ConstraintCharacteristic, + ctx: ParserRuleContext): TableConstraint = { + copy(name = name, characteristic = c) + } + + override protected def generateConstraintName(tableName: String): String = { + val base = condition.filter(_.isLetterOrDigit).take(20) + val rand = scala.util.Random.alphanumeric.take(6).mkString + s"${tableName}_chk_${base}_$rand" + } + + override def defaultConstraintCharacteristic: ConstraintCharacteristic = + ConstraintCharacteristic(enforced = Some(true), rely = Some(false)) + + override def sql: String = s"CONSTRAINT $name CHECK ($condition)" + + override def dataType: DataType = StringType +} + +case class PrimaryKeyConstraint( + columns: Seq[String], + override val name: String = null, + override val characteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty) + extends TableConstraint { + + override def asConstraint(isCreateTable: Boolean): Constraint = { + val (rely, enforced) = getCharacteristicValues + Constraint + .primaryKey(name, columns.map(FieldReference.column).toArray) + .rely(rely) + .enforced(enforced) + .validationStatus(Constraint.ValidationStatus.UNVALIDATED) + .build() + } + + override def withNameAndCharacteristic( + name: String, + c: ConstraintCharacteristic, + ctx: ParserRuleContext): TableConstraint = { + if (c.enforced.contains(true)) { + throw new ParseException( + errorClass = "UNSUPPORTED_CONSTRAINT_CHARACTERISTIC", + messageParameters = Map( + "characteristic" -> "ENFORCED", + "constraintType" -> "PRIMARY KEY"), + ctx = ctx + ) + } + copy(name = name, characteristic = c) + } + + override protected def generateConstraintName(tableName: String): String = s"${tableName}_pk" + + override def defaultConstraintCharacteristic: ConstraintCharacteristic = + ConstraintCharacteristic(enforced = Some(false), rely = Some(false)) +} + +case class UniqueConstraint( + columns: Seq[String], + override val name: String = null, + override val characteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty) + extends TableConstraint { + + override def asConstraint(isCreateTable: Boolean): Constraint = { + val (rely, enforced) = getCharacteristicValues + Constraint + .unique(name, columns.map(FieldReference.column).toArray) + .rely(rely) + .enforced(enforced) + .validationStatus(Constraint.ValidationStatus.UNVALIDATED) + .build() + } + + override def withNameAndCharacteristic( + name: String, + c: ConstraintCharacteristic, + ctx: ParserRuleContext): TableConstraint = { + if (c.enforced.contains(true)) { + throw new ParseException( + errorClass = "UNSUPPORTED_CONSTRAINT_CHARACTERISTIC", + messageParameters = Map( + "characteristic" -> "ENFORCED", + "constraintType" -> "UNIQUE"), + ctx = ctx + ) + } + copy(name = name, characteristic = c) + } + + override protected def generateConstraintName(tableName: String): String = { + val base = columns.map(_.filter(_.isLetterOrDigit)).sorted.mkString("_").take(20) + val rand = scala.util.Random.alphanumeric.take(6).mkString + s"${tableName}_uk_${base}_$rand" + } + + override def defaultConstraintCharacteristic: ConstraintCharacteristic = + ConstraintCharacteristic(enforced = Some(false), rely = Some(false)) +} + +case class ForeignKeyConstraint( + override val name: String = null, + childColumns: Seq[String] = Seq.empty, + parentTableId: Seq[String] = Seq.empty, + parentColumns: Seq[String] = Seq.empty, + override val characteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty) + extends TableConstraint { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override def asConstraint(isCreateTable: Boolean): Constraint = { + val (rely, enforced) = getCharacteristicValues + Constraint + .foreignKey(name, + childColumns.map(FieldReference.column).toArray, + parentTableId.asIdentifier, + parentColumns.map(FieldReference.column).toArray) + .rely(rely) + .enforced(enforced) + .validationStatus(Constraint.ValidationStatus.UNVALIDATED) + .build() + } + + override def withNameAndCharacteristic( + name: String, + c: ConstraintCharacteristic, + ctx: ParserRuleContext): TableConstraint = { + if (c.enforced.contains(true)) { + throw new ParseException( + errorClass = "UNSUPPORTED_CONSTRAINT_CHARACTERISTIC", + messageParameters = Map( + "characteristic" -> "ENFORCED", + "constraintType" -> "FOREIGN KEY"), + ctx = ctx + ) + } + copy(name = name, characteristic = c) + } + + override protected def generateConstraintName(tableName: String): String = + s"${tableName}_${parentTableId.last}_fk" + + override def defaultConstraintCharacteristic: ConstraintCharacteristic = + ConstraintCharacteristic(enforced = Some(false), rely = Some(false)) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index cd7af021d8ffb..6bdc5170b68c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -117,6 +117,23 @@ class AstBuilder extends DataTypeAstBuilder } } + /** + * Retrieves the original input text for a given parser context, preserving all whitespace and + * formatting. + * + * ANTLR's default getText method removes whitespace because lexer rules typically skip it. + * This utility method extracts the exact text from the original input stream, using token + * indices. + * + * @param ctx The parser context to retrieve original text from. + * @return The original input text, including all whitespaces and formatting. + */ + private def getOriginalText(ctx: ParserRuleContext): String = { + ctx.getStart.getInputStream.getText( + new Interval(ctx.getStart.getStartIndex, ctx.getStop.getStopIndex) + ) + } + /** * Override the default behavior for all visit methods. This will only return a non-null result * when the context has only one child. This is done because there is no generic method to @@ -3843,24 +3860,25 @@ class AstBuilder extends DataTypeAstBuilder /** * Create top level table schema. */ - protected def createSchema(ctx: ColDefinitionListContext): StructType = { - val columns = Option(ctx).toArray.flatMap(visitColDefinitionList) - StructType(columns.map(_.toV1Column)) + protected def createSchema(ctx: TableElementListContext): StructType = { + val (cols, _) = visitTableElementList(ctx) + StructType(cols.map(_.toV1Column)) } /** * Get CREATE TABLE column definitions. */ override def visitColDefinitionList( - ctx: ColDefinitionListContext): Seq[ColumnDefinition] = withOrigin(ctx) { - ctx.colDefinition().asScala.map(visitColDefinition).toSeq + ctx: ColDefinitionListContext): TableElementList = withOrigin(ctx) { + val (colDefs, constraints) = ctx.colDefinition().asScala.map(visitColDefinition).toSeq.unzip + (colDefs, constraints.flatten) } /** * Get a CREATE TABLE column definition. */ override def visitColDefinition( - ctx: ColDefinitionContext): ColumnDefinition = withOrigin(ctx) { + ctx: ColDefinitionContext): ColumnAndConstraint = withOrigin(ctx) { import ctx._ val name: String = colName.getText @@ -3869,6 +3887,7 @@ class AstBuilder extends DataTypeAstBuilder var defaultExpression: Option[DefaultExpressionContext] = None var generationExpression: Option[GenerationExpressionContext] = None var commentSpec: Option[CommentSpecContext] = None + var columnConstraint: Option[ColumnConstraintDefinitionContext] = None ctx.colDefinitionOption().asScala.foreach { option => if (option.NULL != null) { blockBang(option.errorCapturingNot) @@ -3902,10 +3921,17 @@ class AstBuilder extends DataTypeAstBuilder } commentSpec = Some(spec) } + Option(option.columnConstraintDefinition()).foreach { definition => + if (columnConstraint.isDefined) { + throw QueryParsingErrors.duplicateTableColumnDescriptor( + option, name, "CONSTRAINT") + } + columnConstraint = Some(definition) + } } val dataType = typedVisit[DataType](ctx.dataType) - ColumnDefinition( + val columnDef = ColumnDefinition( name = name, dataType = dataType, nullable = nullable, @@ -3918,8 +3944,61 @@ class AstBuilder extends DataTypeAstBuilder case ctx: IdentityColumnContext => visitIdentityColumn(ctx, dataType) } ) + val constraint = columnConstraint.map(c => visitColumnConstraintDefinition(name, c)) + (columnDef, constraint) + } + + private def visitColumnConstraintDefinition( + columnName: String, + ctx: ColumnConstraintDefinitionContext): TableConstraint = { + withOrigin(ctx) { + val name = if (ctx.name != null) { + ctx.name.getText + } else { + null + } + val constraintCharacteristic = + visitConstraintCharacteristics(ctx.constraintCharacteristic().asScala.toSeq) + val expr = visitColumnConstraint(columnName, ctx.columnConstraint()) + + expr.withNameAndCharacteristic(name, constraintCharacteristic, ctx) + } + } + + private def visitColumnConstraint( + columnName: String, + ctx: ColumnConstraintContext): TableConstraint = withOrigin(ctx) { + val columns = Seq(columnName) + if (ctx.checkConstraint() != null) { + visitCheckConstraint(ctx.checkConstraint()) + } else if (ctx.uniqueSpec() != null) { + visitUniqueSpec(ctx.uniqueSpec(), columns) + } else { + assert(ctx.referenceSpec() != null) + val (tableId, refColumns) = visitReferenceSpec(ctx.referenceSpec()) + ForeignKeyConstraint( + childColumns = columns, + parentTableId = tableId, + parentColumns = refColumns) + } } + private def visitUniqueSpec(ctx: UniqueSpecContext, columns: Seq[String]): TableConstraint = + withOrigin(ctx) { + if (ctx.UNIQUE() != null) { + UniqueConstraint(columns) + } else { + PrimaryKeyConstraint(columns) + } + } + + override def visitReferenceSpec(ctx: ReferenceSpecContext): (Seq[String], Seq[String]) = + withOrigin(ctx) { + val tableId = visitMultipartIdentifier(ctx.multipartIdentifier()) + val refColumns = visitIdentifierList(ctx.parentColumns) + (tableId, refColumns) + } + /** * Create a location string. */ @@ -3946,9 +4025,7 @@ class AstBuilder extends DataTypeAstBuilder // use `Expression.sql` to avoid storing incorrect text caused by bugs in any expression's // `sql` method. Note: `exprCtx.getText` returns a string without spaces, so we need to // get the text from the underlying char stream instead. - val start = exprCtx.getStart.getStartIndex - val end = exprCtx.getStop.getStopIndex - val originalSQL = exprCtx.getStart.getInputStream.getText(new Interval(start, end)) + val originalSQL = getOriginalText(exprCtx) DefaultValueExpression(expr, originalSQL) } @@ -4140,6 +4217,10 @@ class AstBuilder extends DataTypeAstBuilder Seq[Transform], Seq[ColumnDefinition], Option[BucketSpec], Map[String, String], OptionList, Option[String], Option[String], Option[String], Option[SerdeInfo], Option[ClusterBySpec]) + type ColumnAndConstraint = (ColumnDefinition, Option[TableConstraint]) + + type TableElementList = (Seq[ColumnDefinition], Seq[TableConstraint]) + /** * Validate a create table statement and return the [[TableIdentifier]]. */ @@ -4680,6 +4761,33 @@ class AstBuilder extends DataTypeAstBuilder } } + override def visitTableElementList(ctx: TableElementListContext): TableElementList = + withOrigin(ctx) { + if (ctx == null) { + return (Nil, Nil) + } + val columnDefs = new ArrayBuffer[ColumnDefinition]() + val constraints = new ArrayBuffer[TableConstraint]() + + ctx.tableElement().asScala.foreach { element => + if (element.tableConstraintDefinition() != null) { + constraints += visitTableConstraintDefinition(element.tableConstraintDefinition()) + } else { + val (colDef, constraintOpt) = visitColDefinition(element.colDefinition()) + columnDefs += colDef + constraintOpt.foreach(constraints += _) + } + } + + // check if there are multiple primary keys + val primaryKeys = constraints.filter(_.isInstanceOf[PrimaryKeyConstraint]) + if (primaryKeys.size > 1) { + throw QueryParsingErrors.multiplePrimaryKeysError(ctx) + } + + (columnDefs.toSeq, constraints.toSeq) + } + /** * Create a table, returning a [[CreateTable]] or [[CreateTableAsSelect]] logical plan. * @@ -4714,10 +4822,12 @@ class AstBuilder extends DataTypeAstBuilder val (identifierContext, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) - val columns = Option(ctx.colDefinitionList()).map(visitColDefinitionList).getOrElse(Nil) + val (columns, constraints) = visitTableElementList(ctx.tableElementList()) + val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText) val (partTransforms, partCols, bucketSpec, properties, options, location, comment, - collation, serdeInfo, clusterBySpec) = visitCreateTableClauses(ctx.createTableClauses()) + collation, serdeInfo, clusterBySpec) = + visitCreateTableClauses(ctx.createTableClauses()) if (provider.isDefined && serdeInfo.isDefined) { invalidStatement(s"CREATE TABLE ... USING ... ${serdeInfo.get.describe}", ctx) @@ -4734,33 +4844,36 @@ class AstBuilder extends DataTypeAstBuilder bucketSpec.map(_.asTransform) ++ clusterBySpec.map(_.asTransform) - val tableSpec = UnresolvedTableSpec(properties, provider, options, location, comment, - collation, serdeInfo, external) - - Option(ctx.query).map(plan) match { - case Some(_) if columns.nonEmpty => - operationNotAllowed( - "Schema may not be specified in a Create Table As Select (CTAS) statement", - ctx) + withIdentClause(identifierContext, identifiers => { + val namedConstraints = + constraints.map(c => c.generateConstraintNameIfNeeded(identifiers.last)) + val tableSpec = UnresolvedTableSpec(properties, provider, options, location, comment, + collation, serdeInfo, external, namedConstraints) + val identifier = withOrigin(identifierContext) { + UnresolvedIdentifier(identifiers) + } + Option(ctx.query).map(plan) match { + case Some(_) if columns.nonEmpty => + operationNotAllowed( + "Schema may not be specified in a Create Table As Select (CTAS) statement", + ctx) - case Some(_) if partCols.nonEmpty => - // non-reference partition columns are not allowed because schema can't be specified - operationNotAllowed( - "Partition column types may not be specified in Create Table As Select (CTAS)", - ctx) + case Some(_) if partCols.nonEmpty => + // non-reference partition columns are not allowed because schema can't be specified + operationNotAllowed( + "Partition column types may not be specified in Create Table As Select (CTAS)", + ctx) - case Some(query) => - CreateTableAsSelect(withIdentClause(identifierContext, UnresolvedIdentifier(_)), - partitioning, query, tableSpec, Map.empty, ifNotExists) + case Some(query) => + CreateTableAsSelect(identifier, partitioning, query, tableSpec, Map.empty, ifNotExists) - case _ => - // Note: table schema includes both the table columns list and the partition columns - // with data type. - val allColumns = columns ++ partCols - CreateTable( - withIdentClause(identifierContext, UnresolvedIdentifier(_)), - allColumns, partitioning, tableSpec, ignoreIfExists = ifNotExists) - } + case _ => + // Note: table schema includes both the table columns list and the partition columns + // with data type. + val allColumns = columns ++ partCols + CreateTable(identifier, allColumns, partitioning, tableSpec, ignoreIfExists = ifNotExists) + } + }) } /** @@ -4796,7 +4909,7 @@ class AstBuilder extends DataTypeAstBuilder val orCreate = ctx.replaceTableHeader().CREATE() != null val (partTransforms, partCols, bucketSpec, properties, options, location, comment, collation, serdeInfo, clusterBySpec) = visitCreateTableClauses(ctx.createTableClauses()) - val columns = Option(ctx.colDefinitionList()).map(visitColDefinitionList).getOrElse(Nil) + val (columns, constraints) = visitTableElementList(ctx.tableElementList()) val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText) if (provider.isDefined && serdeInfo.isDefined) { @@ -4808,34 +4921,38 @@ class AstBuilder extends DataTypeAstBuilder bucketSpec.map(_.asTransform) ++ clusterBySpec.map(_.asTransform) - val tableSpec = UnresolvedTableSpec(properties, provider, options, location, comment, - collation, serdeInfo, external = false) - - Option(ctx.query).map(plan) match { - case Some(_) if columns.nonEmpty => - operationNotAllowed( - "Schema may not be specified in a Replace Table As Select (RTAS) statement", - ctx) + val identifierContext = ctx.replaceTableHeader().identifierReference() + withIdentClause(identifierContext, identifiers => { + val namedConstraints = + constraints.map(c => c.generateConstraintNameIfNeeded(identifiers.last)) + val tableSpec = UnresolvedTableSpec(properties, provider, options, location, comment, + collation, serdeInfo, external = false, namedConstraints) + val identifier = withOrigin(identifierContext) { + UnresolvedIdentifier(identifiers) + } + Option(ctx.query).map(plan) match { + case Some(_) if columns.nonEmpty => + operationNotAllowed( + "Schema may not be specified in a Replace Table As Select (RTAS) statement", + ctx) - case Some(_) if partCols.nonEmpty => - // non-reference partition columns are not allowed because schema can't be specified - operationNotAllowed( - "Partition column types may not be specified in Replace Table As Select (RTAS)", - ctx) + case Some(_) if partCols.nonEmpty => + // non-reference partition columns are not allowed because schema can't be specified + operationNotAllowed( + "Partition column types may not be specified in Replace Table As Select (RTAS)", + ctx) - case Some(query) => - ReplaceTableAsSelect( - withIdentClause(ctx.replaceTableHeader.identifierReference(), UnresolvedIdentifier(_)), - partitioning, query, tableSpec, writeOptions = Map.empty, orCreate = orCreate) + case Some(query) => + ReplaceTableAsSelect(identifier, partitioning, query, tableSpec, + writeOptions = Map.empty, orCreate = orCreate) - case _ => - // Note: table schema includes both the table columns list and the partition columns - // with data type. - val allColumns = columns ++ partCols - ReplaceTable( - withIdentClause(ctx.replaceTableHeader.identifierReference(), UnresolvedIdentifier(_)), - allColumns, partitioning, tableSpec, orCreate = orCreate) - } + case _ => + // Note: table schema includes both the table columns list and the partition columns + // with data type. + val allColumns = columns ++ partCols + ReplaceTable(identifier, allColumns, partitioning, tableSpec, orCreate = orCreate) + } + }) } /** @@ -5237,6 +5354,112 @@ class AstBuilder extends DataTypeAstBuilder AlterTableCollation(table, visitCollationSpec(ctx.collationSpec())) } + override def visitTableConstraintDefinition( + ctx: TableConstraintDefinitionContext): TableConstraint = + withOrigin(ctx) { + val name = if (ctx.name != null) { + ctx.name.getText + } else { + null + } + val constraintCharacteristic = + visitConstraintCharacteristics(ctx.constraintCharacteristic().asScala.toSeq) + val expr = + visitTableConstraint(ctx.tableConstraint()).asInstanceOf[TableConstraint] + + expr.withNameAndCharacteristic(name, constraintCharacteristic, ctx) + } + + override def visitCheckConstraint(ctx: CheckConstraintContext): CheckConstraint = + withOrigin(ctx) { + val condition = getOriginalText(ctx.expr) + CheckConstraint( + child = expression(ctx.booleanExpression()), + condition = condition) + } + + + override def visitUniqueConstraint(ctx: UniqueConstraintContext): TableConstraint = + withOrigin(ctx) { + val columns = visitIdentifierList(ctx.identifierList()) + visitUniqueSpec(ctx.uniqueSpec(), columns) + } + + override def visitForeignKeyConstraint(ctx: ForeignKeyConstraintContext): TableConstraint = + withOrigin(ctx) { + val columns = visitIdentifierList(ctx.identifierList()) + val (parentTableId, parentColumns) = visitReferenceSpec(ctx.referenceSpec()) + ForeignKeyConstraint( + childColumns = columns, + parentTableId = parentTableId, + parentColumns = parentColumns) + } + + private def visitConstraintCharacteristics( + constraintCharacteristics: Seq[ConstraintCharacteristicContext]): ConstraintCharacteristic = { + var enforcement: Option[String] = None + var rely: Option[String] = None + constraintCharacteristics.foreach { + case e if e.enforcedCharacteristic() != null => + val text = getOriginalText(e.enforcedCharacteristic()).toUpperCase(Locale.ROOT) + if (enforcement.isDefined) { + val invalidCharacteristics = s"${enforcement.get}, $text" + throw QueryParsingErrors.invalidConstraintCharacteristics( + e.enforcedCharacteristic(), invalidCharacteristics) + } else { + enforcement = Some(text) + } + + case r if r.relyCharacteristic() != null => + val text = r.relyCharacteristic().getText.toUpperCase(Locale.ROOT) + if (rely.isDefined) { + val invalidCharacteristics = s"${rely.get}, $text" + throw QueryParsingErrors.invalidConstraintCharacteristics( + r.relyCharacteristic(), invalidCharacteristics) + } else { + rely = Some(text) + } + } + ConstraintCharacteristic(enforcement.map(_ == "ENFORCED"), rely.map(_ == "RELY")) + } + + /** + * Parse an [[AlterTableCommand]] with table constraint. + * + * For example: + * {{{ + * ALTER TABLE table1 CONSTRAINT constraint_name CHECK (a > 0) + * }}} + */ + override def visitAddTableConstraint(ctx: AddTableConstraintContext): LogicalPlan = + withOrigin(ctx) { + val tableConstraint = visitTableConstraintDefinition(ctx.tableConstraintDefinition()) + withIdentClause(ctx.identifierReference, identifiers => { + val table = UnresolvedTable(identifiers, "ALTER TABLE ... ADD CONSTRAINT") + val namedConstraint = tableConstraint.generateConstraintNameIfNeeded(identifiers.last) + AddConstraint(table, namedConstraint) + }) + } + + /** + * Parse a [[DropConstraint]] command. + * + * For example: + * {{{ + * ALTER TABLE table1 DROP CONSTRAINT constraint_name + * }}} + */ + override def visitDropTableConstraint(ctx: DropTableConstraintContext): LogicalPlan = + withOrigin(ctx) { + val table = createUnresolvedTable( + ctx.identifierReference, "ALTER TABLE ... DROP CONSTRAINT") + DropConstraint( + table, + ctx.name.getText, + ifExists = ctx.EXISTS() != null, + cascade = ctx.CASCADE() != null) + } + /** * Parse [[SetViewProperties]] or [[SetTableProperties]] commands. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala index a0def801ee6f7..845c6a64a5694 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.analysis.{FieldName, FieldPosition, ResolvedFieldName, UnresolvedException} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.catalog.ClusterBySpec -import org.apache.spark.sql.catalyst.expressions.{Expression, Unevaluable} +import org.apache.spark.sql.catalyst.expressions.{Expression, TableConstraint, Unevaluable} import org.apache.spark.sql.catalyst.util.{ResolveDefaultColumns, TypeUtils} import org.apache.spark.sql.connector.catalog.{TableCatalog, TableChange} import org.apache.spark.sql.errors.QueryCompilationErrors @@ -288,3 +288,31 @@ case class AlterTableCollation( protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = copy(table = newChild) } + +/** + * The logical plan of the ALTER TABLE ... ADD CONSTRAINT command. + */ +case class AddConstraint( + table: LogicalPlan, + tableConstraint: TableConstraint) extends AlterTableCommand { + override def changes: Seq[TableChange] = { + val constraint = tableConstraint.asConstraint(isCreateTable = false) + Seq(TableChange.addConstraint(constraint, constraint.enforced())) + } + + protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = copy(table = newChild) +} + +/** + * The logical plan of the ALTER TABLE ... DROP CONSTRAINT command. + */ +case class DropConstraint( + table: LogicalPlan, + name: String, + ifExists: Boolean, + cascade: Boolean) extends AlterTableCommand { + override def changes: Seq[TableChange] = + Seq(TableChange.dropConstraint(name, ifExists, cascade)) + + protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = copy(table = newChild) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 1056a30c5f758..07b6e912b584d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -23,14 +23,15 @@ import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AssignmentUtils, import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} import org.apache.spark.sql.catalyst.catalog.{FunctionResource, RoutineLanguage} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, MetadataAttribute, UnaryExpression, Unevaluable, V2ExpressionUtils} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, truncatedString, CharVarcharUtils, ReplaceDataProjections, RowDeltaUtils, WriteDeltaProjections} +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.TypeUtils.{ordinalNumber, toSQLExpr} import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{IdentifierHelper, MultipartIdentifierHelper} +import org.apache.spark.sql.connector.catalog.constraints.Constraint import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.expressions.filter.Predicate @@ -1520,7 +1521,9 @@ case class UnresolvedTableSpec( comment: Option[String], collation: Option[String], serde: Option[SerdeInfo], - external: Boolean) extends UnaryExpression with Unevaluable with TableSpecBase { + external: Boolean, + constraints: Seq[TableConstraint]) + extends UnaryExpression with Unevaluable with TableSpecBase { override def dataType: DataType = throw new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3113") @@ -1566,9 +1569,11 @@ case class TableSpec( comment: Option[String], collation: Option[String], serde: Option[SerdeInfo], - external: Boolean) extends TableSpecBase { + external: Boolean, + constraints: Seq[Constraint] = Seq.empty) extends TableSpecBase { def withNewLocation(newLocation: Option[String]): TableSpec = { - TableSpec(properties, provider, options, newLocation, comment, collation, serde, external) + TableSpec(properties, provider, options, newLocation, + comment, collation, serde, external, constraints) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index 97cc263c56c5f..3dc8d16d3eef0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.connector.catalog import java.util -import java.util.Collections +import java.util.{Collections, Locale} import scala.jdk.CollectionConverters._ @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{SerdeInfo, TableSpec} import org.apache.spark.sql.catalyst.util.{GeneratedColumn, IdentityColumn} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ import org.apache.spark.sql.connector.catalog.TableChange._ +import org.apache.spark.sql.connector.catalog.constraints.Constraint import org.apache.spark.sql.connector.catalog.functions.UnboundFunction import org.apache.spark.sql.connector.expressions.{ClusterByTransform, LiteralValue, Transform} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation @@ -194,6 +195,49 @@ private[sql] object CatalogV2Util { newPartitioning } + /** + * Apply Table Constraints changes to an existing set of constraints and return the result. + */ + def applyConstraintChanges( + table: Table, + changes: Seq[TableChange]): Array[Constraint] = { + val constraints = table.constraints() + + def findExistingConstraint(name: String): Option[Constraint] = { + constraints.find(_.name.toLowerCase(Locale.ROOT) == name.toLowerCase(Locale.ROOT)) + } + + changes.foldLeft(constraints) { (constraints, change) => + change match { + case add: AddConstraint => + val newConstraint = add.getConstraint + val existingConstraint = findExistingConstraint(newConstraint.name) + if (existingConstraint.isDefined) { + throw new AnalysisException( + errorClass = "CONSTRAINT_ALREADY_EXISTS", + messageParameters = + Map("constraintName" -> existingConstraint.get.name, + "oldConstraint" -> existingConstraint.get.toDDL)) + } + constraints :+ newConstraint + + case drop: DropConstraint => + val existingConstraint = findExistingConstraint(drop.getName) + if (existingConstraint.isEmpty && !drop.isIfExists) { + throw new AnalysisException( + errorClass = "CONSTRAINT_DOES_NOT_EXIST", + messageParameters = + Map("constraintName" -> drop.getName, "tableName" -> table.name())) + } + constraints.filterNot(_.name == drop.getName) + + case _ => + // ignore non-constraint changes + constraints + } + }.toArray + } + /** * Apply schema changes to a schema and return the result. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala index 133670d5fcced..0afdffb8b5e7c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala @@ -30,7 +30,8 @@ import org.apache.spark.util.ArrayImplicits._ class CreateTablePartitioningValidationSuite extends AnalysisTest { val tableSpec = - UnresolvedTableSpec(Map.empty, None, OptionList(Seq.empty), None, None, None, None, false) + UnresolvedTableSpec(Map.empty, None, OptionList(Seq.empty), None, None, None, None, false, + Seq.empty) test("CreateTableAsSelect: fail missing top-level column") { val plan = CreateTableAsSelect( UnresolvedIdentifier(Array("table_name").toImmutableArraySeq), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index c8d2de9c6b8de..1589bcb8a3d7e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -2705,7 +2705,7 @@ class DDLParserSuite extends AnalysisTest { val createTableResult = CreateTable(UnresolvedIdentifier(Seq("my_tab")), columnsWithDefaultValue, Seq.empty[Transform], UnresolvedTableSpec(Map.empty[String, String], Some("parquet"), - OptionList(Seq.empty), None, None, None, None, false), false) + OptionList(Seq.empty), None, None, None, None, false, Seq.empty), false) // Parse the CREATE TABLE statement twice, swapping the order of the NOT NULL and DEFAULT // options, to make sure that the parser accepts any ordering of these options. comparePlans(parsePlan( @@ -2718,7 +2718,7 @@ class DDLParserSuite extends AnalysisTest { "b STRING NOT NULL DEFAULT 'abc') USING parquet"), ReplaceTable(UnresolvedIdentifier(Seq("my_tab")), columnsWithDefaultValue, Seq.empty[Transform], UnresolvedTableSpec(Map.empty[String, String], Some("parquet"), - OptionList(Seq.empty), None, None, None, None, false), false)) + OptionList(Seq.empty), None, None, None, None, false, Seq.empty), false)) // These ALTER TABLE statements should parse successfully. comparePlans( parsePlan("ALTER TABLE t1 ADD COLUMN x int NOT NULL DEFAULT 42"), @@ -2881,12 +2881,12 @@ class DDLParserSuite extends AnalysisTest { "CREATE TABLE my_tab(a INT, b INT NOT NULL GENERATED ALWAYS AS (a+1)) USING parquet"), CreateTable(UnresolvedIdentifier(Seq("my_tab")), columnsWithGenerationExpr, Seq.empty[Transform], UnresolvedTableSpec(Map.empty[String, String], Some("parquet"), - OptionList(Seq.empty), None, None, None, None, false), false)) + OptionList(Seq.empty), None, None, None, None, false, Seq.empty), false)) comparePlans(parsePlan( "REPLACE TABLE my_tab(a INT, b INT NOT NULL GENERATED ALWAYS AS (a+1)) USING parquet"), ReplaceTable(UnresolvedIdentifier(Seq("my_tab")), columnsWithGenerationExpr, Seq.empty[Transform], UnresolvedTableSpec(Map.empty[String, String], Some("parquet"), - OptionList(Seq.empty), None, None, None, None, false), false)) + OptionList(Seq.empty), None, None, None, None, false, Seq.empty), false)) // Two generation expressions checkError( exception = parseException("CREATE TABLE my_tab(a INT, " + @@ -2957,7 +2957,8 @@ class DDLParserSuite extends AnalysisTest { None, None, None, - false + false, + Seq.empty ), false ) @@ -2980,7 +2981,8 @@ class DDLParserSuite extends AnalysisTest { None, None, None, - false + false, + Seq.empty ), false ) @@ -3273,7 +3275,7 @@ class DDLParserSuite extends AnalysisTest { Seq(ColumnDefinition("c", StringType)), Seq.empty[Transform], UnresolvedTableSpec(Map.empty[String, String], Some("parquet"), OptionList(Seq.empty), - None, None, Some(collation), None, false), false)) + None, None, Some(collation), None, false, Seq.empty), false)) } } @@ -3285,7 +3287,7 @@ class DDLParserSuite extends AnalysisTest { Seq(ColumnDefinition("c", StringType)), Seq.empty[Transform], UnresolvedTableSpec(Map.empty[String, String], Some("parquet"), OptionList(Seq.empty), - None, None, Some(collation), None, false), false)) + None, None, Some(collation), None, false, Seq.empty), false)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/ConstraintSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/ConstraintSuite.scala new file mode 100644 index 0000000000000..6b4bea3b14ccd --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/ConstraintSuite.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.connector.catalog.constraints.Constraint +import org.apache.spark.sql.connector.catalog.constraints.Constraint.ValidationStatus +import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, LiteralValue, NamedReference} +import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.apache.spark.sql.types.IntegerType + +class ConstraintSuite extends SparkFunSuite { + + test("CHECK constraint toDDL") { + val con1 = Constraint.check("con1") + .sql("id > 10") + .enforced(true) + .validationStatus(ValidationStatus.VALID) + .rely(true) + .build() + assert(con1.toDDL == "CONSTRAINT con1 CHECK id > 10 ENFORCED VALID RELY") + + val con2 = Constraint.check("con2") + .predicate( + new Predicate( + "=", + Array[Expression]( + FieldReference(Seq("a", "b.c", "d")), + LiteralValue(1, IntegerType)))) + .enforced(false) + .validationStatus(ValidationStatus.VALID) + .rely(true) + .build() + assert(con2.toDDL == "CONSTRAINT con2 CHECK a.`b.c`.d = 1 NOT ENFORCED VALID RELY") + + val con3 = Constraint.check("con3") + .sql("a.b.c <=> 1") + .predicate( + new Predicate( + "<=>", + Array[Expression]( + FieldReference(Seq("a", "b", "c")), + LiteralValue(1, IntegerType)))) + .enforced(false) + .validationStatus(ValidationStatus.INVALID) + .rely(false) + .build() + assert(con3.toDDL == "CONSTRAINT con3 CHECK a.b.c <=> 1 NOT ENFORCED INVALID NORELY") + + val con4 = Constraint.check("con4").sql("a = 1").build() + assert(con4.toDDL == "CONSTRAINT con4 CHECK a = 1 ENFORCED UNVALIDATED NORELY") + } + + test("UNIQUE constraint toDDL") { + val con1 = Constraint.unique( + "con1", + Array[NamedReference](FieldReference(Seq("a", "b", "c")), FieldReference(Seq("d")))) + .enforced(false) + .validationStatus(ValidationStatus.UNVALIDATED) + .rely(true) + .build() + assert(con1.toDDL == "CONSTRAINT con1 UNIQUE (a.b.c, d) NOT ENFORCED UNVALIDATED RELY") + + val con2 = Constraint.unique( + "con2", + Array[NamedReference](FieldReference(Seq("a.b", "x", "y")), FieldReference(Seq("d")))) + .enforced(false) + .validationStatus(ValidationStatus.VALID) + .rely(true) + .build() + assert(con2.toDDL == "CONSTRAINT con2 UNIQUE (`a.b`.x.y, d) NOT ENFORCED VALID RELY") + } + + test("PRIMARY KEY constraint toDDL") { + val pk1 = Constraint.primaryKey( + "pk1", + Array[NamedReference](FieldReference(Seq("a", "b", "c")), FieldReference(Seq("d")))) + .enforced(true) + .validationStatus(ValidationStatus.VALID) + .rely(true) + .build() + assert(pk1.toDDL == "CONSTRAINT pk1 PRIMARY KEY (a.b.c, d) ENFORCED VALID RELY") + + val pk2 = Constraint.primaryKey( + "pk2", + Array[NamedReference](FieldReference(Seq("x.y", "z")), FieldReference(Seq("id")))) + .enforced(false) + .validationStatus(ValidationStatus.INVALID) + .rely(false) + .build() + assert(pk2.toDDL == "CONSTRAINT pk2 PRIMARY KEY (`x.y`.z, id) NOT ENFORCED INVALID NORELY") + } + + test("FOREIGN KEY constraint toDDL") { + val fk1 = Constraint.foreignKey( + "fk1", + Array[NamedReference](FieldReference(Seq("col1")), FieldReference(Seq("col2"))), + Identifier.of(Array("schema"), "table"), + Array[NamedReference](FieldReference(Seq("ref_col1")), FieldReference(Seq("ref_col2")))) + .enforced(true) + .validationStatus(ValidationStatus.VALID) + .rely(true) + .build() + assert(fk1.toDDL == "CONSTRAINT fk1 FOREIGN KEY (col1, col2) " + + "REFERENCES schema.table (ref_col1, ref_col2) " + + "ENFORCED VALID RELY") + + val fk2 = Constraint.foreignKey( + "fk2", + Array[NamedReference](FieldReference(Seq("x.y", "z"))), + Identifier.of(Array.empty[String], "other_table"), + Array[NamedReference](FieldReference(Seq("other_id")))) + .enforced(false) + .validationStatus(ValidationStatus.INVALID) + .rely(false) + .build() + assert(fk2.toDDL == "CONSTRAINT fk2 FOREIGN KEY (`x.y`.z) " + + "REFERENCES other_table (other_id) " + + "NOT ENFORCED INVALID NORELY") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index c27b8fea059f7..e8f2cf6979e86 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.connector.catalog import java.util +import org.apache.spark.sql.connector.catalog.constraints.Constraint import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} import org.apache.spark.sql.connector.expressions.{SortOrder, Transform} import org.apache.spark.sql.connector.write.{LogicalWriteInfo, SupportsOverwrite, WriteBuilder, WriterCommitMessage} @@ -40,7 +41,8 @@ class InMemoryTable( numPartitions: Option[Int] = None, advisoryPartitionSize: Option[Long] = None, isDistributionStrictlyRequired: Boolean = true, - override val numRowsPerSplit: Int = Int.MaxValue) + override val numRowsPerSplit: Int = Int.MaxValue, + override val constraints: Array[Constraint] = Array.empty) extends InMemoryBaseTable(name, schema, partitioning, properties, distribution, ordering, numPartitions, advisoryPartitionSize, isDistributionStrictlyRequired, numRowsPerSplit) with SupportsDelete { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala index 56ed3bb243e19..5397e6cfd999f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala @@ -23,6 +23,7 @@ import java.util.concurrent.ConcurrentHashMap import scala.jdk.CollectionConverters._ import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NonEmptyNamespaceException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.connector.catalog.constraints.Constraint import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} import org.apache.spark.sql.connector.expressions.{SortOrder, Transform} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -87,11 +88,13 @@ class BasicInMemoryTableCatalog extends TableCatalog { ident: Identifier, columns: Array[Column], partitions: Array[Transform], - properties: util.Map[String, String]): Table = { + properties: util.Map[String, String], + constraints: Array[Constraint]): Table = { createTable(ident, columns, partitions, properties, Distributions.unspecified(), - Array.empty, None, None) + Array.empty, None, None, constraints) } + // scalastyle:off argcount def createTable( ident: Identifier, columns: Array[Column], @@ -101,8 +104,10 @@ class BasicInMemoryTableCatalog extends TableCatalog { ordering: Array[SortOrder], requiredNumPartitions: Option[Int], advisoryPartitionSize: Option[Long], + constraints: Array[Constraint], distributionStrictlyRequired: Boolean = true, numRowsPerSplit: Int = Int.MaxValue): Table = { + // scalastyle:on argcount val schema = CatalogV2Util.v2ColumnsToStructType(columns) if (tables.containsKey(ident)) { throw new TableAlreadyExistsException(ident.asMultipartIdentifier) @@ -113,7 +118,7 @@ class BasicInMemoryTableCatalog extends TableCatalog { val tableName = s"$name.${ident.quoted}" val table = new InMemoryTable(tableName, schema, partitions, properties, distribution, ordering, requiredNumPartitions, advisoryPartitionSize, distributionStrictlyRequired, - numRowsPerSplit) + numRowsPerSplit, constraints) tables.put(ident, table) namespaces.putIfAbsent(ident.namespace.toList, Map()) table @@ -124,13 +129,19 @@ class BasicInMemoryTableCatalog extends TableCatalog { val properties = CatalogV2Util.applyPropertiesChanges(table.properties, changes) val schema = CatalogV2Util.applySchemaChanges(table.schema, changes, None, "ALTER TABLE") val finalPartitioning = CatalogV2Util.applyClusterByChanges(table.partitioning, schema, changes) + val constraints = CatalogV2Util.applyConstraintChanges(table, changes) // fail if the last column in the schema was dropped if (schema.fields.isEmpty) { throw new IllegalArgumentException(s"Cannot drop all fields") } - val newTable = new InMemoryTable(table.name, schema, finalPartitioning, properties) + val newTable = new InMemoryTable( + table.name, + schema, + finalPartitioning, + properties, + constraints = constraints) .withData(table.data) tables.put(ident, newTable) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala index aff65496b763b..3b4f6475a6bca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala @@ -688,7 +688,8 @@ class Catalog(sparkSession: SparkSession) extends catalog.Catalog { comment = { if (description.isEmpty) None else Some(description) }, collation = None, serde = None, - external = tableType == CatalogTableType.EXTERNAL) + external = tableType == CatalogTableType.EXTERNAL, + constraints = Seq.empty) val plan = CreateTable( name = UnresolvedIdentifier(ident), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala index b423c89fff3db..501b4985128dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala @@ -213,7 +213,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram comment = extraOptions.get(TableCatalog.PROP_COMMENT), collation = extraOptions.get(TableCatalog.PROP_COLLATION), serde = None, - external = false) + external = false, + constraints = Seq.empty) runCommand(df.sparkSession) { CreateTableAsSelect( UnresolvedIdentifier( @@ -478,7 +479,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram comment = extraOptions.get(TableCatalog.PROP_COMMENT), collation = extraOptions.get(TableCatalog.PROP_COLLATION), serde = None, - external = false) + external = false, + constraints = Seq.empty) ReplaceTableAsSelect( UnresolvedIdentifier(nameParts), partitioningAsV2, @@ -499,7 +501,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram comment = extraOptions.get(TableCatalog.PROP_COMMENT), collation = extraOptions.get(TableCatalog.PROP_COLLATION), serde = None, - external = false) + external = false, + constraints = Seq.empty) CreateTableAsSelect( UnresolvedIdentifier(nameParts), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala index e4efee93d2a08..01b3619f12363 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala @@ -154,7 +154,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) comment = None, collation = None, serde = None, - external = false) + external = false, + constraints = Seq.empty) runCommand( CreateTableAsSelect( UnresolvedIdentifier(tableName), @@ -220,7 +221,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) comment = None, collation = None, serde = None, - external = false) + external = false, + constraints = Seq.empty) runCommand(ReplaceTableAsSelect( UnresolvedIdentifier(tableName), partitioning.getOrElse(Seq.empty) ++ clustering, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala index 96e8755577542..471c5feadaabc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala @@ -175,7 +175,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) extends streaming.D None, None, None, - external = false) + external = false, + constraints = Seq.empty) val cmd = CreateTable( UnresolvedIdentifier(originalMultipartIdentifier), ds.schema.asNullable.map(ColumnDefinition.fromV1Column(_, parser)), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 8859b7b421b3c..b40cb82e9cc08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -365,7 +365,7 @@ class SparkSqlAstBuilder extends AstBuilder { visitCreateTableClauses(ctx.createTableClauses()) val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText).getOrElse( throw QueryParsingErrors.createTempTableNotSpecifyProviderError(ctx)) - val schema = Option(ctx.colDefinitionList()).map(createSchema) + val schema = Option(ctx.tableElementList()).map(createSchema) logWarning(s"CREATE TEMPORARY TABLE ... USING ... is deprecated, please use " + "CREATE TEMPORARY VIEW ... USING ... instead") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala index f55fbafe11ddb..25e5292f36723 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala @@ -43,7 +43,8 @@ case class CreateTableExec( override protected def run(): Seq[InternalRow] = { if (!catalog.tableExists(identifier)) { try { - catalog.createTable(identifier, columns, partitioning.toArray, tableProperties.asJava) + catalog.createTable(identifier, columns, partitioning.toArray, tableProperties.asJava, + tableSpec.constraints.toArray) } catch { case _: TableAlreadyExistsException if ignoreIfExists => logWarning( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala index 894a3a10d4193..9c0122c4cd318 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala @@ -48,7 +48,8 @@ case class ReplaceTableExec( } else if (!orCreate) { throw QueryCompilationErrors.cannotReplaceMissingTableError(ident) } - catalog.createTable(ident, columns, partitioning.toArray, tableProperties.asJava) + catalog.createTable(ident, columns, partitioning.toArray, tableProperties.asJava, + tableSpec.constraints.toArray) Seq.empty } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeNullChecksV2Writes.scala b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeNullChecksV2Writes.scala index b48ff7121c767..cf5212351b529 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeNullChecksV2Writes.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeNullChecksV2Writes.scala @@ -21,6 +21,7 @@ import java.util.Collections import org.apache.spark.{SparkConf, SparkRuntimeException} import org.apache.spark.sql.connector.catalog.{Column => ColumnV2, Identifier, InMemoryTableCatalog} +import org.apache.spark.sql.connector.catalog.constraints.Constraint import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} @@ -214,7 +215,8 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS ColumnV2.create("i", IntegerType), ColumnV2.create("arr", ArrayType(structType, containsNull = false))), partitions = Array.empty[Transform], - properties = Collections.emptyMap[String, String]) + properties = Collections.emptyMap[String, String], + constraints = Array.empty[Constraint]) if (byName) { val inputDF = sql("SELECT 1 AS i, null AS arr") @@ -261,7 +263,8 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS ColumnV2.create("i", IntegerType), ColumnV2.create("arr", ArrayType(structType, containsNull = true))), partitions = Array.empty[Transform], - properties = Collections.emptyMap[String, String]) + properties = Collections.emptyMap[String, String], + constraints = Array.empty[Constraint]) if (byName) { val inputDF = sql( @@ -315,7 +318,8 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS ColumnV2.create("i", IntegerType), ColumnV2.create("m", MapType(IntegerType, IntegerType, valueContainsNull = false))), partitions = Array.empty[Transform], - properties = Collections.emptyMap[String, String]) + properties = Collections.emptyMap[String, String], + constraints = Array.empty[Constraint]) if (byName) { val inputDF = sql("SELECT 1 AS i, null AS m") @@ -354,7 +358,8 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS ColumnV2.create("i", IntegerType), ColumnV2.create("m", MapType(structType, structType, valueContainsNull = true))), partitions = Array.empty[Transform], - properties = Collections.emptyMap[String, String]) + properties = Collections.emptyMap[String, String], + constraints = Array.empty[Constraint]) if (byName) { val inputDF = sql("SELECT 1 AS i, map(named_struct('x', 1, 'y', 1), null) AS m") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index c24f52bd93070..412066a1a41ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -238,7 +238,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { catalog: InMemoryTableCatalog = catalog): Unit = { catalog.createTable(Identifier.of(Array("ns"), table), columns, partitions, emptyProps, Distributions.unspecified(), Array.empty, None, None, - numRowsPerSplit = 1) + Array.empty, numRowsPerSplit = 1) } private val customers: String = "customers" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala index a1089b4291e90..300492577b1fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala @@ -54,7 +54,7 @@ class V2CommandsCaseSensitivitySuite Seq("ID", "iD").foreach { ref => val tableSpec = UnresolvedTableSpec(Map.empty, None, OptionList(Seq.empty), - None, None, None, None, false) + None, None, None, None, false, Seq.empty) val plan = CreateTableAsSelect( UnresolvedIdentifier(Array("table_name").toImmutableArraySeq), Expressions.identity(ref) :: Nil, @@ -79,7 +79,7 @@ class V2CommandsCaseSensitivitySuite Seq("POINT.X", "point.X", "poInt.x", "poInt.X").foreach { ref => val tableSpec = UnresolvedTableSpec(Map.empty, None, OptionList(Seq.empty), - None, None, None, None, false) + None, None, None, None, false, Seq.empty) val plan = CreateTableAsSelect( UnresolvedIdentifier(Array("table_name").toImmutableArraySeq), Expressions.bucket(4, ref) :: Nil, @@ -105,7 +105,7 @@ class V2CommandsCaseSensitivitySuite Seq("ID", "iD").foreach { ref => val tableSpec = UnresolvedTableSpec(Map.empty, None, OptionList(Seq.empty), - None, None, None, None, false) + None, None, None, None, false, Seq.empty) val plan = ReplaceTableAsSelect( UnresolvedIdentifier(Array("table_name").toImmutableArraySeq), Expressions.identity(ref) :: Nil, @@ -130,7 +130,7 @@ class V2CommandsCaseSensitivitySuite Seq("POINT.X", "point.X", "poInt.x", "poInt.X").foreach { ref => val tableSpec = UnresolvedTableSpec(Map.empty, None, OptionList(Seq.empty), - None, None, None, None, false) + None, None, None, None, false, Seq.empty) val plan = ReplaceTableAsSelect( UnresolvedIdentifier(Array("table_name").toImmutableArraySeq), Expressions.bucket(4, ref) :: Nil, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala index f62e092138a98..01a4a52189f58 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala @@ -977,7 +977,8 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase ) val distribution = Distributions.ordered(ordering) - catalog.createTable(ident, columns, Array.empty, emptyProps, distribution, ordering, None, None) + catalog.createTable(ident, columns, Array.empty, emptyProps, + distribution, ordering, None, None, Array.empty) withTempDir { checkpointDir => val inputData = ContinuousMemoryStream[(Long, String, Date)] @@ -1218,7 +1219,8 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase // scalastyle:on argcount catalog.createTable(ident, columns, Array.empty, emptyProps, tableDistribution, - tableOrdering, tableNumPartitions, tablePartitionSize, distributionStrictlyRequired) + tableOrdering, tableNumPartitions, tablePartitionSize, Array.empty, + distributionStrictlyRequired) val df = if (!dataSkewed) { spark.createDataFrame(Seq( @@ -1320,7 +1322,7 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase expectAnalysisException: Boolean = false): Unit = { catalog.createTable(ident, columns, Array.empty, emptyProps, tableDistribution, - tableOrdering, tableNumPartitions, tablePartitionSize) + tableOrdering, tableNumPartitions, tablePartitionSize, Array.empty) withTempDir { checkpointDir => val inputData = MemoryStream[(Long, String, Date)] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropConstraintParseSuite.scala new file mode 100644 index 0000000000000..eaf51a980324a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropConstraintParseSuite.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedTable} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.plans.logical.DropConstraint +import org.apache.spark.sql.test.SharedSparkSession + +class AlterTableDropConstraintParseSuite extends AnalysisTest with SharedSparkSession { + + test("Drop constraint") { + val sql = "ALTER TABLE a.b.c DROP CONSTRAINT c1" + val parsed = parsePlan(sql) + val expected = DropConstraint( + UnresolvedTable( + Seq("a", "b", "c"), + "ALTER TABLE ... DROP CONSTRAINT"), + "c1", + ifExists = false, + cascade = false) + comparePlans(parsed, expected) + } + + test("Drop constraint if exists") { + val sql = "ALTER TABLE a.b.c DROP CONSTRAINT IF EXISTS c1" + val parsed = parsePlan(sql) + val expected = DropConstraint( + UnresolvedTable( + Seq("a", "b", "c"), + "ALTER TABLE ... DROP CONSTRAINT"), + "c1", + ifExists = true, + cascade = false) + comparePlans(parsed, expected) + } + + test("Drop constraint cascade") { + val sql = "ALTER TABLE a.b.c DROP CONSTRAINT c1 CASCADE" + val parsed = parsePlan(sql) + val expected = DropConstraint( + UnresolvedTable( + Seq("a", "b", "c"), + "ALTER TABLE ... DROP CONSTRAINT"), + "c1", + ifExists = false, + cascade = true) + comparePlans(parsed, expected) + } + + test("Drop constraint restrict") { + val sql = "ALTER TABLE a.b.c DROP CONSTRAINT c1 RESTRICT" + val parsed = parsePlan(sql) + val expected = DropConstraint( + UnresolvedTable( + Seq("a", "b", "c"), + "ALTER TABLE ... DROP CONSTRAINT"), + "c1", + ifExists = false, + cascade = false) + comparePlans(parsed, expected) + } + + test("Drop constraint with invalid name") { + val sql = "ALTER TABLE a.b.c DROP CONSTRAINT c1-c3 ENFORCE" + val msg = intercept[ParseException] { + parsePlan(sql) + }.getMessage + assert(msg.contains("Syntax error at or near '-'")) + } + + test("Drop constraint with invalid mode") { + val sql = "ALTER TABLE a.b.c DROP CONSTRAINT c1 ENFORCE" + val msg = intercept[ParseException] { + parsePlan(sql) + }.getMessage + assert(msg.contains("Syntax error at or near 'ENFORCE': extra input 'ENFORCE'.")) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CheckConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CheckConstraintParseSuite.scala new file mode 100644 index 0000000000000..bcbae0d72118b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CheckConstraintParseSuite.scala @@ -0,0 +1,314 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedTable} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.plans.logical.{AddConstraint, ColumnDefinition, CreateTable, ReplaceTable, UnresolvedTableSpec} +import org.apache.spark.sql.types.StringType + +class CheckConstraintParseSuite extends ConstraintParseSuiteBase { + override val validConstraintCharacteristics = + super.validConstraintCharacteristics ++ super.enforcedConstraintCharacteristics + + val constraint1 = CheckConstraint( + child = GreaterThan(UnresolvedAttribute("a"), Literal(0)), + condition = "a > 0", + name = "c1") + val constraint2 = CheckConstraint( + child = EqualTo(UnresolvedAttribute("b"), Literal("foo")), + condition = "b = 'foo'", + name = "c2") + + test("Create table with one check constraint - table level") { + val sql = "CREATE TABLE t (a INT, b STRING, CONSTRAINT c1 CHECK (a > 0)) USING parquet" + verifyConstraints(sql, Seq(constraint1)) + } + + test("Create table with two check constraints - table level") { + val sql = "CREATE TABLE t (a INT, b STRING, CONSTRAINT c1 CHECK (a > 0), " + + "CONSTRAINT c2 CHECK (b = 'foo')) USING parquet" + + verifyConstraints(sql, Seq(constraint1, constraint2)) + } + + test("Create table with valid characteristic - table level") { + validConstraintCharacteristics.foreach { + case (enforcedStr, relyStr, characteristic) => + val sql = s"CREATE TABLE t (a INT, b STRING, CONSTRAINT c1 CHECK (a > 0) " + + s"$enforcedStr $relyStr) USING parquet" + val constraint = constraint1.withNameAndCharacteristic("c1", characteristic, null) + verifyConstraints(sql, Seq(constraint)) + } + } + + test("Create table with invalid characteristic - table level") { + invalidConstraintCharacteristics.foreach { case (characteristic1, characteristic2) => + val constraintStr = s"CONSTRAINT c1 CHECK (a > 0) $characteristic1 $characteristic2" + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT c1 CHECK (a > 0) $characteristic1 $characteristic2", + start = 33, + stop = 61 + characteristic1.length + characteristic2.length + ) + checkError( + exception = intercept[ParseException] { + parsePlan(s"CREATE TABLE t (a INT, b STRING, $constraintStr ) USING parquet") + }, + condition = "INVALID_CONSTRAINT_CHARACTERISTICS", + parameters = Map("characteristics" -> s"$characteristic1, $characteristic2"), + queryContext = Array(expectedContext)) + } + } + + test("Create table with one check constraint - column level") { + val sql = "CREATE TABLE t (a INT CONSTRAINT c1 CHECK (a > 0), b STRING) USING parquet" + verifyConstraints(sql, Seq(constraint1)) + } + + test("Create table with two check constraints - column level") { + val sql = "CREATE TABLE t (a INT CONSTRAINT c1 CHECK (a > 0), " + + "b STRING CONSTRAINT c2 CHECK (b = 'foo')) USING parquet" + verifyConstraints(sql, Seq(constraint1, constraint2)) + } + + test("Create table with mixed column and table level check constraints") { + val sql = "CREATE TABLE t (a INT CONSTRAINT c1 CHECK (a > 0), b STRING, " + + "CONSTRAINT c2 CHECK (b = 'foo')) USING parquet" + verifyConstraints(sql, Seq(constraint1, constraint2)) + } + + test("Create table with valid characteristic - column level") { + validConstraintCharacteristics.foreach { + case (enforcedStr, relyStr, characteristic) => + val sql = s"CREATE TABLE t (a INT CONSTRAINT c1 CHECK (a > 0)" + + s" $enforcedStr $relyStr, b STRING) USING parquet" + val constraint = constraint1.withNameAndCharacteristic("c1", characteristic, null) + verifyConstraints(sql, Seq(constraint)) + } + } + + test("Create table with invalid characteristic - column level") { + invalidConstraintCharacteristics.foreach { case (characteristic1, characteristic2) => + val constraintStr = s"CONSTRAINT c1 CHECK (a > 0) $characteristic1 $characteristic2" + val sql = s"CREATE TABLE t (a INT $constraintStr, b STRING) USING parquet" + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT c1 CHECK (a > 0) $characteristic1 $characteristic2", + start = 22, + stop = 50 + characteristic1.length + characteristic2.length + ) + checkError( + exception = intercept[ParseException] { + parsePlan(sql) + }, + condition = "INVALID_CONSTRAINT_CHARACTERISTICS", + parameters = Map("characteristics" -> s"$characteristic1, $characteristic2"), + queryContext = Array(expectedContext)) + } + } + + test("Create table with column 'constraint'") { + val sql = "CREATE TABLE t (constraint STRING) USING parquet" + val columns = Seq(ColumnDefinition("constraint", StringType)) + val expected = createExpectedPlan(columns, Seq.empty) + comparePlans(parsePlan(sql), expected) + } + + test("Replace table with one check constraint - table level") { + val sql = "REPLACE TABLE t (a INT, b STRING, CONSTRAINT c1 CHECK (a > 0)) USING parquet" + verifyConstraints(sql, Seq(constraint1), isCreateTable = false) + } + + test("Replace table with two check constraints - table level") { + val sql = "REPLACE TABLE t (a INT, b STRING, CONSTRAINT c1 CHECK (a > 0), " + + "CONSTRAINT c2 CHECK (b = 'foo')) USING parquet" + + verifyConstraints(sql, Seq(constraint1, constraint2), isCreateTable = false) + } + + test("Replace table with valid characteristic - table level") { + validConstraintCharacteristics.foreach { + case (enforcedStr, relyStr, characteristic) => + val sql = s"REPLACE TABLE t (a INT, b STRING, CONSTRAINT c1 CHECK (a > 0) " + + s"$enforcedStr $relyStr) USING parquet" + val constraint = constraint1.withNameAndCharacteristic("c1", characteristic, null) + verifyConstraints(sql, Seq(constraint), isCreateTable = false) + } + } + + test("Replace table with invalid characteristic") { + invalidConstraintCharacteristics.foreach { case (characteristic1, characteristic2) => + val constraintStr = s"CONSTRAINT c1 CHECK (a > 0) $characteristic1 $characteristic2" + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT c1 CHECK (a > 0) $characteristic1 $characteristic2", + start = 34, + stop = 62 + characteristic1.length + characteristic2.length + ) + checkError( + exception = intercept[ParseException] { + parsePlan(s"REPLACE TABLE t (a INT, b STRING, $constraintStr ) USING parquet") + }, + condition = "INVALID_CONSTRAINT_CHARACTERISTICS", + parameters = Map("characteristics" -> s"$characteristic1, $characteristic2"), + queryContext = Array(expectedContext)) + } + } + + test("Replace table with column 'constraint'") { + val sql = "REPLACE TABLE t (constraint STRING) USING parquet" + val columns = Seq(ColumnDefinition("constraint", StringType)) + val expected = createExpectedPlan(columns, Seq.empty, isCreateTable = false) + comparePlans(parsePlan(sql), expected) + } + + test("Add check constraint") { + val sql = + """ + |ALTER TABLE a.b.c ADD CONSTRAINT c1 CHECK (a > 0) + |""".stripMargin + val parsed = parsePlan(sql) + val expected = AddConstraint( + UnresolvedTable( + Seq("a", "b", "c"), + "ALTER TABLE ... ADD CONSTRAINT"), + constraint1) + comparePlans(parsed, expected) + } + + test("Add invalid check constraint name") { + val sql = + """ + |ALTER TABLE a.b.c ADD CONSTRAINT c1-c3 CHECK (d > 0) + |""".stripMargin + val e = intercept[ParseException] { + parsePlan(sql) + } + checkError(e, "INVALID_IDENTIFIER", "42602", Map("ident" -> "c1-c3")) + } + + test("Add invalid check constraint expression") { + val sql = + """ + |ALTER TABLE a.b.c ADD CONSTRAINT c1 CHECK (d >) + |""".stripMargin + val msg = intercept[ParseException] { + parsePlan(sql) + }.getMessage + assert(msg.contains("Syntax error at or near ')'")) + } + + test("Add check constraint with valid characteristic") { + validConstraintCharacteristics.foreach { case (enforcedStr, relyStr, characteristic) => + val sql = + s""" + |ALTER TABLE a.b.c ADD CONSTRAINT c1 CHECK (d > 0) $enforcedStr $relyStr + |""".stripMargin + val parsed = parsePlan(sql) + val expected = AddConstraint( + UnresolvedTable( + Seq("a", "b", "c"), + "ALTER TABLE ... ADD CONSTRAINT"), + CheckConstraint( + child = GreaterThan(UnresolvedAttribute("d"), Literal(0)), + condition = "d > 0", + name = "c1", + characteristic = characteristic + )) + comparePlans(parsed, expected) + } + } + + test("Add check constraint with invalid characteristic") { + invalidConstraintCharacteristics.foreach { case (characteristic1, characteristic2) => + val sql = + s"ALTER TABLE a.b.c ADD CONSTRAINT c1 CHECK (d > 0) $characteristic1 $characteristic2" + + val e = intercept[ParseException] { + parsePlan(sql) + } + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT c1 CHECK (d > 0) $characteristic1 $characteristic2", + start = 22, + stop = 50 + characteristic1.length + characteristic2.length + ) + checkError( + exception = e, + condition = "INVALID_CONSTRAINT_CHARACTERISTICS", + parameters = Map("characteristics" -> s"$characteristic1, $characteristic2"), + queryContext = Array(expectedContext)) + } + } + + test("Create table with unnamed check constraint") { + Seq( + "CREATE TABLE t (a INT, b STRING, CHECK (a > 0))", + "CREATE TABLE t (a INT CHECK (a > 0), b STRING)" + ).foreach { sql => + val plan = parsePlan(sql) + plan match { + case c: CreateTable => + val tableSpec = c.tableSpec.asInstanceOf[UnresolvedTableSpec] + assert(tableSpec.constraints.size == 1) + assert(tableSpec.constraints.head.isInstanceOf[CheckConstraint]) + assert(tableSpec.constraints.head.name.matches("t_chk_a0_[0-9a-zA-Z]{6}")) + + case other => + fail(s"Expected CreateTable, but got: $other") + } + } + } + + test("Replace table with unnamed check constraint") { + Seq( + "REPLACE TABLE t (a INT, b STRING, CHECK (a > 0))", + "REPLACE TABLE t (a INT CHECK (a > 0), b STRING)" + ).foreach { sql => + val plan = parsePlan(sql) + plan match { + case c: ReplaceTable => + val tableSpec = c.tableSpec.asInstanceOf[UnresolvedTableSpec] + assert(tableSpec.constraints.size == 1) + assert(tableSpec.constraints.head.isInstanceOf[CheckConstraint]) + assert(tableSpec.constraints.head.name.matches("t_chk_a0_[0-9a-zA-Z]{6}")) + + case other => + fail(s"Expected ReplaceTable, but got: $other") + } + } + } + + test("Add unnamed check constraint") { + val sql = + """ + |ALTER TABLE a.b.c ADD CHECK (d > 0) + |""".stripMargin + val plan = parsePlan(sql) + plan match { + case a: AddConstraint => + val table = a.table.asInstanceOf[UnresolvedTable] + assert(table.multipartIdentifier == Seq("a", "b", "c")) + assert(a.tableConstraint.isInstanceOf[CheckConstraint]) + assert(a.tableConstraint.name.matches("c_chk_d0_[0-9a-zA-Z]{6}")) + + case other => + fail(s"Expected AddConstraint, but got: $other") + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ConstraintParseSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ConstraintParseSuiteBase.scala new file mode 100644 index 0000000000000..ea369489eb181 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ConstraintParseSuiteBase.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedIdentifier} +import org.apache.spark.sql.catalyst.expressions.{ConstraintCharacteristic, TableConstraint} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan +import org.apache.spark.sql.catalyst.plans.logical.{ColumnDefinition, CreateTable, LogicalPlan, OptionList, ReplaceTable, UnresolvedTableSpec} +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, StringType} + +abstract class ConstraintParseSuiteBase extends AnalysisTest with SharedSparkSession { + protected def validConstraintCharacteristics = Seq( + ("", "", ConstraintCharacteristic(enforced = None, rely = None)), + ("NOT ENFORCED", "", ConstraintCharacteristic(enforced = Some(false), rely = None)), + ("", "RELY", ConstraintCharacteristic(enforced = None, rely = Some(true))), + ("", "NORELY", ConstraintCharacteristic(enforced = None, rely = Some(false))), + ("NOT ENFORCED", "RELY", + ConstraintCharacteristic(enforced = Some(false), rely = Some(true))), + ("NOT ENFORCED", "NORELY", + ConstraintCharacteristic(enforced = Some(false), rely = Some(false))) + ) + + protected def enforcedConstraintCharacteristics = Seq( + ("ENFORCED", "", ConstraintCharacteristic(enforced = Some(true), rely = None)), + ("ENFORCED", "RELY", ConstraintCharacteristic(enforced = Some(true), rely = Some(true))), + ("ENFORCED", "NORELY", ConstraintCharacteristic(enforced = Some(true), rely = Some(false))), + ("RELY", "ENFORCED", ConstraintCharacteristic(enforced = Some(true), rely = Some(true))), + ("NORELY", "ENFORCED", ConstraintCharacteristic(enforced = Some(true), rely = Some(false))) + ) + + protected val invalidConstraintCharacteristics = Seq( + ("ENFORCED", "ENFORCED"), + ("ENFORCED", "NOT ENFORCED"), + ("NOT ENFORCED", "ENFORCED"), + ("NOT ENFORCED", "NOT ENFORCED"), + ("RELY", "RELY"), + ("RELY", "NORELY"), + ("NORELY", "RELY"), + ("NORELY", "NORELY") + ) + + protected def createExpectedPlan( + columns: Seq[ColumnDefinition], + constraints: Seq[TableConstraint], + isCreateTable: Boolean = true): LogicalPlan = { + val tableId = UnresolvedIdentifier(Seq("t")) + val tableSpec = UnresolvedTableSpec( + Map.empty[String, String], Some("parquet"), OptionList(Seq.empty), + None, None, None, None, false, constraints) + if (isCreateTable) { + CreateTable(tableId, columns, Seq.empty, tableSpec, false) + } else { + ReplaceTable(tableId, columns, Seq.empty, tableSpec, false) + } + } + + protected def verifyConstraints( + sql: String, + constraints: Seq[TableConstraint], + isCreateTable: Boolean = true): Unit = { + val parsed = parsePlan(sql) + val columns = Seq( + ColumnDefinition("a", IntegerType), + ColumnDefinition("b", StringType) + ) + val expected = createExpectedPlan( + columns = columns, constraints = constraints, isCreateTable = isCreateTable) + comparePlans(parsed, expected) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ForeignKeyConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ForeignKeyConstraintParseSuite.scala new file mode 100644 index 0000000000000..1df5e7b9deb2e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ForeignKeyConstraintParseSuite.scala @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.catalyst.analysis.UnresolvedTable +import org.apache.spark.sql.catalyst.expressions.ForeignKeyConstraint +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.plans.logical.AddConstraint + +class ForeignKeyConstraintParseSuite extends ConstraintParseSuiteBase { + test("Create table with foreign key - table level") { + val sql = "CREATE TABLE t (a INT, b STRING," + + " FOREIGN KEY (a) REFERENCES parent(id)) USING parquet" + val constraint = ForeignKeyConstraint( + name = "t_parent_fk", + childColumns = Seq("a"), + parentTableId = Seq("parent"), + parentColumns = Seq("id") + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with named foreign key - table level") { + val sql = "CREATE TABLE t (a INT, b STRING, CONSTRAINT fk1 FOREIGN KEY (a)" + + " REFERENCES parent(id)) USING parquet" + val constraint = ForeignKeyConstraint( + childColumns = Seq("a"), + parentTableId = Seq("parent"), + parentColumns = Seq("id"), + name = "fk1" + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with foreign key - column level") { + val sql = "CREATE TABLE t (a INT REFERENCES parent(id), b STRING) USING parquet" + val constraint = ForeignKeyConstraint( + name = "t_parent_fk", + childColumns = Seq("a"), + parentTableId = Seq("parent"), + parentColumns = Seq("id") + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with named foreign key - column level") { + val sql = "CREATE TABLE t (a INT CONSTRAINT fk1 REFERENCES parent(id), b STRING) USING parquet" + val constraint = ForeignKeyConstraint( + childColumns = Seq("a"), + parentTableId = Seq("parent"), + parentColumns = Seq("id"), + name = "fk1" + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Add foreign key constraint") { + Seq( + ("", "orders_customers_fk"), + ("CONSTRAINT fk1", "fk1")).foreach { case (constraintName, expectedName) => + val sql = + s""" + |ALTER TABLE orders ADD $constraintName FOREIGN KEY (customer_id) + |REFERENCES customers (id) + |""".stripMargin + val parsed = parsePlan(sql) + val expected = AddConstraint( + UnresolvedTable( + Seq("orders"), + "ALTER TABLE ... ADD CONSTRAINT"), + ForeignKeyConstraint( + name = expectedName, + childColumns = Seq("customer_id"), + parentTableId = Seq("customers"), + parentColumns = Seq("id") + )) + comparePlans(parsed, expected) + } + } + + test("Add invalid foreign key constraint name") { + val sql = + """ + |ALTER TABLE orders ADD CONSTRAINT fk-1 FOREIGN KEY (customer_id) + |REFERENCES customers (id) + |""".stripMargin + val e = intercept[ParseException] { + parsePlan(sql) + } + checkError(e, "PARSE_SYNTAX_ERROR", "42601", Map("error" -> "'-'", "hint" -> "")) + } + + test("Add foreign key constraint with empty columns") { + val sql = + """ + |ALTER TABLE orders ADD CONSTRAINT fk1 FOREIGN KEY () + |REFERENCES customers (id) + |""".stripMargin + val e = intercept[ParseException] { + parsePlan(sql) + } + checkError(e, "PARSE_SYNTAX_ERROR", "42601", Map("error" -> "')'", "hint" -> "")) + } + + test("Add foreign key constraint with valid characteristic") { + validConstraintCharacteristics.foreach { case (enforcedStr, relyStr, characteristic) => + val sql = + s""" + |ALTER TABLE orders ADD CONSTRAINT fk1 FOREIGN KEY (customer_id) + |REFERENCES customers (id) $enforcedStr $relyStr + |""".stripMargin + val parsed = parsePlan(sql) + val expected = AddConstraint( + UnresolvedTable( + Seq("orders"), + "ALTER TABLE ... ADD CONSTRAINT"), + ForeignKeyConstraint( + name = "fk1", + childColumns = Seq("customer_id"), + parentTableId = Seq("customers"), + parentColumns = Seq("id"), + characteristic = characteristic + )) + comparePlans(parsed, expected) + } + } + + test("Add foreign key constraint with invalid characteristic") { + invalidConstraintCharacteristics.foreach { case (characteristic1, characteristic2) => + val sql = + s""" + |ALTER TABLE orders ADD CONSTRAINT fk1 FOREIGN KEY (customer_id) + |REFERENCES customers (id) $characteristic1 $characteristic2 + |""".stripMargin + + val e = intercept[ParseException] { + parsePlan(sql) + } + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT fk1 FOREIGN KEY (customer_id)\nREFERENCES customers (id) " + + s"$characteristic1 $characteristic2", + start = 24, + stop = 91 + characteristic1.length + characteristic2.length + ) + checkError( + exception = e, + condition = "INVALID_CONSTRAINT_CHARACTERISTICS", + parameters = Map("characteristics" -> s"$characteristic1, $characteristic2"), + queryContext = Array(expectedContext)) + } + } + + test("ENFORCED is not supported for foreign key -- create table with unnamed constraint") { + enforcedConstraintCharacteristics.foreach { case (c1, c2, _) => + val characteristic = if (c2.isEmpty) { + c1 + } else { + s"$c1 $c2" + } + val sql = + s"CREATE TABLE t (id INT REFERENCES parent(id) $characteristic) USING parquet" + val error = intercept[ParseException] { + parsePlan(sql) + } + val expectedContext = ExpectedContext( + fragment = s"REFERENCES parent(id) $characteristic", + start = 23, + stop = 44 + characteristic.length + ) + checkError( + exception = error, + condition = "UNSUPPORTED_CONSTRAINT_CHARACTERISTIC", + parameters = Map("characteristic" -> "ENFORCED", "constraintType" -> "FOREIGN KEY"), + queryContext = Array(expectedContext)) + } + } + + test("ENFORCED is not supported for foreign key -- create table with named constraint") { + enforcedConstraintCharacteristics.foreach { case (c1, c2, _) => + val characteristic = if (c2.isEmpty) { + c1 + } else { + s"$c1 $c2" + } + val sql = + s"CREATE TABLE t (id INT, CONSTRAINT fk1 FOREIGN KEY (id)" + + s" REFERENCES parent(id) $characteristic) USING parquet" + val error = intercept[ParseException] { + parsePlan(sql) + } + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT fk1 FOREIGN KEY (id) REFERENCES parent(id) $characteristic", + start = 24, + stop = 77 + characteristic.length + ) + checkError( + exception = error, + condition = "UNSUPPORTED_CONSTRAINT_CHARACTERISTIC", + parameters = Map("characteristic" -> "ENFORCED", "constraintType" -> "FOREIGN KEY"), + queryContext = Array(expectedContext)) + } + } + + test("ENFORCED is not supported for foreign key -- alter table") { + enforcedConstraintCharacteristics.foreach { case (c1, c2, _) => + val characteristic = if (c2.isEmpty) { + c1 + } else { + s"$c1 $c2" + } + val sql = + s"ALTER TABLE a.b.c ADD CONSTRAINT fk1 FOREIGN KEY (id)" + + s" REFERENCES parent(id) $characteristic" + val error = intercept[ParseException] { + parsePlan(sql) + } + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT fk1 FOREIGN KEY (id) REFERENCES parent(id) $characteristic", + start = 22, + stop = 75 + characteristic.length + ) + checkError( + exception = error, + condition = "UNSUPPORTED_CONSTRAINT_CHARACTERISTIC", + parameters = Map("characteristic" -> "ENFORCED", "constraintType" -> "FOREIGN KEY"), + queryContext = Array(expectedContext)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PrimaryKeyConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PrimaryKeyConstraintParseSuite.scala new file mode 100644 index 0000000000000..e3b3f162fcae3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PrimaryKeyConstraintParseSuite.scala @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.catalyst.analysis.UnresolvedTable +import org.apache.spark.sql.catalyst.expressions.PrimaryKeyConstraint +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.plans.logical.AddConstraint + +class PrimaryKeyConstraintParseSuite extends ConstraintParseSuiteBase { + + test("Create table with primary key - table level") { + val sql = "CREATE TABLE t (a INT, b STRING, PRIMARY KEY (a)) USING parquet" + val constraint = PrimaryKeyConstraint(columns = Seq("a"), name = "t_pk") + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with named primary key - table level") { + val sql = "CREATE TABLE t (a INT, b STRING, CONSTRAINT pk1 PRIMARY KEY (a)) USING parquet" + val constraint = PrimaryKeyConstraint( + columns = Seq("a"), + name = "pk1" + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with composite primary key - table level") { + val sql = "CREATE TABLE t (a INT, b STRING, PRIMARY KEY (a, b)) USING parquet" + val constraint = PrimaryKeyConstraint(columns = Seq("a", "b"), name = "t_pk") + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with primary key - column level") { + val sql = "CREATE TABLE t (a INT PRIMARY KEY, b STRING) USING parquet" + val constraint = PrimaryKeyConstraint(columns = Seq("a"), name = "t_pk") + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with named primary key - column level") { + val sql = "CREATE TABLE t (a INT CONSTRAINT pk1 PRIMARY KEY, b STRING) USING parquet" + val constraint = PrimaryKeyConstraint( + columns = Seq("a"), + name = "pk1" + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with multiple primary keys should fail") { + val expectedContext = ExpectedContext( + fragment = "a INT PRIMARY KEY, b STRING, PRIMARY KEY (b)", + start = 16, + stop = 59 + ) + checkError( + exception = intercept[ParseException] { + parsePlan("CREATE TABLE t (a INT PRIMARY KEY, b STRING, PRIMARY KEY (b)) USING parquet") + }, + condition = "MULTIPLE_PRIMARY_KEYS", + parameters = Map.empty[String, String], + queryContext = Array(expectedContext)) + } + + test("Add primary key constraint") { + Seq(("", "c_pk"), ("CONSTRAINT pk1", "pk1")).foreach { case (constraintName, expectedName) => + val sql = + s""" + |ALTER TABLE a.b.c ADD $constraintName PRIMARY KEY (id, name) + |""".stripMargin + val parsed = parsePlan(sql) + val expected = AddConstraint( + UnresolvedTable( + Seq("a", "b", "c"), + "ALTER TABLE ... ADD CONSTRAINT"), + PrimaryKeyConstraint( + name = expectedName, + columns = Seq("id", "name") + )) + comparePlans(parsed, expected) + } + } + + test("Add invalid primary key constraint name") { + val sql = + """ + |ALTER TABLE a.b.c ADD CONSTRAINT pk-1 PRIMARY KEY (id) + |""".stripMargin + val e = intercept[ParseException] { + parsePlan(sql) + } + checkError(e, "PARSE_SYNTAX_ERROR", "42601", Map("error" -> "'-'", "hint" -> "")) + } + + test("Add primary key constraint with empty columns") { + val sql = + """ + |ALTER TABLE a.b.c ADD CONSTRAINT pk1 PRIMARY KEY () + |""".stripMargin + val e = intercept[ParseException] { + parsePlan(sql) + } + checkError(e, "PARSE_SYNTAX_ERROR", "42601", Map("error" -> "')'", "hint" -> "")) + } + + test("Add primary key constraint with valid characteristic") { + validConstraintCharacteristics.foreach { case (enforcedStr, relyStr, characteristic) => + val sql = + s""" + |ALTER TABLE a.b.c ADD CONSTRAINT pk1 PRIMARY KEY (id) $enforcedStr $relyStr + |""".stripMargin + val parsed = parsePlan(sql) + val expected = AddConstraint( + UnresolvedTable( + Seq("a", "b", "c"), + "ALTER TABLE ... ADD CONSTRAINT"), + PrimaryKeyConstraint( + name = "pk1", + columns = Seq("id"), + characteristic = characteristic + )) + comparePlans(parsed, expected) + } + } + + test("Add primary key constraint with invalid characteristic") { + invalidConstraintCharacteristics.foreach { case (characteristic1, characteristic2) => + val sql = + s"ALTER TABLE a.b.c ADD CONSTRAINT pk1 PRIMARY KEY (id) $characteristic1 $characteristic2" + + val e = intercept[ParseException] { + parsePlan(sql) + } + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT pk1 PRIMARY KEY (id) $characteristic1 $characteristic2", + start = 22, + stop = 54 + characteristic1.length + characteristic2.length + ) + checkError( + exception = e, + condition = "INVALID_CONSTRAINT_CHARACTERISTICS", + parameters = Map("characteristics" -> s"$characteristic1, $characteristic2"), + queryContext = Array(expectedContext)) + } + } + + test("ENFORCED is not supported for primary key -- create table with unnamed constraint") { + enforcedConstraintCharacteristics.foreach { case (c1, c2, _) => + val characteristic = if (c2.isEmpty) { + c1 + } else { + s"$c1 $c2" + } + val sql = + s"CREATE TABLE t (id INT PRIMARY KEY $characteristic) USING parquet" + val error = intercept[ParseException] { + parsePlan(sql) + } + val expectedContext = ExpectedContext( + fragment = s"PRIMARY KEY $characteristic", + start = 23, + stop = 34 + characteristic.length + ) + checkError( + exception = error, + condition = "UNSUPPORTED_CONSTRAINT_CHARACTERISTIC", + parameters = Map("characteristic" -> "ENFORCED", "constraintType" -> "PRIMARY KEY"), + queryContext = Array(expectedContext)) + } + } + + test("ENFORCED is not supported for primary key -- create table with named constraint") { + enforcedConstraintCharacteristics.foreach { case (c1, c2, _) => + val characteristic = if (c2.isEmpty) { + c1 + } else { + s"$c1 $c2" + } + val sql = + s"CREATE TABLE t (id INT, CONSTRAINT pk1 PRIMARY KEY (id) $characteristic) USING parquet" + val error = intercept[ParseException] { + parsePlan(sql) + } + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT pk1 PRIMARY KEY (id) $characteristic", + start = 24, + stop = 55 + characteristic.length + ) + checkError( + exception = error, + condition = "UNSUPPORTED_CONSTRAINT_CHARACTERISTIC", + parameters = Map("characteristic" -> "ENFORCED", "constraintType" -> "PRIMARY KEY"), + queryContext = Array(expectedContext)) + } + } + + test("ENFORCED is not supported for primary key -- alter table") { + enforcedConstraintCharacteristics.foreach { case (c1, c2, _) => + val characteristic = if (c2.isEmpty) { + c1 + } else { + s"$c1 $c2" + } + val sql = + s"ALTER TABLE a.b.c ADD CONSTRAINT pk1 PRIMARY KEY (id) $characteristic" + val error = intercept[ParseException] { + parsePlan(sql) + } + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT pk1 PRIMARY KEY (id) $characteristic", + start = 22, + stop = 53 + characteristic.length + ) + checkError( + exception = error, + condition = "UNSUPPORTED_CONSTRAINT_CHARACTERISTIC", + parameters = Map("characteristic" -> "ENFORCED", "constraintType" -> "PRIMARY KEY"), + queryContext = Array(expectedContext)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/UniqueConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/UniqueConstraintParseSuite.scala new file mode 100644 index 0000000000000..f33807e7cfa0b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/UniqueConstraintParseSuite.scala @@ -0,0 +1,301 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.catalyst.analysis.UnresolvedTable +import org.apache.spark.sql.catalyst.expressions.UniqueConstraint +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.plans.logical.{AddConstraint, CreateTable, ReplaceTable, UnresolvedTableSpec} + +class UniqueConstraintParseSuite extends ConstraintParseSuiteBase { + + test("Create table with unnamed unique constraint") { + Seq( + "CREATE TABLE t (a INT, b STRING, UNIQUE (a))", + "CREATE TABLE t (a INT UNIQUE, b STRING)" + ).foreach { sql => + val plan = parsePlan(sql) + plan match { + case c: CreateTable => + val tableSpec = c.tableSpec.asInstanceOf[UnresolvedTableSpec] + assert(tableSpec.constraints.size == 1) + assert(tableSpec.constraints.head.isInstanceOf[UniqueConstraint]) + assert(tableSpec.constraints.head.name.matches("t_uk_a_[0-9a-zA-Z]{6}")) + + case other => + fail(s"Expected CreateTable, but got: $other") + } + } + } + + test("Create table with composite unique constraint - table level") { + Seq( + "CREATE TABLE t (a INT, b STRING, UNIQUE (a, b))", + "CREATE TABLE t (a INT, b STRING, UNIQUE (b, a))" + ).foreach { sql => + val plan = parsePlan(sql) + plan match { + case c: CreateTable => + val tableSpec = c.tableSpec.asInstanceOf[UnresolvedTableSpec] + assert(tableSpec.constraints.size == 1) + assert(tableSpec.constraints.head.isInstanceOf[UniqueConstraint]) + assert(tableSpec.constraints.head.name.matches("t_uk_a_b_[0-9a-zA-Z]{6}")) + + case other => + fail(s"Expected CreateTable, but got: $other") + } + } + } + + test("Create table with multiple unique constraints") { + Seq( + "CREATE TABLE t (a INT UNIQUE, b STRING, UNIQUE (b))", + "CREATE TABLE t (a INT, UNIQUE (a), b STRING UNIQUE)" + ).foreach { sql => + val plan = parsePlan(sql) + plan match { + case c: CreateTable => + val tableSpec = c.tableSpec.asInstanceOf[UnresolvedTableSpec] + assert(tableSpec.constraints.size == 2) + assert(tableSpec.constraints.head.isInstanceOf[UniqueConstraint]) + assert(tableSpec.constraints.head.name.matches("t_uk_a_[0-9a-zA-Z]{6}")) + assert(tableSpec.constraints.last.isInstanceOf[UniqueConstraint]) + assert(tableSpec.constraints.last.name.matches("t_uk_b_[0-9a-zA-Z]{6}")) + + case other => + fail(s"Expected CreateTable, but got: $other") + } + } + } + + test("Replace table with unnamed unique constraint") { + Seq( + "REPLACE TABLE t (a INT, b STRING, UNIQUE (a))", + "REPLACE TABLE t (a INT UNIQUE, b STRING)" + ).foreach { sql => + val plan = parsePlan(sql) + plan match { + case c: ReplaceTable => + val tableSpec = c.tableSpec.asInstanceOf[UnresolvedTableSpec] + assert(tableSpec.constraints.size == 1) + assert(tableSpec.constraints.head.isInstanceOf[UniqueConstraint]) + assert(tableSpec.constraints.head.name.matches("t_uk_a_[0-9a-zA-Z]{6}")) + + case other => + fail(s"Expected ReplaceTable, but got: $other") + } + } + } + + test("Add unnamed unique constraint") { + val sql = + """ + |ALTER TABLE a.b.c ADD UNIQUE (d) + |""".stripMargin + val plan = parsePlan(sql) + plan match { + case a: AddConstraint => + val table = a.table.asInstanceOf[UnresolvedTable] + assert(table.multipartIdentifier == Seq("a", "b", "c")) + assert(a.tableConstraint.isInstanceOf[UniqueConstraint]) + assert(a.tableConstraint.name.matches("c_uk_d_[0-9a-zA-Z]{6}")) + + case other => + fail(s"Expected AddConstraint, but got: $other") + } + } + + test("Create table with named unique constraint - table level") { + val sql = "CREATE TABLE t (a INT, b STRING, CONSTRAINT uk1 UNIQUE (a)) USING parquet" + val constraint = UniqueConstraint( + columns = Seq("a"), + name = "uk1" + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with named unique constraint - column level") { + val sql = "CREATE TABLE t (a INT CONSTRAINT uk1 UNIQUE, b STRING) USING parquet" + val constraint = UniqueConstraint( + columns = Seq("a"), + name = "uk1" + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Add unique constraint") { + Seq( + ("consTrainT abcdEF", "abcdEF"), + ("CONSTRAINT uk1", "uk1")).foreach { case (constraintName, expectedName) => + val sql = + s""" + |ALTER TABLE a.b.c ADD $constraintName UNIQUE (email, username) + |""".stripMargin + val parsed = parsePlan(sql) + val expected = AddConstraint( + UnresolvedTable( + Seq("a", "b", "c"), + "ALTER TABLE ... ADD CONSTRAINT"), + UniqueConstraint( + name = expectedName, + columns = Seq("email", "username") + )) + comparePlans(parsed, expected) + } + } + + test("Add invalid unique constraint name") { + val sql = + """ + |ALTER TABLE a.b.c ADD CONSTRAINT uk-1 UNIQUE (email) + |""".stripMargin + val e = intercept[ParseException] { + parsePlan(sql) + } + checkError(e, "PARSE_SYNTAX_ERROR", "42601", Map("error" -> "'-'", "hint" -> "")) + } + + test("Add unique constraint with empty columns") { + val sql = + """ + |ALTER TABLE a.b.c ADD CONSTRAINT uk1 UNIQUE () + |""".stripMargin + val e = intercept[ParseException] { + parsePlan(sql) + } + checkError(e, "PARSE_SYNTAX_ERROR", "42601", Map("error" -> "')'", "hint" -> "")) + } + + test("Add unique constraint with valid characteristic") { + validConstraintCharacteristics.foreach { case (enforcedStr, relyStr, characteristic) => + val sql = + s""" + |ALTER TABLE a.b.c ADD CONSTRAINT uk1 UNIQUE (email) $enforcedStr $relyStr + |""".stripMargin + val parsed = parsePlan(sql) + val expected = AddConstraint( + UnresolvedTable( + Seq("a", "b", "c"), + "ALTER TABLE ... ADD CONSTRAINT"), + UniqueConstraint( + name = "uk1", + columns = Seq("email"), + characteristic = characteristic + )) + comparePlans(parsed, expected) + } + } + + test("Add unique constraint with invalid characteristic") { + invalidConstraintCharacteristics.foreach { case (characteristic1, characteristic2) => + val sql = + s"ALTER TABLE a.b.c ADD CONSTRAINT uk1 UNIQUE (email) $characteristic1 $characteristic2" + + val e = intercept[ParseException] { + parsePlan(sql) + } + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT uk1 UNIQUE (email) $characteristic1 $characteristic2", + start = 22, + stop = 52 + characteristic1.length + characteristic2.length + ) + checkError( + exception = e, + condition = "INVALID_CONSTRAINT_CHARACTERISTICS", + parameters = Map("characteristics" -> s"$characteristic1, $characteristic2"), + queryContext = Array(expectedContext)) + } + } + + test("ENFORCED is not supported for unique -- create table with unnamed constraint") { + enforcedConstraintCharacteristics.foreach { case (c1, c2, _) => + val characteristic = if (c2.isEmpty) { + c1 + } else { + s"$c1 $c2" + } + val sql = + s"CREATE TABLE t (id INT UNIQUE $characteristic) USING parquet" + val error = intercept[ParseException] { + parsePlan(sql) + } + val expectedContext = ExpectedContext( + fragment = s"UNIQUE $characteristic", + start = 23, + stop = 29 + characteristic.length + ) + checkError( + exception = error, + condition = "UNSUPPORTED_CONSTRAINT_CHARACTERISTIC", + parameters = Map("characteristic" -> "ENFORCED", "constraintType" -> "UNIQUE"), + queryContext = Array(expectedContext)) + } + } + + test("ENFORCED is not supported for unique -- create table with named constraint") { + enforcedConstraintCharacteristics.foreach { case (c1, c2, _) => + val characteristic = if (c2.isEmpty) { + c1 + } else { + s"$c1 $c2" + } + val sql = + s"CREATE TABLE t (id INT CONSTRAINT uk1 UNIQUE $characteristic) USING parquet" + val error = intercept[ParseException] { + parsePlan(sql) + } + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT uk1 UNIQUE $characteristic", + start = 23, + stop = 44 + characteristic.length + ) + checkError( + exception = error, + condition = "UNSUPPORTED_CONSTRAINT_CHARACTERISTIC", + parameters = Map("characteristic" -> "ENFORCED", "constraintType" -> "UNIQUE"), + queryContext = Array(expectedContext)) + } + } + + test("ENFORCED is not supported for unique -- alter table") { + enforcedConstraintCharacteristics.foreach { case (c1, c2, _) => + val characteristic = if (c2.isEmpty) { + c1 + } else { + s"$c1 $c2" + } + val sql = + s"ALTER TABLE a.b.c ADD CONSTRAINT uni UNIQUE (id) $characteristic" + val error = intercept[ParseException] { + parsePlan(sql) + } + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT uni UNIQUE (id) $characteristic", + start = 22, + stop = 48 + characteristic.length + ) + checkError( + exception = error, + condition = "UNSUPPORTED_CONSTRAINT_CHARACTERISTIC", + parameters = Map("characteristic" -> "ENFORCED", "constraintType" -> "UNIQUE"), + queryContext = Array(expectedContext)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala new file mode 100644 index 0000000000000..818f4f42c65d5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command.v2 + +import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.connector.catalog.constraints.Check +import org.apache.spark.sql.execution.command.DDLCommandTestUtils + +class CheckConstraintSuite extends QueryTest with CommandSuiteBase with DDLCommandTestUtils { + override protected def command: String = "ALTER TABLE .. ADD CONSTRAINT" + + test("Nondeterministic expression -- alter table") { + withTable("t") { + sql("create table t(i double)") + val query = + """ + |ALTER TABLE t ADD CONSTRAINT c1 CHECK (i > rand(0)) + |""".stripMargin + val error = intercept[AnalysisException] { + sql(query) + } + checkError( + exception = error, + condition = "INVALID_CHECK_CONSTRAINT.NONDETERMINISTIC", + sqlState = "42621", + parameters = Map.empty, + context = ExpectedContext( + fragment = "i > rand(0)", + start = 40, + stop = 50 + ) + ) + } + } + + test("Nondeterministic expression -- create table") { + Seq( + "create table t(i double check (i > rand(0)))", + "create table t(i double, constraint c1 check (i > rand(0)))", + "replace table t(i double check (i > rand(0)))", + "replace table t(i double, constraint c1 check (i > rand(0)))" + ).foreach { query => + withTable("t") { + val error = intercept[AnalysisException] { + sql(query) + } + checkError( + exception = error, + condition = "INVALID_CHECK_CONSTRAINT.NONDETERMINISTIC", + sqlState = "42621", + parameters = Map.empty, + context = ExpectedContext( + fragment = "i > rand(0)" + ) + ) + } + } + } + + test("Expression referring a column of another table -- alter table") { + withTable("t", "t2") { + sql("create table t(i double) using parquet") + sql("create table t2(j string) using parquet") + val query = + """ + |ALTER TABLE t ADD CONSTRAINT c1 CHECK (len(t2.j) > 0) + |""".stripMargin + val error = intercept[AnalysisException] { + sql(query) + } + checkError( + exception = error, + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + sqlState = "42703", + parameters = Map("objectName" -> "`t2`.`j`", "proposal" -> "`t`.`i`"), + context = ExpectedContext( + fragment = "t2.j", + start = 44, + stop = 47 + ) + ) + } + } + + test("Expression referring a column of another table -- create and replace table") { + withTable("t", "t2") { + sql("create table t(i double) using parquet") + val query = "create table t2(j string check(t.i > 0)) using parquet" + val error = intercept[AnalysisException] { + sql(query) + } + checkError( + exception = error, + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + sqlState = "42703", + parameters = Map("objectName" -> "`t`.`i`", "proposal" -> "`j`"), + context = ExpectedContext( + fragment = "t.i" + ) + ) + } + } + + private def getCheckConstraint(table: Table): Check = { + assert(table.constraints.length == 1) + assert(table.constraints.head.isInstanceOf[Check]) + table.constraints.head.asInstanceOf[Check] + table.constraints.head.asInstanceOf[Check] + } + + test("Predicate should be null if it can't be converted to V2 predicate") { + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id bigint, j string) $defaultUsing") + sql(s"ALTER TABLE $t ADD CONSTRAINT c1 CHECK (from_json(j, 'a INT').a > 1)") + val table = loadTable(catalog, "ns", "tbl") + val constraint = getCheckConstraint(table) + assert(constraint.name() == "c1") + assert(constraint.toDDL == + "CONSTRAINT c1 CHECK from_json(j, 'a INT').a > 1 ENFORCED VALID NORELY") + assert(constraint.sql() == "from_json(j, 'a INT').a > 1") + assert(constraint.predicate() == null) + } + } + + def getConstraintCharacteristics(isCreateTable: Boolean): Seq[(String, String)] = { + val validStatus = if (isCreateTable) "UNVALIDATED" else "VALID" + Seq( + ("", s"ENFORCED $validStatus NORELY"), + ("NOT ENFORCED", s"NOT ENFORCED $validStatus NORELY"), + ("NOT ENFORCED NORELY", s"NOT ENFORCED $validStatus NORELY"), + ("NORELY NOT ENFORCED", s"NOT ENFORCED $validStatus NORELY"), + ("NORELY", s"ENFORCED $validStatus NORELY"), + ("NOT ENFORCED RELY", s"NOT ENFORCED $validStatus RELY"), + ("RELY NOT ENFORCED", s"NOT ENFORCED $validStatus RELY"), + ("NOT ENFORCED RELY", s"NOT ENFORCED $validStatus RELY"), + ("RELY NOT ENFORCED", s"NOT ENFORCED $validStatus RELY"), + ("RELY", s"ENFORCED $validStatus RELY") + ) + } + + test("Create table with check constraint") { + getConstraintCharacteristics(true).foreach { case (characteristic, expectedDDL) => + withNamespaceAndTable("ns", "tbl", catalog) { t => + val constraintStr = s"CONSTRAINT c1 CHECK (id > 0) $characteristic" + sql(s"CREATE TABLE $t (id bigint, data string, $constraintStr) $defaultUsing") + val table = loadTable(catalog, "ns", "tbl") + val constraint = getCheckConstraint(table) + assert(constraint.name() == "c1") + assert(constraint.toDDL == s"CONSTRAINT c1 CHECK id > 0 $expectedDDL") + } + } + } + + test("Alter table add check constraint") { + getConstraintCharacteristics(false).foreach { case (characteristic, expectedDDL) => + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing") + assert(loadTable(catalog, "ns", "tbl").constraints.isEmpty) + + sql(s"ALTER TABLE $t ADD CONSTRAINT c1 CHECK (id > 0) $characteristic") + val table = loadTable(catalog, "ns", "tbl") + val constraint = getCheckConstraint(table) + assert(constraint.name() == "c1") + assert(constraint.toDDL == s"CONSTRAINT c1 CHECK id > 0 $expectedDDL") + } + } + } + + test("Add duplicated check constraint") { + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing") + assert(loadTable(catalog, "ns", "tbl").constraints.isEmpty) + + sql(s"ALTER TABLE $t ADD CONSTRAINT abc CHECK (id > 0)") + // Constraint names are case-insensitive + Seq("abc", "ABC").foreach { name => + val error = intercept[AnalysisException] { + sql(s"ALTER TABLE $t ADD CONSTRAINT $name CHECK (id>0)") + } + checkError( + exception = error, + condition = "CONSTRAINT_ALREADY_EXISTS", + sqlState = "42710", + parameters = Map("constraintName" -> "abc", + "oldConstraint" -> "CONSTRAINT abc CHECK id > 0 ENFORCED VALID NORELY") + ) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala index 6ba60e245f9b4..8e7a4f55fc5f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command.v2 import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.analysis.ResolvePartitionSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.connector.catalog.{CatalogV2Implicits, Identifier, InMemoryCatalog, InMemoryPartitionTable, InMemoryPartitionTableCatalog, InMemoryTableCatalog} +import org.apache.spark.sql.connector.catalog.{CatalogV2Implicits, Identifier, InMemoryCatalog, InMemoryPartitionTable, InMemoryPartitionTableCatalog, InMemoryTable, InMemoryTableCatalog} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.ArrayImplicits._ @@ -64,4 +64,11 @@ trait CommandSuiteBase extends SharedSparkSession { assert(partMetadata.containsKey("location")) assert(partMetadata.get("location") === expected) } + + def loadTable(catalog: String, schema : String, table: String): InMemoryTable = { + import CatalogV2Implicits._ + val catalogPlugin = spark.sessionState.catalogManager.catalog(catalog) + catalogPlugin.asTableCatalog.loadTable(Identifier.of(Array(schema), table)) + .asInstanceOf[InMemoryTable] + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DropConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DropConstraintSuite.scala new file mode 100644 index 0000000000000..f492e18a6e529 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DropConstraintSuite.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.command.v2 + +import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.connector.catalog.constraints.Check +import org.apache.spark.sql.execution.command.DDLCommandTestUtils + +class DropConstraintSuite extends QueryTest with CommandSuiteBase with DDLCommandTestUtils { + override protected def command: String = "ALTER TABLE .. DROP CONSTRAINT" + + test("Drop a non-exist constraint") { + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing") + val e = intercept[AnalysisException] { + sql(s"ALTER TABLE $t DROP CONSTRAINT c1") + } + checkError( + exception = e, + condition = "CONSTRAINT_DOES_NOT_EXIST", + sqlState = "42704", + parameters = Map("constraintName" -> "c1", "tableName" -> "test_catalog.ns.tbl") + ) + } + } + + test("Drop a non-exist constraint if exists") { + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing") + sql(s"ALTER TABLE $t DROP CONSTRAINT IF EXISTS c1") + } + } + + test("Drop a constraint on a non-exist table") { + Seq("", "IF EXISTS").foreach { ifExists => + val e = intercept[AnalysisException] { + sql(s"ALTER TABLE test_catalog.ns.tbl DROP CONSTRAINT $ifExists c1") + } + checkError( + exception = e, + condition = "TABLE_OR_VIEW_NOT_FOUND", + sqlState = "42P01", + parameters = Map("relationName" -> "`test_catalog`.`ns`.`tbl`"), + context = ExpectedContext( + fragment = "test_catalog.ns.tbl", + start = 12, + stop = 30 + ) + ) + } + } + + test("Drop existing constraints") { + Seq("", "IF EXISTS").foreach { ifExists => + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing") + sql(s"ALTER TABLE $t ADD CONSTRAINT c1 CHECK (id > 0)") + sql(s"ALTER TABLE $t ADD CONSTRAINT c2 CHECK (len(data) > 0)") + sql(s"ALTER TABLE $t DROP CONSTRAINT $ifExists c1") + val table = loadTable(catalog, "ns", "tbl") + assert(table.constraints.length == 1) + val constraint = table.constraints.head.asInstanceOf[Check] + assert(constraint.name() == "c2") + + sql(s"ALTER TABLE $t DROP CONSTRAINT $ifExists c2") + val table2 = loadTable(catalog, "ns", "tbl") + assert(table2.constraints.length == 0) + } + } + } + + test("Drop constraint is case insensitive") { + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing") + sql(s"ALTER TABLE $t ADD CONSTRAINT abc CHECK (id > 0)") + sql(s"ALTER TABLE $t DROP CONSTRAINT aBC") + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ForeignKeyConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ForeignKeyConstraintSuite.scala new file mode 100644 index 0000000000000..14abe9d1ab9ba --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ForeignKeyConstraintSuite.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.command.v2 + +import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.execution.command.DDLCommandTestUtils + +class ForeignKeyConstraintSuite extends QueryTest with CommandSuiteBase with DDLCommandTestUtils { + override protected def command: String = "ALTER TABLE .. ADD CONSTRAINT" + + private val validConstraintCharacteristics = Seq( + ("", "NOT ENFORCED UNVALIDATED NORELY"), + ("NOT ENFORCED", "NOT ENFORCED UNVALIDATED NORELY"), + ("NOT ENFORCED NORELY", "NOT ENFORCED UNVALIDATED NORELY"), + ("NORELY NOT ENFORCED", "NOT ENFORCED UNVALIDATED NORELY"), + ("NORELY", "NOT ENFORCED UNVALIDATED NORELY"), + ("RELY", "NOT ENFORCED UNVALIDATED RELY") + ) + + test("Add foreign key constraint") { + validConstraintCharacteristics.foreach { case (characteristic, expectedDDL) => + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id bigint, fk bigint, data string) $defaultUsing") + sql(s"CREATE TABLE ${t}_ref (id bigint, data string) $defaultUsing") + assert(loadTable(catalog, "ns", "tbl").constraints.isEmpty) + + sql(s"ALTER TABLE $t ADD CONSTRAINT fk1 FOREIGN KEY (fk) " + + s"REFERENCES ${t}_ref(id) $characteristic") + val table = loadTable(catalog, "ns", "tbl") + assert(table.constraints.length == 1) + val constraint = table.constraints.head + assert(constraint.name() == "fk1") + assert(constraint.toDDL == s"CONSTRAINT fk1 FOREIGN KEY (fk) " + + s"REFERENCES test_catalog.ns.tbl_ref (id) $expectedDDL") + } + } + } + + test("Create table with foreign key constraint") { + validConstraintCharacteristics.foreach { case (characteristic, expectedDDL) => + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE ${t}_ref (id bigint, data string) $defaultUsing") + val constraintStr = s"CONSTRAINT fk1 FOREIGN KEY (fk) " + + s"REFERENCES ${t}_ref(id) $characteristic" + sql(s"CREATE TABLE $t (id bigint, fk bigint, data string, $constraintStr) $defaultUsing") + val table = loadTable(catalog, "ns", "tbl") + assert(table.constraints.length == 1) + val constraint = table.constraints.head + assert(constraint.name() == "fk1") + assert(constraint.toDDL == s"CONSTRAINT fk1 FOREIGN KEY (fk) " + + s"REFERENCES test_catalog.ns.tbl_ref (id) $expectedDDL") + } + } + } + + test("Add duplicated foreign key constraint") { + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id bigint, fk bigint, data string) $defaultUsing") + sql(s"CREATE TABLE ${t}_ref (id bigint, data string) $defaultUsing") + assert(loadTable(catalog, "ns", "tbl").constraints.isEmpty) + + sql(s"ALTER TABLE $t ADD CONSTRAINT fk1 FOREIGN KEY (fk) REFERENCES ${t}_ref(id)") + // Constraint names are case-insensitive + Seq("fk1", "FK1").foreach { name => + val error = intercept[AnalysisException] { + sql(s"ALTER TABLE $t ADD CONSTRAINT $name FOREIGN KEY (fk) REFERENCES ${t}_ref(id)") + } + checkError( + exception = error, + condition = "CONSTRAINT_ALREADY_EXISTS", + sqlState = "42710", + parameters = Map("constraintName" -> "fk1", + "oldConstraint" -> + ("CONSTRAINT fk1 FOREIGN KEY (fk) " + + "REFERENCES test_catalog.ns.tbl_ref (id) NOT ENFORCED UNVALIDATED NORELY")) + ) + } + } + } + + test("Add foreign key constraint with multiple columns") { + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id1 bigint, id2 bigint, fk1 bigint, fk2 bigint, data string) " + + s"$defaultUsing") + sql(s"CREATE TABLE ${t}_ref (id1 bigint, id2 bigint, data string) $defaultUsing") + assert(loadTable(catalog, "ns", "tbl").constraints.isEmpty) + + sql(s"ALTER TABLE $t ADD CONSTRAINT fk1 FOREIGN KEY (fk1, fk2) REFERENCES ${t}_ref(id1, id2)") + val table = loadTable(catalog, "ns", "tbl") + assert(table.constraints.length == 1) + val constraint = table.constraints.head + assert(constraint.name() == "fk1") + assert(constraint.toDDL == + s"CONSTRAINT fk1 FOREIGN KEY (fk1, fk2) " + + s"REFERENCES test_catalog.ns.tbl_ref (id1, id2) NOT ENFORCED UNVALIDATED NORELY") + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/PrimaryKeyConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/PrimaryKeyConstraintSuite.scala new file mode 100644 index 0000000000000..ae404aff274ac --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/PrimaryKeyConstraintSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.command.v2 + +import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.execution.command.DDLCommandTestUtils + +class PrimaryKeyConstraintSuite extends QueryTest with CommandSuiteBase with DDLCommandTestUtils { + override protected def command: String = "ALTER TABLE .. ADD CONSTRAINT" + + private val validConstraintCharacteristics = Seq( + ("", "NOT ENFORCED UNVALIDATED NORELY"), + ("NOT ENFORCED", "NOT ENFORCED UNVALIDATED NORELY"), + ("NOT ENFORCED NORELY", "NOT ENFORCED UNVALIDATED NORELY"), + ("NORELY NOT ENFORCED", "NOT ENFORCED UNVALIDATED NORELY"), + ("NORELY", "NOT ENFORCED UNVALIDATED NORELY"), + ("RELY", "NOT ENFORCED UNVALIDATED RELY") + ) + + test("Add primary key constraint") { + validConstraintCharacteristics.foreach { case (characteristic, expectedDDL) => + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing") + assert(loadTable(catalog, "ns", "tbl").constraints.isEmpty) + + sql(s"ALTER TABLE $t ADD CONSTRAINT pk1 PRIMARY KEY (id) $characteristic") + val table = loadTable(catalog, "ns", "tbl") + assert(table.constraints.length == 1) + val constraint = table.constraints.head + assert(constraint.name() == "pk1") + assert(constraint.toDDL == s"CONSTRAINT pk1 PRIMARY KEY (id) $expectedDDL") + } + } + } + + test("Create table with primary key constraint") { + validConstraintCharacteristics.foreach { case (characteristic, expectedDDL) => + withNamespaceAndTable("ns", "tbl", catalog) { t => + val constraintStr = s"CONSTRAINT pk1 PRIMARY KEY (id) $characteristic" + sql(s"CREATE TABLE $t (id bigint, data string, $constraintStr) $defaultUsing") + val table = loadTable(catalog, "ns", "tbl") + assert(table.constraints.length == 1) + val constraint = table.constraints.head + assert(constraint.name() == "pk1") + assert(constraint.toDDL == s"CONSTRAINT pk1 PRIMARY KEY (id) $expectedDDL") + } + } + } + + test("Add duplicated primary key constraint") { + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing") + assert(loadTable(catalog, "ns", "tbl").constraints.isEmpty) + + sql(s"ALTER TABLE $t ADD CONSTRAINT pk1 PRIMARY KEY (id)") + // Constraint names are case-insensitive + Seq("pk1", "PK1").foreach { name => + val error = intercept[AnalysisException] { + sql(s"ALTER TABLE $t ADD CONSTRAINT $name PRIMARY KEY (id)") + } + checkError( + exception = error, + condition = "CONSTRAINT_ALREADY_EXISTS", + sqlState = "42710", + parameters = Map("constraintName" -> "pk1", + "oldConstraint" -> "CONSTRAINT pk1 PRIMARY KEY (id) NOT ENFORCED UNVALIDATED NORELY") + ) + } + } + } + + test("Add primary key constraint with multiple columns") { + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id1 bigint, id2 bigint, data string) $defaultUsing") + assert(loadTable(catalog, "ns", "tbl").constraints.isEmpty) + + sql(s"ALTER TABLE $t ADD CONSTRAINT pk1 PRIMARY KEY (id1, id2)") + val table = loadTable(catalog, "ns", "tbl") + assert(table.constraints.length == 1) + val constraint = table.constraints.head + assert(constraint.name() == "pk1") + assert(constraint.toDDL == + "CONSTRAINT pk1 PRIMARY KEY (id1, id2) NOT ENFORCED UNVALIDATED NORELY") + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/UniqueConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/UniqueConstraintSuite.scala new file mode 100644 index 0000000000000..4eee2c248cfde --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/UniqueConstraintSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.command.v2 + +import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.execution.command.DDLCommandTestUtils + +class UniqueConstraintSuite extends QueryTest with CommandSuiteBase with DDLCommandTestUtils { + override protected def command: String = "ALTER TABLE .. ADD CONSTRAINT" + + private val validConstraintCharacteristics = Seq( + ("", "NOT ENFORCED UNVALIDATED NORELY"), + ("NOT ENFORCED", "NOT ENFORCED UNVALIDATED NORELY"), + ("NOT ENFORCED NORELY", "NOT ENFORCED UNVALIDATED NORELY"), + ("NORELY NOT ENFORCED", "NOT ENFORCED UNVALIDATED NORELY"), + ("NORELY", "NOT ENFORCED UNVALIDATED NORELY"), + ("RELY", "NOT ENFORCED UNVALIDATED RELY") + ) + + test("Add unique constraint") { + validConstraintCharacteristics.foreach { case (characteristic, expectedDDL) => + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing") + assert(loadTable(catalog, "ns", "tbl").constraints.isEmpty) + + sql(s"ALTER TABLE $t ADD CONSTRAINT uk1 UNIQUE (id) $characteristic") + val table = loadTable(catalog, "ns", "tbl") + assert(table.constraints.length == 1) + val constraint = table.constraints.head + assert(constraint.name() == "uk1") + assert(constraint.toDDL == s"CONSTRAINT uk1 UNIQUE (id) $expectedDDL") + } + } + } + + test("Create table with unique constraint") { + validConstraintCharacteristics.foreach { case (characteristic, expectedDDL) => + withNamespaceAndTable("ns", "tbl", catalog) { t => + val constraintStr = s"CONSTRAINT uk1 UNIQUE (id) $characteristic" + sql(s"CREATE TABLE $t (id bigint, data string, $constraintStr) $defaultUsing") + val table = loadTable(catalog, "ns", "tbl") + assert(table.constraints.length == 1) + val constraint = table.constraints.head + assert(constraint.name() == "uk1") + assert(constraint.toDDL == s"CONSTRAINT uk1 UNIQUE (id) $expectedDDL") + } + } + } + + test("Add duplicated unique constraint") { + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing") + assert(loadTable(catalog, "ns", "tbl").constraints.isEmpty) + + sql(s"ALTER TABLE $t ADD CONSTRAINT uk1 UNIQUE (id)") + // Constraint names are case-insensitive + Seq("uk1", "UK1").foreach { name => + val error = intercept[AnalysisException] { + sql(s"ALTER TABLE $t ADD CONSTRAINT $name UNIQUE (id)") + } + checkError( + exception = error, + condition = "CONSTRAINT_ALREADY_EXISTS", + sqlState = "42710", + parameters = Map("constraintName" -> "uk1", + "oldConstraint" -> "CONSTRAINT uk1 UNIQUE (id) NOT ENFORCED UNVALIDATED NORELY") + ) + } + } + } + + test("Add unique constraint with multiple columns") { + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id1 bigint, id2 bigint, data string) $defaultUsing") + assert(loadTable(catalog, "ns", "tbl").constraints.isEmpty) + + sql(s"ALTER TABLE $t ADD CONSTRAINT uk1 UNIQUE (id1, id2)") + val table = loadTable(catalog, "ns", "tbl") + assert(table.constraints.length == 1) + val constraint = table.constraints.head + assert(constraint.name() == "uk1") + assert(constraint.toDDL == + "CONSTRAINT uk1 UNIQUE (id1, id2) NOT ENFORCED UNVALIDATED NORELY") + } + } +}